diff --git a/.vscode/settings.json b/.vscode/settings.json index b7a1292efb2c6..fd28e2d7b335c 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -40,3 +40,4 @@ "-build/include_subdir", "-runtime/references" ] +} diff --git a/cgmanifests/generate_cgmanifest.py b/cgmanifests/generate_cgmanifest.py index a9eaacc6f2938..81181d3ccfb20 100644 --- a/cgmanifests/generate_cgmanifest.py +++ b/cgmanifests/generate_cgmanifest.py @@ -90,55 +90,6 @@ def add_github_dep(name, parsed_url): git_deps[dep] = name -with open( - os.path.join(REPO_DIR, "tools", "ci_build", "github", "linux", "docker", "Dockerfile.manylinux2_28_cuda11"), -) as f: - for line in f: - if not line.strip(): - package_name = None - package_filename = None - package_url = None - if package_filename is None: - m = re.match(r"RUN\s+export\s+(.+?)_ROOT=(\S+).*", line) - if m is not None: - package_name = m.group(1) - package_filename = m.group(2) - else: - m = re.match(r"RUN\s+export\s+(.+?)_VERSION=(\S+).*", line) - if m is not None: - package_name = m.group(1) - package_filename = m.group(2) - elif package_url is None: - m = re.match(r"(.+?)_DOWNLOAD_URL=(\S+)", line) - if m is not None: - package_url = m.group(2) - if package_name == "LIBXCRYPT": - package_url = m.group(2) + "/v" + package_filename + ".tar.gz" - elif package_name == "CMAKE": - package_url = m.group(2) + "/v" + package_filename + "/cmake-" + package_filename + ".tar.gz" - else: - package_url = m.group(2) + "/" + package_filename + ".tar.gz" - parsed_url = urlparse(package_url) - if parsed_url.hostname == "github.com": - add_github_dep("manylinux dependency " + package_name, parsed_url) - else: - registration = { - "Component": { - "Type": "other", - "other": { - "Name": package_name.lower(), - "Version": package_filename.split("-")[-1], - "DownloadUrl": package_url, - }, - "comments": "manylinux dependency", - } - } - registrations.append(registration) - package_name = None - package_filename = None - package_url = None - - def normalize_path_separators(path): return path.replace(os.path.sep, "/") diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 6f1ca84e1a304..08ca90d7c3b7f 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -2,112 +2,6 @@ "$schema": "https://json.schemastore.org/component-detection-manifest.json", "Version": 1, "Registrations": [ - { - "Component": { - "Type": "other", - "other": { - "Name": "autoconf", - "Version": "2.71", - "DownloadUrl": "http://ftp.gnu.org/gnu/autoconf/autoconf-2.71.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "Component": { - "Type": "other", - "other": { - "Name": "automake", - "Version": "1.16.5", - "DownloadUrl": "http://ftp.gnu.org/gnu/automake/automake-1.16.5.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "Component": { - "Type": "other", - "other": { - "Name": "libtool", - "Version": "2.4.7", - "DownloadUrl": "http://ftp.gnu.org/gnu/libtool/libtool-2.4.7.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "Component": { - "Type": "other", - "other": { - "Name": "git", - "Version": "2.36.2", - "DownloadUrl": "https://www.kernel.org/pub/software/scm/git/git-2.36.2.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "Component": { - "Type": "other", - "other": { - "Name": "sqlite_autoconf", - "Version": "3390200", - "DownloadUrl": "https://www.sqlite.org/2022/sqlite-autoconf-3390200.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "Component": { - "Type": "other", - "other": { - "Name": "openssl", - "Version": "1.1.1q", - "DownloadUrl": "https://www.openssl.org/source/openssl-1.1.1q.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "50cf2b6dd4fdf04309445f2eec8de7051d953abf", - "repositoryUrl": "https://github.com/besser82/libxcrypt.git" - }, - "comments": "manylinux dependency LIBXCRYPT" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "a896e3d066448b3530dbcaa48869fafefd738f57", - "repositoryUrl": "https://github.com/emscripten-core/emsdk.git" - }, - "comments": "git submodule at cmake/external/emsdk" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "7a2ed51a6b682a83e345ff49fc4cfd7ca47550db", - "repositoryUrl": "https://github.com/google/libprotobuf-mutator.git" - }, - "comments": "git submodule at cmake/external/libprotobuf-mutator" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "e2525550194ce3d8a2c4a3af451c9d9b3ae6650e", - "repositoryUrl": "https://github.com/onnx/onnx.git" - }, - "comments": "git submodule at cmake/external/onnx" - } - }, { "component": { "type": "git", @@ -268,6 +162,16 @@ "comments": "mp11" } }, + { + "component": { + "type": "git", + "git": { + "commitHash": "fdefbe85ed9c362b95b9b401cd19db068a76141f", + "repositoryUrl": "https://github.com/onnx/onnx.git" + }, + "comments": "onnx" + } + }, { "component": { "type": "git", diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 496ca72bb1b6c..b03a3019764ca 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -39,7 +39,12 @@ include(CMakeDependentOption) include(FetchContent) include(CheckFunctionExists) +# TODO: update this once all system adapt c++20 +if(CMAKE_SYSTEM_NAME STREQUAL "Darwin") +set(CMAKE_CXX_STANDARD 20) +else() set(CMAKE_CXX_STANDARD 17) +endif() set_property(GLOBAL PROPERTY USE_FOLDERS ON) # NOTE: POSITION INDEPENDENT CODE hurts performance, and it only make sense on POSIX systems @@ -68,6 +73,11 @@ option(onnxruntime_ENABLE_PYTHON "Enable python buildings" OFF) # Enable it may cause LNK1169 error option(onnxruntime_ENABLE_MEMLEAK_CHECKER "Experimental: Enable memory leak checker in Windows debug build" OFF) option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) +# Enable ONNX Runtime CUDA EP's internal unit tests that directly access the EP's internal functions instead of through +# OpKernels. When the option is ON, we will have two copies of GTest library in the same process. It is not a typical +# use. If you hit any problem with that, please do not report it to GTest. Turn OFF the following build option instead. +cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS;LINUX" OFF) + option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF) option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) option(onnxruntime_USE_COREML "Build with CoreML support" OFF) @@ -141,10 +151,11 @@ option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OF option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF) option(onnxruntime_DISABLE_FLOAT8_TYPES "Disable float 8 types" OFF) option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF) -cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON" OFF) +cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON;NOT onnxruntime_USE_CUDA" OFF) # For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone cmake_dependent_option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." ON "onnxruntime_MINIMAL_BUILD;NOT onnxruntime_ENABLE_PYTHON" OFF) -option(onnxruntime_DISABLE_ABSEIL "Do not link to Abseil. Redefine Inlined containers to STD containers." OFF) +# Even when onnxruntime_DISABLE_ABSEIL is ON, ONNX Runtime still needs to link to abseil. +option(onnxruntime_DISABLE_ABSEIL "Do not use Abseil data structures in ONNX Runtime source code. Redefine Inlined containers to STD containers." OFF) option(onnxruntime_EXTENDED_MINIMAL_BUILD "onnxruntime_MINIMAL_BUILD with support for execution providers that compile kernels." OFF) option(onnxruntime_MINIMAL_BUILD_CUSTOM_OPS "Add custom operator kernels support to a minimal build." OFF) @@ -264,10 +275,6 @@ if (onnxruntime_ENABLE_TRAINING_APIS) endif() endif() -if (onnxruntime_USE_CUDA) - set(onnxruntime_DISABLE_RTTI OFF) -endif() - if (onnxruntime_USE_ROCM) if (WIN32) message(FATAL_ERROR "ROCM does not support build in Windows!") diff --git a/cmake/deps.txt b/cmake/deps.txt index 279b5ca649dba..7cf49f02333a4 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -24,7 +24,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 -onnx;https://github.com/onnx/onnx/archive/e2525550194ce3d8a2c4a3af451c9d9b3ae6650e.zip;782f23d788185887f520a90535513e244218e928 +onnx;https://github.com/onnx/onnx/archive/14303de049144035dfd94ace5f7a3b44773b1aad.zip;250eab9690392b248d75b56e605fb49eca373442 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/0462dc31ae78f48744b6141ae376df1f96d3f459.zip;5ff086361956cceb81ed17453a1fd8db2aa4328d protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa @@ -44,4 +44,4 @@ tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.0.0.zip;0f95b3c1fc1bd1175c4a90b2c9e39074d1bccefd utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c -composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/d52ec01652b7d620386251db92455968d8d90bdc.zip;6b5ce8edf3625f8817086c194fbf94b664e1b0e0 \ No newline at end of file +composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/d52ec01652b7d620386251db92455968d8d90bdc.zip;6b5ce8edf3625f8817086c194fbf94b664e1b0e0 diff --git a/cmake/deps_update_and_upload.py b/cmake/deps_update_and_upload.py new file mode 100644 index 0000000000000..194d21435f0be --- /dev/null +++ b/cmake/deps_update_and_upload.py @@ -0,0 +1,56 @@ +# in case deps.txt is updated, run this file to update and upload the dependencies so that CI can use them. +# Before running the script, increase the version number found at: +# https://aiinfra.visualstudio.com/Lotus/_artifacts/feed/Lotus/UPack/onnxruntime_build_dependencies/versions +# Run without --do-upload once to verify downloading. Use --do-upload when you are ready to publish. +# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 --do-upload +# update version number in tools\ci_build\github\azure-pipelines\templates\download-deps.yml +import re +import subprocess +import os +import argparse +import tempfile + +parser = argparse.ArgumentParser(description="Update dependencies and publish to Azure Artifacts") +parser.add_argument( + "--root-path", type=str, default=tempfile.gettempdir(), help="Target root path for downloaded files" +) +parser.add_argument("--version", type=str, default="1.0.82", help="Package version to publish") +parser.add_argument("--do-upload", action="store_true", help="Upload the package to Azure Artifacts") +args = parser.parse_args() + +with open("cmake/deps.txt") as file: + text = file.read() + +lines = [line for line in text.split("\n") if not line.startswith("#") and ";" in line] + +root_path = args.root_path + +for line in lines: + url = re.sub("^[^;]+?;https://([^;]+?);.*", r"https://\1", line) + filename = re.sub("^[^;]+?;https://([^;]+?);.*", r"\1", line) + full_path = os.path.join(root_path, filename) + subprocess.run(["curl", "-sSL", "--create-dirs", "-o", full_path, url]) + +package_name = "onnxruntime_build_dependencies" +version = args.version + +# Check if the user is logged in to Azure +result = subprocess.run("az account show", shell=True, capture_output=True, text=True) +if "No subscriptions found" in result.stderr: + # Prompt the user to log in to Azure + print("You are not logged in to Azure. Please log in to continue.") + subprocess.run("az login", shell=True) + +# Publish the package to Azure Artifacts if --no-upload is not specified + +cmd = f'az artifacts universal publish --organization https://dev.azure.com/onnxruntime --feed onnxruntime --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' +if args.do_upload: + subprocess.run(cmd, shell=True) +else: + print("would have run: " + cmd) + +cmd = f'az artifacts universal publish --organization https://dev.azure.com/aiinfra --feed Lotus --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' +if args.do_upload: + subprocess.run(cmd, shell=True) +else: + print("would have run: " + cmd) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index e1671bcf43ed9..019c6341d2e46 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -37,8 +37,12 @@ if (onnxruntime_BUILD_UNIT_TESTS) set(gtest_disable_pthreads ON) endif() set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) - # Set it to ON will cause crashes in onnxruntime_test_all when onnxruntime_USE_CUDA is ON - set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) + if (CMAKE_SYSTEM_NAME STREQUAL "iOS") + # Needs to update onnxruntime/test/xctest/xcgtest.mm + set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) + else() + set(GTEST_HAS_ABSL ON CACHE BOOL "" FORCE) + endif() # gtest and gmock FetchContent_Declare( googletest diff --git a/cmake/onnxruntime_config.h.in b/cmake/onnxruntime_config.h.in index 2aef9dcf209e0..e3ea767401ddc 100644 --- a/cmake/onnxruntime_config.h.in +++ b/cmake/onnxruntime_config.h.in @@ -22,5 +22,5 @@ #cmakedefine HAS_UNUSED_BUT_SET_VARIABLE #cmakedefine HAS_UNUSED_VARIABLE #cmakedefine HAS_USELESS_CAST -#cmakedefine ORT_BUILD_INFO u8"@ORT_BUILD_INFO@" -#cmakedefine ORT_VERSION u8"@ORT_VERSION@" +#cmakedefine ORT_BUILD_INFO "@ORT_BUILD_INFO@" +#cmakedefine ORT_VERSION "@ORT_VERSION@" diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index e0ccc504d7b27..992908392c946 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -581,7 +581,7 @@ set_target_properties(onnxruntime_mlas PROPERTIES FOLDER "ONNXRuntime") if (WIN32) target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") if (onnxruntime_ENABLE_STATIC_ANALYSIS) - target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize" 131072>) + target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize 131072>") endif() endif() diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index b9e7873132089..96c05e5282bb5 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -460,13 +460,17 @@ if (onnxruntime_USE_CUDA) if (onnxruntime_REDUCED_OPS_BUILD) substitute_op_reduction_srcs(onnxruntime_providers_cuda_src) endif() - # cuda_provider_interface.cc is removed from the object target: onnxruntime_providers_cuda_obj and - # add to the lib onnxruntime_providers_cuda separatedly. - # onnxruntime_providers_cuda_ut can share all the object files with onnxruntime_providers_cuda except cuda_provider_interface.cc. - set(cuda_provider_interface_src ${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_provider_interface.cc) - list(REMOVE_ITEM onnxruntime_providers_cuda_src ${cuda_provider_interface_src}) - onnxruntime_add_object_library(onnxruntime_providers_cuda_obj ${onnxruntime_providers_cuda_src}) - onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${cuda_provider_interface_src} $) + if(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + # cuda_provider_interface.cc is removed from the object target: onnxruntime_providers_cuda_obj and + # add to the lib onnxruntime_providers_cuda separatedly. + # onnxruntime_providers_cuda_ut can share all the object files with onnxruntime_providers_cuda except cuda_provider_interface.cc. + set(cuda_provider_interface_src ${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_provider_interface.cc) + list(REMOVE_ITEM onnxruntime_providers_cuda_src ${cuda_provider_interface_src}) + onnxruntime_add_object_library(onnxruntime_providers_cuda_obj ${onnxruntime_providers_cuda_src}) + onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${cuda_provider_interface_src} $) + else() + onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${onnxruntime_providers_cuda_src}) + endif() # config_cuda_provider_shared_module can be used to config onnxruntime_providers_cuda_obj, onnxruntime_providers_cuda & onnxruntime_providers_cuda_ut. # This function guarantees that all 3 targets have the same configurations. function(config_cuda_provider_shared_module target) @@ -600,7 +604,9 @@ if (onnxruntime_USE_CUDA) target_compile_definitions(${target} PRIVATE ENABLE_ATEN) endif() endfunction() - config_cuda_provider_shared_module(onnxruntime_providers_cuda_obj) + if(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + config_cuda_provider_shared_module(onnxruntime_providers_cuda_obj) + endif() config_cuda_provider_shared_module(onnxruntime_providers_cuda) install(TARGETS onnxruntime_providers_cuda @@ -1750,7 +1756,7 @@ if (onnxruntime_USE_TVM) target_include_directories(onnxruntime_providers_tvm PRIVATE ${TVM_INCLUDES} - ${PYTHON_INLCUDE_DIRS}) + ${PYTHON_INCLUDE_DIRS}) onnxruntime_add_include_to_target(onnxruntime_providers_tvm onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) add_dependencies(onnxruntime_providers_tvm ${onnxruntime_EXTERNAL_DEPENDENCIES}) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index ec83eb2095071..0e642c5a7e0aa 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -35,13 +35,17 @@ function(AddTest) if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) #TODO: fix the warnings, they are dangerous - target_compile_options(${_UT_TARGET} PRIVATE "/wd4244") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4244>" + "$<$>:/wd4244>") endif() if (MSVC) - target_compile_options(${_UT_TARGET} PRIVATE "/wd6330") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd6330>" + "$<$>:/wd6330>") #Abseil has a lot of C4127/C4324 warnings. - target_compile_options(${_UT_TARGET} PRIVATE "/wd4127") - target_compile_options(${_UT_TARGET} PRIVATE "/wd4324") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4127>" + "$<$>:/wd4127>") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4324>" + "$<$>:/wd4324>") endif() set_target_properties(${_UT_TARGET} PROPERTIES FOLDER "ONNXRuntimeTest") @@ -60,6 +64,11 @@ function(AddTest) Threads::Threads) target_compile_definitions(${_UT_TARGET} PRIVATE -DUSE_ONNXRUNTIME_DLL) else() + if(onnxruntime_USE_CUDA) + #XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs, + # otherwise it will impact when CUDA DLLs can be unloaded. + target_link_libraries(${_UT_TARGET} PRIVATE cudart) + endif() target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES}) endif() @@ -85,29 +94,22 @@ function(AddTest) # include dbghelp in case tests throw an ORT exception, as that exception includes a stacktrace, which requires dbghelp. target_link_libraries(${_UT_TARGET} PRIVATE debug dbghelp) - if (onnxruntime_USE_CUDA) - # disable a warning from the CUDA headers about unreferenced local functions - if (MSVC) - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd4505>" - "$<$>:/wd4505>") - endif() - endif() if (MSVC) # warning C6326: Potential comparison of a constant with another constant. # Lot of such things came from gtest - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd6326>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") # Raw new and delete. A lot of such things came from googletest. - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") # "Global initializer calls a non-constexpr function." - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd26426>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") endif() target_compile_options(${_UT_TARGET} PRIVATE ${disabled_warnings}) else() target_compile_options(${_UT_TARGET} PRIVATE ${DISABLED_WARNINGS_FOR_TVM}) - target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:-Xcompiler -Wno-error=sign-compare>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options -Wno-error=sign-compare>" "$<$>:-Wno-error=sign-compare>") target_compile_options(${_UT_TARGET} PRIVATE "-Wno-error=uninitialized") endif() @@ -698,7 +700,7 @@ onnxruntime_add_static_library(onnxruntime_test_utils ${onnxruntime_test_utils_s if(MSVC) target_compile_options(onnxruntime_test_utils PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") - target_compile_options(onnxruntime_test_utils PRIVATE "$<$:-Xcompiler /wd6326>" + target_compile_options(onnxruntime_test_utils PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") else() target_compile_definitions(onnxruntime_test_utils PUBLIC -DNSYNC_ATOMIC_CPP11) @@ -755,13 +757,8 @@ set_target_properties(onnx_test_runner_common PROPERTIES FOLDER "ONNXRuntimeTest set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src} ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src} ${onnxruntime_test_quantiztion_src}) -if(NOT TARGET onnxruntime AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - list(APPEND all_tests ${onnxruntime_shared_lib_test_SRC}) -endif() -if (onnxruntime_USE_CUDA) - onnxruntime_add_static_library(onnxruntime_test_cuda_ops_lib ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) - list(APPEND onnxruntime_test_common_libs onnxruntime_test_cuda_ops_lib) +if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) file(GLOB onnxruntime_test_providers_cuda_ut_src CONFIGURE_DEPENDS "${TEST_SRC_DIR}/providers/cuda/test_cases/*" ) @@ -822,7 +819,9 @@ if (onnxruntime_USE_TENSORRT) # made test name contain the "ep" and "model path" information, so we can easily filter the tests using cuda ep or other ep with *cpu_* or *xxx_*. list(APPEND test_all_args "--gtest_filter=-*cpu_*:*cuda_*" ) endif () - +if(NOT onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + list(REMOVE_ITEM all_tests ${TEST_SRC_DIR}/providers/cuda/cuda_provider_test.cc) +endif() AddTest( TARGET onnxruntime_test_all SOURCES ${all_tests} ${onnxruntime_unittest_main_src} @@ -832,11 +831,15 @@ AddTest( DEPENDS ${all_dependencies} TEST_ARGS ${test_all_args} ) + if (MSVC) # The warning means the type of two integral values around a binary operator is narrow than their result. # If we promote the two input values first, it could be more tolerant to integer overflow. # However, this is test code. We are less concerned. - target_compile_options(onnxruntime_test_all PRIVATE "/wd26451" "/wd4244") + target_compile_options(onnxruntime_test_all PRIVATE "$<$:SHELL:--compiler-options /wd26451>" + "$<$>:/wd26451>") + target_compile_options(onnxruntime_test_all PRIVATE "$<$:SHELL:--compiler-options /wd4244>" + "$<$>:/wd4244>") else() target_compile_options(onnxruntime_test_all PRIVATE "-Wno-parentheses") endif() @@ -1092,18 +1095,18 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${ONNXRUNTIME_ROOT}/core/mlas/inc) target_compile_definitions(onnxruntime_benchmark PRIVATE BENCHMARK_STATIC_DEFINE) if(WIN32) - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd4141>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd4141>" "$<$>:/wd4141>") # Avoid using new and delete. But this is a benchmark program, it's ok if it has a chance to leak. - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26400>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26400>" "$<$>:/wd26400>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26814>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26814>" "$<$>:/wd26814>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26814>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26814>" "$<$>:/wd26497>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26426>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") @@ -1255,7 +1258,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo) endif() if (onnxruntime_USE_CUDA) - list(APPEND onnxruntime_shared_lib_test_LIBS onnxruntime_test_cuda_ops_lib cudart) + list(APPEND onnxruntime_shared_lib_test_LIBS cudart) endif() if (onnxruntime_USE_TENSORRT) list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER}) @@ -1270,7 +1273,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) LIBS ${onnxruntime_shared_lib_test_LIBS} DEPENDS ${all_dependencies} ) - + if (onnxruntime_USE_CUDA) + target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) + endif() if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc" @@ -1356,13 +1362,13 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ) onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src}) if(MSVC) - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:-Xcompiler /wd6326>" + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:-Xcompiler /wd26426>" + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") endif() if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") @@ -1476,7 +1482,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") "${TEST_SRC_DIR}/testdata/custom_op_library/cuda/cuda_ops.*") list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include) if (HAS_QSPECTRE) - list(APPEND custom_op_lib_option "$<$:SHELL:-Xcompiler /Qspectre>") + list(APPEND custom_op_lib_option "$<$:SHELL:--compiler-options /Qspectre>") endif() endif() @@ -1503,7 +1509,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") else() set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-DEF:${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.def") if (NOT onnxruntime_USE_CUDA) - target_compile_options(custom_op_library PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(custom_op_library PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") endif() endif() diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index 155d153019f85..a2d7672a3d48d 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -64,16 +64,3 @@ index 0aab3e26..0f859267 100644 +#endif + #endif // ! ONNX_ONNX_PB_H -diff --git a/onnx/checker.cc b/onnx/checker.cc -index 8fdaf037..1beb1b88 100644 ---- a/onnx/checker.cc -+++ b/onnx/checker.cc -@@ -190,7 +190,7 @@ void check_tensor(const TensorProto& tensor, const CheckerContext& ctx) { - } - std::string data_path = path_join(ctx.get_model_dir(), relative_path); - // use stat64 to check whether the file exists --#ifdef __APPLE__ -+#if defined(__APPLE__) || defined(__wasm__) - struct stat buffer; // APPLE does not have stat64 - if (stat((data_path).c_str(), &buffer) != 0) { - #else diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs index 659c6303702ac..6889112acb385 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/CheckpointState.shared.cs @@ -40,20 +40,16 @@ internal enum PropertyType : long String = 2 } - private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) + private void AddPropertyImpl(string propertyName, PropertyType propertyType, T propertyValue) where T : unmanaged { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); - T[] value = new T[1]; - value[0] = propertyValue; - Memory memory = value; - using (var memHandle = memory.Pin()) + T[] value = { propertyValue }; + unsafe { - IntPtr memPtr; - unsafe + fixed (T* memPtr = value) { - memPtr = (IntPtr)memHandle.Pointer; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, (IntPtr)memPtr)); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, propertyType, memPtr)); } } @@ -103,13 +99,13 @@ public static void SaveCheckpoint(CheckpointState state, string checkpointPath, } /// - /// Adds the given int property to the checkpoint state. + /// Adds or updates the given int property to/in the checkpoint state. /// - /// Runtime properties that are ints such as epoch, training step, and others can be added to the checkpoint - /// state by the user if they desire by calling this function with the appropriate property name and - /// value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, long propertyValue) { @@ -117,13 +113,13 @@ public void AddProperty(string propertyName, long propertyValue) } /// - /// Adds the given float property to the checkpoint state. + /// Adds or updates the given float property to/in the checkpoint state. /// - /// Runtime properties that are floats such as loss, best score, and others can be added to the checkpoint - /// state by the user if they desire by calling this function with the appropriate property name and - /// value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, float propertyValue) { @@ -131,28 +127,25 @@ public void AddProperty(string propertyName, float propertyValue) } /// - /// Adds the given string property to the checkpoint state. + /// Adds or updates the given string property to/in the checkpoint state. /// - /// Runtime properties that are strings such as parameter names, custom strings, and others can be added - /// to the checkpoint state by the user if they desire by calling this function with the appropriate property - /// name and value. The given property name must be unique to be able to successfully add the property. + /// Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint + /// state by the user by calling this function with the corresponding property name and value. + /// The given property name must be unique to be able to successfully add the property. /// - /// Unique name of the property being added. + /// Name of the property being added or updated. /// Property value associated with the given name. public void AddProperty(string propertyName, string propertyValue) { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); var propertyValueUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyValue); - IntPtr unmanagedPointer = Marshal.AllocHGlobal(propertyValueUtf8.Length); - try - { - Marshal.Copy(propertyValueUtf8, 0, unmanagedPointer, propertyValueUtf8.Length); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, unmanagedPointer)); - } - finally + unsafe { - Marshal.FreeHGlobal(unmanagedPointer); + fixed (byte* p = propertyValueUtf8) + { + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtAddProperty(handle, propertyNameUtf8, PropertyType.String, (IntPtr)p)); + } } } @@ -162,34 +155,86 @@ public void AddProperty(string propertyName, string propertyValue) /// Gets the property value from an existing entry in the checkpoint state. The property must /// exist in the checkpoint state to be able to retrieve it successfully. /// - /// Unique name of the property being retrieved. + /// Name of the property being retrieved. /// Property value associated with the given property name. public object GetProperty(string propertyName) { var propertyNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(propertyName); var allocator = OrtAllocator.DefaultInstance; IntPtr propertyValue = IntPtr.Zero; + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetProperty(handle, propertyNameUtf8, allocator.Pointer, out PropertyType propertyType, out propertyValue)); - if (propertyType == PropertyType.Int) + try { - var longPropertyValue = Marshal.ReadInt64(propertyValue); - allocator.FreeMemory(propertyValue); - return longPropertyValue; + if (propertyType == PropertyType.Int) + { + Int64 value; + unsafe + { + value = *(Int64*)propertyValue; + } + return value; + } + else if (propertyType == PropertyType.Float) + { + float value; + unsafe + { + value = *(float*)propertyValue; + } + return value; + } + else if (propertyType == PropertyType.String) + { + return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue); + } + + throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); } - else if (propertyType == PropertyType.Float) + finally { - float[] value = new float[1]; - Marshal.Copy(propertyValue, value, 0, 1); allocator.FreeMemory(propertyValue); - return value[0]; } - else if (propertyType == PropertyType.String) + } + + /// + /// Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + /// + /// This function updates a model parameter in the checkpoint state with the given parameter data. + /// The training session must be already created with the checkpoint state that contains the parameter + /// being updated. The given parameter is copied over to the registered device for the training session. + /// The parameter must exist in the checkpoint state to be able to update it successfully. + /// + /// Name of the parameter being updated. + /// The parameter data that should replace the existing parameter data. + public void UpdateParameter(string parameterName, OrtValue parameter) + { + if (parameter.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) { - return NativeOnnxValueHelper.StringFromNativeUtf8(propertyValue, allocator); + throw new ArgumentException("Incorrect buffer received. Expected a tensor parameter."); } - throw new ArgumentException("Expected the property type to be one of long, float or string. Unknown type retrieved " + propertyValue.ToString()); + var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtUpdateParameter(handle, parameterNameUtf8, parameter.Handle)); + } + + /// + /// Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + /// + /// This function retrieves the model parameter data from the checkpoint state for the given parameter name. + /// The parameter is copied over to the provided OrtValue. The training session must be already created + /// with the checkpoint state that contains the parameter being retrieved. + /// The parameter must exist in the checkpoint state to be able to retrieve it successfully. + /// + /// Name of the parameter being updated. + /// The parameter data that is retrieved from the checkpoint state. + public OrtValue GetParameter(string parameterName) + { + var parameterNameUtf8 = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(parameterName); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParameter(handle, parameterNameUtf8, OrtAllocator.DefaultInstance.Pointer, out IntPtr parameterHandle)); + + return new OrtValue(parameterHandle); } #region SafeHandle diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs index 1868ff509bfc3..68a399f8b9671 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/NativeTrainingMethods.shared.cs @@ -42,6 +42,9 @@ public struct OrtTrainingApi public IntPtr AddProperty; public IntPtr GetProperty; public IntPtr LoadCheckpointFromBuffer; + public IntPtr GetParameterTypeAndShape; + public IntPtr UpdateParameter; + public IntPtr GetParameter; } internal static class NativeTrainingMethods @@ -97,6 +100,9 @@ static NativeTrainingMethods() OrtGetEvalModelInputName = (DOrtGetEvalModelInputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelInputName, typeof(DOrtGetEvalModelInputName)); OrtAddProperty = (DOrtAddProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.AddProperty, typeof(DOrtAddProperty)); OrtGetProperty = (DOrtGetProperty)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetProperty, typeof(DOrtGetProperty)); + OrtGetParameterTypeAndShape = (DOrtGetParameterTypeAndShape)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameterTypeAndShape, typeof(DOrtGetParameterTypeAndShape)); + OrtUpdateParameter = (DOrtUpdateParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.UpdateParameter, typeof(DOrtUpdateParameter)); + OrtGetParameter = (DOrtGetParameter)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetParameter, typeof(DOrtGetParameter)); } } @@ -359,6 +365,34 @@ out UIntPtr inputCount public static DOrtGetProperty OrtGetProperty; + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameterTypeAndShape( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + out IntPtr /*(OrtTensorTypeAndShapeInfo**)*/ parameterTypeAndShape + ); + + public static DOrtGetParameterTypeAndShape OrtGetParameterTypeAndShape; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtUpdateParameter( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + IntPtr /*(OrtValue*)*/ parameter + ); + + public static DOrtUpdateParameter OrtUpdateParameter; + + [UnmanagedFunctionPointer(CallingConvention.Winapi)] + public delegate IntPtr /*(OrtStatus*)*/ DOrtGetParameter( + IntPtr /*(OrtCheckpointState*)*/ checkpointState, + byte[] /*(const char*)*/ parameterName, + IntPtr /*(OrtAllocator*)*/ allocator, + out IntPtr /*(OrtValue**)*/ parameter + ); + + public static DOrtGetParameter OrtGetParameter; + #endregion TrainingSession API public static bool TrainingEnabled() diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs index 33993c2be135b..877677dcad57b 100644 --- a/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs +++ b/csharp/src/Microsoft.ML.OnnxRuntime/Training/TrainingSession.shared.cs @@ -358,13 +358,14 @@ public void EvalStep( IReadOnlyCollection inputValues, IReadOnlyCollection outputValues) { - if (!_evalOutputCount.Equals(outputValues.Count)) + if (_evalOutputCount != (ulong)outputValues.Count()) { - throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount})."); + throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of eval model ({_evalOutputCount})."); } - IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); + const bool isInput = true; + IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, isInput); - IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */ + IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, !isInput); /* pointers to Pre-allocated OrtValue instances */ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtEvalStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count, inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray)); } @@ -509,18 +510,17 @@ public void ExportModelForInferencing(string inferenceModelPath, IReadOnlyCollec /// Returns a contiguous buffer that holds a copy of all training state parameters /// /// Whether to only copy trainable parameters or to copy all parameters. - public FixedBufferOnnxValue ToBuffer(bool onlyTrainable) + public OrtValue ToBuffer(bool onlyTrainable) { UIntPtr bufferSize = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out bufferSize, onlyTrainable)); float[] bufferMemory = new float[bufferSize.ToUInt64()]; - var memInfo = OrtMemoryInfo.DefaultInstance; // CPU - var shape = new long[] { (long)bufferSize.ToUInt64() }; - var buffer = FixedBufferOnnxValue.CreateFromMemory(memInfo, bufferMemory, Tensors.TensorElementType.Float, shape, (long)bufferSize.ToUInt64() * sizeof(float)); + var shape = new long[] { (long)bufferSize }; + var buffer = OrtValue.CreateAllocatedTensorValue(OrtAllocator.DefaultInstance, Tensors.TensorElementType.Float, shape); - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Value.Handle, onlyTrainable)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyParametersToBuffer(_nativeHandle, buffer.Handle, onlyTrainable)); return buffer; } @@ -528,45 +528,30 @@ public FixedBufferOnnxValue ToBuffer(bool onlyTrainable) /// /// Loads the training session model parameters from a contiguous buffer /// - /// Contiguous buffer to load the parameters from. - public void FromBuffer(FixedBufferOnnxValue buffer) + /// Contiguous buffer to load the parameters from. + /// Whether to only load trainable parameters or to load all parameters. + public void FromBuffer(OrtValue ortValue, bool onlyTrainable) { - if (buffer.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) + if (ortValue.OnnxType != OnnxValueType.ONNX_TYPE_TENSOR) { throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer."); } - IntPtr typeAndShapeInfo = IntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorTypeAndShape(buffer.Value.Handle, out typeAndShapeInfo)); - UIntPtr numDimensions = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(typeAndShapeInfo, out numDimensions)); - if (numDimensions.ToUInt64() != 1) + var tensorInfo = ortValue.GetTensorTypeAndShape(); + if (tensorInfo.ElementDataType != Tensors.TensorElementType.Float) { - string errorMessage = "Incorrect buffer shape received. Expected a contiguous tensor buffer. Expected number of dimensions: 1, Actual: " + numDimensions.ToString(); - throw new ArgumentException(errorMessage); - } - - // Here buffer size represents the number of elements in the buffer - NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorShapeElementCount(typeAndShapeInfo, out UIntPtr bufferSize)); - - // OrtGetParametersSize returns the total number of elements in the model's parameters. - UIntPtr numElementsTrainingOnly = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElementsTrainingOnly, true)); - if ((ulong)bufferSize == (ulong)numElementsTrainingOnly) - { - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, true)); - return; + throw new ArgumentException("Incorrect buffer received. Expected a tensor buffer of type float."); } UIntPtr numElements = UIntPtr.Zero; - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, false)); - if ((ulong)bufferSize != (ulong)numElements) + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetParametersSize(_nativeHandle, out numElements, onlyTrainable)); + if ((ulong)tensorInfo.ElementCount != (ulong)numElements) { - string errorMessage = "Incorrect buffer size received. Expected size to be one of " + numElementsTrainingOnly.ToString() + " (training only) or " + numElements.ToString() + " (all parameters). Actual size: " + bufferSize.ToString(); + string errorMessage = "Incorrect buffer size received. Expected size to be " + numElements.ToString() + ". Actual size: " + tensorInfo.ElementCount.ToString(); throw new ArgumentException(errorMessage); } - NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, buffer.Value.Handle, false)); + NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCopyBufferToParameters(_nativeHandle, ortValue.Handle, onlyTrainable)); } /// diff --git a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs index ea2b6d7dbc118..68b1d5bcc6147 100644 --- a/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs +++ b/csharp/test/Microsoft.ML.OnnxRuntime.Tests.Common/TrainingTest.cs @@ -484,20 +484,23 @@ public void TestEvalModelOutputNames() public void TestToBuffer() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - - var buffer = trainingSession.ToBuffer(true); - cleanUp.Add(buffer); + using (var buffer = trainingSession.ToBuffer(true)) + { + Assert.NotNull(buffer); + var typeShape = buffer.GetTensorTypeAndShape(); + Assert.Equal(1, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(397510, fetchedShape[0]); + } } } @@ -505,22 +508,25 @@ public void TestToBuffer() public void TestFromBuffer() { string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); - using (var cleanUp = new DisposableListTest()) + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) { - var state = CheckpointState.LoadCheckpoint(checkpointPath); - cleanUp.Add(state); Assert.NotNull(state); - string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); - string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); - string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); - - var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath); - cleanUp.Add(trainingSession); - var buffer = trainingSession.ToBuffer(true); - cleanUp.Add(buffer); + using (var buffer = trainingSession.ToBuffer(true)) + { + Assert.NotNull(buffer); + var typeShape = buffer.GetTensorTypeAndShape(); + Assert.Equal(1, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(397510, fetchedShape[0]); - trainingSession.FromBuffer(buffer); + trainingSession.FromBuffer(buffer, true); + } } } @@ -530,6 +536,82 @@ public void TestSetSeed() TrainingUtils.SetSeed(8888); } + [Fact(DisplayName = "TestGetParameter")] + public void TestGetParameter() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) + using (var parameter = state.GetParameter("fc1.weight")) + { + Assert.NotNull(state); + Assert.NotNull(parameter); + + var typeShape = parameter.GetTensorTypeAndShape(); + Assert.Equal(2, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(500, fetchedShape[0]); + Assert.Equal(784, fetchedShape[1]); + } + } + + [Fact(DisplayName = "TestUpdateParameter")] + public void TestUpdateParameter() + { + string checkpointPath = Path.Combine(Directory.GetCurrentDirectory(), "checkpoint.ckpt"); + string trainingPath = Path.Combine(Directory.GetCurrentDirectory(), "training_model.onnx"); + string evalPath = Path.Combine(Directory.GetCurrentDirectory(), "eval_model.onnx"); + string optimizerPath = Path.Combine(Directory.GetCurrentDirectory(), "adamw.onnx"); + + using (var state = CheckpointState.LoadCheckpoint(checkpointPath)) + using (var trainingSession = new TrainingSession(state, trainingPath, evalPath, optimizerPath)) + { + Assert.NotNull(state); + + using (var parameter = state.GetParameter("fc1.weight")) + { + Assert.NotNull(parameter); + var typeShape = parameter.GetTensorTypeAndShape(); + + Assert.Equal(2, typeShape.DimensionsCount); + var fetchedShape = typeShape.Shape; + Assert.Equal(500, fetchedShape[0]); + Assert.Equal(784, fetchedShape[1]); + + float maxVal = 20; + Random randNum = new Random(); + float[] updated_parameter_buffer = Enumerable + .Repeat(0, 500 * 784) + .Select(i => maxVal * (float)randNum.NextDouble()) + .ToArray(); + + using (var updated_parameter = OrtValue.CreateTensorValueFromMemory(updated_parameter_buffer, fetchedShape)) + { + state.UpdateParameter("fc1.weight", updated_parameter); + using (var current_parameter = state.GetParameter("fc1.weight")) + { + var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(updated_parameter_buffer, current_parameter_tensor); + Assert.NotEqual(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + } + + state.UpdateParameter("fc1.weight", parameter); + + using (var current_parameter = state.GetParameter("fc1.weight")) + { + var current_parameter_tensor = current_parameter.GetTensorDataAsSpan().ToArray(); + Assert.Equal(parameter.GetTensorDataAsSpan().ToArray(), current_parameter_tensor); + Assert.NotEqual(updated_parameter_buffer, current_parameter_tensor); + } + } + } + } + } + internal class FloatComparer : IEqualityComparer { private float atol = 1e-3f; diff --git a/docs/FAQ.md b/docs/FAQ.md index e039f4e4e4160..70bedbd02e944 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -2,13 +2,13 @@ Here are some commonly raised questions from users of ONNX Runtime and brought up in [Issues](https://github.com/microsoft/onnxruntime/issues). ## Do the GPU builds support quantized models? -The default CUDA build supports 3 standard quantization operators: QuantizeLinear, DequantizeLinear, and MatMulInteger. The TensorRT EP has limited support for INT8 quantized ops. In general, support of quantized models through ORT is continuing to expand on a model-driven basis. For performance improvements, quantization is not always required, and we suggest trying alternative strategies to [performance tune](./ONNX_Runtime_Perf_Tuning.md) before determining that quantization is necessary. +The default CUDA build supports 3 standard quantization operators: QuantizeLinear, DequantizeLinear, and MatMulInteger. The TensorRT EP has limited support for INT8 quantized ops. In general, support of quantized models through ORT is continuing to expand on a model-driven basis. For performance improvements, quantization is not always required, and we suggest trying alternative strategies to [performance tune](https://onnxruntime.ai/docs/performance/tune-performance/) before determining that quantization is necessary. ## How do I change the severity level of the default logger to something other than the default (WARNING)? Setting the severity level to VERBOSE is most useful when debugging errors. Refer to the API documentation: -* Python - [RunOptions.log_severity_level](https://microsoft.github.io/onnxruntime/python/api_summary.html#onnxruntime.RunOptions.log_severity_level) +* Python - [RunOptions.log_severity_level](https://onnxruntime.ai/docs/api/python/api_summary.html#onnxruntime.RunOptions.log_severity_level) ``` import onnxruntime as ort ort.set_default_logger_severity(0) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 33c187a28b62e..14b6b339c11f3 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -67,7 +67,8 @@ Do not modify directly.* |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[4, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ConcatFromSequence|*in* input_sequence:**S**
*out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|20+|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[9, 19]|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(float)| |||[1, 10]|**T** = tensor(float)| |ConvInteger|*in* x:**T1**
*in* w:**T2**
*in* x_zero_point:**T1**
*in* w_zero_point:**T2**
*out* y:**T3**|10+|**T1** = tensor(uint8)
**T2** = tensor(uint8)
**T3** = tensor(int32)| @@ -78,7 +79,7 @@ Do not modify directly.* |Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| -|DFT|*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| +|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[1, 10]|**T** = tensor(double), tensor(float)| @@ -935,7 +936,7 @@ Do not modify directly.* |Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||11+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|DFT|*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int64)| +|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int64)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index 7f3f26fa4aa02..a867ab6066485 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -3,8 +3,9 @@ #pragma once -#include +#include #include +#include #include #include @@ -37,7 +38,8 @@ namespace onnxruntime { class Tensor final { public: - // NB! Removing Create() methods returning unique_ptr. Still available in other EPs that are dynamically linked. + // NB! Removing Create() methods returning unique_ptr. + // Still available in other EPs that are dynamically linked. // Strive not to allocate Tensor with new/delete as it is a shallow class and using it by value is just fine. // Use InitOrtValue() methods to allocate for OrtValue. @@ -45,105 +47,104 @@ class Tensor final { /** * Create tensor with given type, shape, pre-allocated memory and allocator info. - * This function won't check if the preallocated buffer(p_data) has enough room for the shape. - * \param p_type Data type of the tensor + * This function does not check if the preallocated buffer(p_data) has enough room for the shape. + * \param elt_type Data type of the tensor elements. * \param shape Shape of the tensor * \param p_data A preallocated buffer. Can be NULL if the shape is empty. - * Tensor does not own the data and will not delete it - * \param alloc Where the buffer('p_data') was allocated from + * Tensor does not own the data and will not delete it + * \param location Memory info for location of p_data. * \param offset Offset in bytes to start of Tensor within p_data. * \param strides Strides span. Can be empty if the tensor is contiguous. */ - Tensor(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, + Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, ptrdiff_t offset = 0, gsl::span strides = {}); - /// - /// Creates an instance of Tensor on the heap using the appropriate __ctor and - /// initializes OrtValue with it. - /// - /// - /// - /// - /// - /// - /// - static void InitOrtValue(MLDataType p_type, const TensorShape& shape, - void* p_data, const OrtMemoryInfo& location, - OrtValue& ort_value, ptrdiff_t offset = 0, - gsl::span strides = {}); - - /// - /// Creates an instance of Tensor who own the pre-allocated buffer. - /// - /// - /// - /// - /// - /// - /// - static void InitOrtValue(MLDataType p_type, const TensorShape& shape, - void* p_data, std::shared_ptr allocator, - OrtValue& ort_value, ptrdiff_t offset = 0, - gsl::span strides = {}); - - static size_t CalculateTensorStorageSize(MLDataType p_type, - const TensorShape& shape, - gsl::span strides = {}); - /** - * Deprecated. The original design is this Tensor class won't do any allocation / release. - * However, this function will allocate the buffer for the shape, and do placement new if p_type is string tensor. - */ - Tensor(MLDataType p_type, const TensorShape& shape, std::shared_ptr allocator, - gsl::span strides = {}); - - /// - /// Creates an instance of Tensor on the heap using the appropriate __ctor and - /// initializes OrtValue with it. - /// - /// - /// - /// - /// - /// - static void InitOrtValue(MLDataType elt_type, - const TensorShape& shape, - std::shared_ptr allocator, - OrtValue& ort_value, - gsl::span strides = {}); - - /// - /// Creates an instance of Tensor on the heap using the appropriate __ctor and - /// initializes OrtValue with it. - /// - /// - /// - static void InitOrtValue(Tensor&& tensor, OrtValue& ort_value); - - /** - * Create tensor with given type, shape, pre-allocated memory and allocator which will be used to free the pre-allocated memory. - * This function won't check if the preallocated buffer(p_data) has enough room for the shape. - * However, this function will de-allocate the buffer upon the tensor getting destructed. - * \param p_type Data type of the tensor + * Create tensor with given type, shape, pre-allocated memory and allocator which will be used to free the + * pre-allocated memory. The Tensor will take over ownership of p_data. + * This function does not check if the preallocated buffer(p_data) has enough room for the shape. + * \param elt_type Data type of the tensor elements. * \param shape Shape of the tensor * \param p_data A preallocated buffer. Can be NULL if the shape is empty. - * Tensor will own the memory and will delete it when the tensor instance is destructed. + * Tensor will own the memory and will delete it when the tensor instance is destructed. * \param deleter Allocator used to free the pre-allocated memory * \param offset Offset in bytes to start of Tensor within p_data. * \param strides Strides span. Can be empty if the tensor is contiguous. */ - Tensor(MLDataType p_type, const TensorShape& shape, void* p_data, std::shared_ptr deleter, + Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, std::shared_ptr deleter, ptrdiff_t offset = 0, gsl::span strides = {}); + /// + /// Create a Tensor that allocates and owns the buffer required for the specified shape. + /// + /// Data type of the tensor elements. + /// Tensor shape. + /// Allocator to use to create and free buffer. + Tensor(MLDataType elt_type, const TensorShape& shape, std::shared_ptr allocator); + ~Tensor(); // Move is allowed ORT_DISALLOW_COPY_AND_ASSIGNMENT(Tensor); Tensor(Tensor&& other) noexcept; - Tensor& operator=(Tensor&& other) noexcept; + /// + /// Creates an instance of Tensor on the heap and initializes OrtValue with it. + /// + /// Data type of the tensor elements. + /// Tensor shape. + /// Tensor data. + /// Memory info for location of p_data. + /// OrtValue to populate with Tensor. + /// Optional offset if Tensor refers to a subset of p_data. + /// Optional strides if Tensor refers to a subset of p_data. + static void InitOrtValue(MLDataType elt_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, + OrtValue& ort_value, + ptrdiff_t offset = 0, gsl::span strides = {}); + + /// + /// Creates an instance of Tensor on the heap which will take over ownership of the pre-allocated buffer. + /// + /// Data type of the tensor elements. + /// + /// Tensor data. + /// Allocator that was used to create p_data and will be used to free it. + /// OrtValue to populate with Tensor. + /// Optional offset if Tensor refers to a subset of p_data. + /// Optional strides if Tensor refers to a subset of p_data. + static void InitOrtValue(MLDataType elt_type, const TensorShape& shape, void* p_data, + std::shared_ptr allocator, + OrtValue& ort_value, + ptrdiff_t offset = 0, gsl::span strides = {}); + + /// + /// Creates an instance of Tensor on the heap and initializes OrtValue with it. + /// The Tensor instance will allocate and own the data required for `shape`. + /// + /// Data type of the tensor elements. + /// Tensor shape. + /// Allocator that was used to create p_data and will be used to free it. + /// OrtValue to populate with Tensor. + static void InitOrtValue(MLDataType elt_type, const TensorShape& shape, std::shared_ptr allocator, + OrtValue& ort_value); + + /// + /// Initializes OrtValue with an existing Tensor. + /// + /// Tensor. + /// OrtValue to populate with Tensor. + static void InitOrtValue(Tensor&& tensor, OrtValue& ort_value); + + /// + /// Calculate the required storage for the tensor. + /// + /// Data type of the tensor elements. + /// Tensor shape. + /// Bytes required. + static size_t CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape); + /** Returns the data type. */ @@ -294,7 +295,7 @@ class Tensor final { // More API methods. private: - void Init(MLDataType p_type, + void Init(MLDataType elt_type, const TensorShape& shape, void* p_raw_data, AllocatorPtr deleter, diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 47356c3fe3608..45f81783421e0 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2055,6 +2055,7 @@ struct KernelContext { void* GetGPUComputeStream() const; Logger GetLogger() const; OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const; + OrtKernelContext* GetOrtKernelContext() const { return ctx_; } private: OrtKernelContext* ctx_; diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index fd42824e81a56..887339ebc50d6 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -18,22 +18,81 @@ #include "onnxruntime_cxx_api.h" #include #include +#include #include namespace Ort { namespace Custom { -class TensorBase { +class ArgBase { public: - TensorBase(OrtKernelContext* ctx) : ctx_(ctx) {} - virtual ~TensorBase() {} + ArgBase(OrtKernelContext* ctx, + size_t indice, + bool is_input) : ctx_(ctx), indice_(indice), is_input_(is_input) {} + virtual ~ArgBase(){}; + + protected: + struct KernelContext ctx_; + size_t indice_; + bool is_input_; +}; + +using ArgPtr = std::unique_ptr; +using ArgPtrs = std::vector; + +class TensorBase : public ArgBase { + public: + TensorBase(OrtKernelContext* ctx, + size_t indice, + bool is_input) : ArgBase(ctx, indice, is_input) {} + operator bool() const { return shape_.has_value(); } + const std::vector& Shape() const { + if (!shape_.has_value()) { + ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return shape_.value(); + } + + ONNXTensorElementDataType Type() const { + return type_; + } + + int64_t NumberOfElement() const { + if (shape_.has_value()) { + return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies()); + } else { + return 0; + } + } + + std::string Shape2Str() const { + if (shape_.has_value()) { + std::string shape_str; + for (const auto& dim : *shape_) { + shape_str.append(std::to_string(dim)); + shape_str.append(", "); + } + return shape_str; + } else { + return "empty"; + } + } + + bool IsCpuTensor() const { + return strcmp("Cpu", mem_type_) == 0; + } + + virtual const void* DataRaw() const = 0; + virtual size_t SizeInBytes() const = 0; + protected: - struct KernelContext ctx_; std::optional> shape_; + ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + const char* mem_type_ = "Cpu"; }; template @@ -48,13 +107,14 @@ struct Span { T operator[](size_t indice) const { return data_[indice]; } + const T* data() const { return data_; } }; template class Tensor : public TensorBase { public: using TT = typename std::remove_reference::type; - Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) { + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { if (is_input_) { if (indice >= ctx_.GetInputCount()) { ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); @@ -64,19 +124,6 @@ class Tensor : public TensorBase { shape_ = type_shape_info.GetShape(); } } - const std::vector& Shape() const { - if (!shape_.has_value()) { - ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION); - } - return shape_.value(); - } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies()); - } else { - return 0; - } - } const TT* Data() const { return reinterpret_cast(const_value_.GetTensorRawData()); } @@ -104,10 +151,15 @@ class Tensor : public TensorBase { } return *Data(); } + const void* DataRaw() const override { + return reinterpret_cast(Data()); + } + + size_t SizeInBytes() const override { + return sizeof(TT) * static_cast(NumberOfElement()); + } private: - size_t indice_; - bool is_input_; ConstValue const_value_; // for input TT* data_{}; // for output Span span_; @@ -118,7 +170,7 @@ class Tensor : public TensorBase { public: using strings = std::vector; - Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) { + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { if (is_input_) { if (indice >= ctx_.GetInputCount()) { ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); @@ -147,16 +199,21 @@ class Tensor : public TensorBase { } } } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies()); - } else { - return 0; - } - } const strings& Data() const { return input_strings_; } + const void* DataRaw() const override { + if (input_strings_.size() != 1) { + ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return reinterpret_cast(input_strings_[0].c_str()); + } + size_t SizeInBytes() const override { + if (input_strings_.size() != 1) { + ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return input_strings_[0].size(); + } void SetStringOutput(const strings& ss, const std::vector& dims) { shape_ = dims; std::vector raw; @@ -179,8 +236,6 @@ class Tensor : public TensorBase { } private: - size_t indice_; - bool is_input_; std::vector input_strings_; // for input }; @@ -190,7 +245,7 @@ class Tensor : public TensorBase { using strings = std::vector; using string_views = std::vector; - Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) { + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { if (is_input_) { if (indice >= ctx_.GetInputCount()) { ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); @@ -211,16 +266,21 @@ class Tensor : public TensorBase { } } } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies()); - } else { - return 0; - } - } const string_views& Data() const { return input_string_views_; } + const void* DataRaw() const override { + if (input_string_views_.size() != 1) { + ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return reinterpret_cast(input_string_views_[0].data()); + } + size_t SizeInBytes() const override { + if (input_string_views_.size() != 1) { + ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return input_string_views_[0].size(); + } void SetStringOutput(const strings& ss, const std::vector& dims) { shape_ = dims; std::vector raw; @@ -243,15 +303,101 @@ class Tensor : public TensorBase { } private: - size_t indice_; - bool is_input_; std::vector chars_; // for input std::vector input_string_views_; // for input }; using TensorPtr = std::unique_ptr; +using TensorPtrs = std::vector; -//////////////////////////// OrtLiteCustomOp //////////////////////////////// +struct TensorArray : public ArgBase { + TensorArray(OrtKernelContext* ctx, + size_t start_indice, + bool is_input) : ArgBase(ctx, + start_indice, + is_input) { + if (is_input) { + auto input_count = ctx_.GetInputCount(); + for (size_t ith_input = start_indice; ith_input < input_count; ++ith_input) { + auto const_value = ctx_.GetInput(start_indice); + auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); + auto type = type_shape_info.GetElementType(); + TensorPtr tensor; + switch (type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + tensor = std::make_unique>(ctx, ith_input, true); + break; + default: + ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION); + break; + } + tensors_.emplace_back(tensor.release()); + } // for + } + } + template + T* AllocateOutput(size_t ith_output, const std::vector& shape) { + // ith_output is the indice of output relative to the tensor array + // indice_ + ith_output is the indice relative to context + auto tensor = std::make_unique>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false); + auto raw_output = tensor.get()->Allocate(shape); + tensors_.emplace_back(tensor.release()); + return raw_output; + } + Tensor& AllocateStringTensor(size_t ith_output) { + // ith_output is the indice of output relative to the tensor array + // indice_ + ith_output is the indice relative to context + auto tensor = std::make_unique>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false); + Tensor& output = *tensor; + tensors_.emplace_back(tensor.release()); + return output; + } + size_t Size() const { + return tensors_.size(); + } + const TensorPtr& operator[](size_t ith_input) const { + // ith_input is the indice of output relative to the tensor array + return tensors_.at(ith_input); + } + + private: + TensorPtrs tensors_; +}; + +using Variadic = TensorArray; struct OrtLiteCustomOp : public OrtCustomOp { using ConstOptionalFloatTensor = std::optional&>; @@ -260,34 +406,34 @@ struct OrtLiteCustomOp : public OrtCustomOp { // CreateTuple template static typename std::enable_if>::type - CreateTuple(OrtKernelContext*, std::vector&, size_t, size_t, const std::string&) { + CreateTuple(OrtKernelContext*, ArgPtrs&, size_t, size_t, const std::string&) { return std::make_tuple(); } template static typename std::enable_if::value, std::tuple>::type - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { std::tuple current = std::tuple{context}; - auto next = CreateTuple(context, tensors, num_input, num_output, ep); + auto next = CreateTuple(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } template static typename std::enable_if::value, std::tuple>::type - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { std::tuple current = std::tuple{*context}; - auto next = CreateTuple(context, tensors, num_input, num_output, ep); + auto next = CreateTuple(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } #ifdef ORT_CUDA_CTX template static typename std::enable_if::value, std::tuple>::type - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { thread_local CudaContext cuda_context; cuda_context.Init(*context); std::tuple current = std::tuple{cuda_context}; - auto next = CreateTuple(context, tensors, num_input, num_output, ep); + auto next = CreateTuple(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } #endif @@ -295,143 +441,179 @@ struct OrtLiteCustomOp : public OrtCustomOp { #ifdef ORT_ROCM_CTX template static typename std::enable_if::value, std::tuple>::type - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { thread_local RocmContext rocm_context; rocm_context.Init(*context); std::tuple current = std::tuple{rocm_context}; - auto next = CreateTuple(context, tensors, num_input, num_output, ep); + auto next = CreateTuple(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } #endif -#define CREATE_TUPLE_INPUT(data_type) \ - template \ - static typename std::enable_if*>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if&>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if*>>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_input < num_input) { \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - } \ - template \ - static typename std::enable_if*>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{&reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if&>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if*>>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_input < num_input) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{&reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - } \ - template \ - static typename std::enable_if::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsScalar()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_input < num_input) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsScalar()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - } -#define CREATE_TUPLE_OUTPUT(data_type) \ - template \ - static typename std::enable_if*>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_output, false)); \ - std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if&>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_output, false)); \ - std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if*>>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_output < num_output) { \ - tensors.push_back(std::make_unique>(context, ith_output, false)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_input, true)); + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_input, true)); + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_output, false)); + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_output, false)); + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + +#define CREATE_TUPLE_INPUT(data_type) \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{&reinterpret_cast*>(args.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{&reinterpret_cast*>(args.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } \ + template \ + static typename std::enable_if::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())->AsScalar()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())->AsScalar()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } +#define CREATE_TUPLE_OUTPUT(data_type) \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_output < num_output) { \ + args.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ } #define CREATE_TUPLE(data_type) \ CREATE_TUPLE_INPUT(data_type) \ @@ -491,6 +673,34 @@ struct OrtLiteCustomOp : public OrtCustomOp { } #endif + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + #define PARSE_INPUT_BASE(pack_type, onnx_type) \ template \ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type \ @@ -499,6 +709,12 @@ struct OrtLiteCustomOp : public OrtCustomOp { ParseArgs(input_types, output_types); \ } \ template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + input_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } \ + template \ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same>::value>::type \ ParseArgs(std::vector& input_types, std::vector& output_types) { \ input_types.push_back(onnx_type); \ @@ -585,18 +801,40 @@ struct OrtLiteCustomOp : public OrtCustomOp { return self->output_types_[indice]; }; - OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) { - return INPUT_OUTPUT_OPTIONAL; + OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL; }; - OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp*, size_t) { - return INPUT_OUTPUT_OPTIONAL; + OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL; + }; + + OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { + return 1; + }; + + OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { + return 0; + }; + + OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { + return 1; + }; + + OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { + return 0; }; OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; }; OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; }; OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; }; OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; }; + + OrtCustomOp::CreateKernelV2 = {}; + OrtCustomOp::KernelComputeV2 = {}; + OrtCustomOp::KernelCompute = {}; } const std::string op_name_; @@ -619,12 +857,14 @@ struct OrtLiteCustomOp : public OrtCustomOp { template struct OrtLiteCustomFunc : public OrtLiteCustomOp { using ComputeFn = void (*)(Args...); + using ComputeFnReturnStatus = Status (*)(Args...); using MyType = OrtLiteCustomFunc; struct Kernel { size_t num_input_{}; size_t num_output_{}; ComputeFn compute_fn_{}; + ComputeFnReturnStatus compute_fn_return_status_{}; std::string ep_{}; }; @@ -636,8 +876,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { auto kernel = reinterpret_cast(op_kernel); - std::vector tensors; - auto t = CreateTuple<0, 0, Args...>(context, tensors, kernel->num_input_, kernel->num_output_, kernel->ep_); + std::vector args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t); }; @@ -656,7 +896,36 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { }; } - ComputeFn compute_fn_; + OrtLiteCustomFunc(const char* op_name, + const char* execution_provider, + ComputeFnReturnStatus compute_fn_return_status) : OrtLiteCustomOp(op_name, execution_provider), + compute_fn_return_status_(compute_fn_return_status) { + ParseArgs(input_types_, output_types_); + + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + auto kernel = reinterpret_cast(op_kernel); + std::vector args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + return std::apply([kernel](Args const&... t_args) { Status status = kernel->compute_fn_return_status_(t_args...); return status.release(); }, t); + }; + + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { + auto kernel = std::make_unique(); + kernel->compute_fn_return_status_ = static_cast(this_)->compute_fn_return_status_; + Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); + Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); + auto self = static_cast(this_); + kernel->ep_ = self->execution_provider_; + return reinterpret_cast(kernel.release()); + }; + + OrtCustomOp::KernelDestroy = [](void* op_kernel) { + delete reinterpret_cast(op_kernel); + }; + } + + ComputeFn compute_fn_ = {}; + ComputeFnReturnStatus compute_fn_return_status_ = {}; }; // struct OrtLiteCustomFunc /////////////////////////// OrtLiteCustomStruct /////////////////////////// @@ -679,6 +948,10 @@ template struct OrtLiteCustomStruct : public OrtLiteCustomOp { template using CustomComputeFn = void (CustomOp::*)(Args...); + + template + using CustomComputeFnReturnStatus = Status (CustomOp::*)(Args...); + using MyType = OrtLiteCustomStruct; struct Kernel { @@ -692,18 +965,6 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { const char* execution_provider) : OrtLiteCustomOp(op_name, execution_provider) { init(&CustomOp::Compute); - } - - template - void init(CustomComputeFn) { - ParseArgs(input_types_, output_types_); - - OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { - auto kernel = reinterpret_cast(op_kernel); - std::vector tensors; - auto t = CreateTuple<0, 0, Args...>(context, tensors, kernel->num_input_, kernel->num_output_, kernel->ep_); - std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t); - }; OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); @@ -719,6 +980,28 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { delete reinterpret_cast(op_kernel); }; } + + template + void init(CustomComputeFn) { + ParseArgs(input_types_, output_types_); + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { + auto kernel = reinterpret_cast(op_kernel); + ArgPtrs args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t); + }; + } + + template + void init(CustomComputeFnReturnStatus) { + ParseArgs(input_types_, output_types_); + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + auto kernel = reinterpret_cast(op_kernel); + ArgPtrs args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t); + }; + } }; // struct OrtLiteCustomStruct /////////////////////////// CreateLiteCustomOp //////////////////////////// @@ -731,6 +1014,14 @@ OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, return std::make_unique(op_name, execution_provider, custom_compute_fn).release(); } +template +OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, + const char* execution_provider, + Status (*custom_compute_fn_v2)(Args...)) { + using LiteOp = OrtLiteCustomFunc; + return std::make_unique(op_name, execution_provider, custom_compute_fn_v2).release(); +} + template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider) { @@ -739,4 +1030,4 @@ OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, } } // namespace Custom -} // namespace Ort +} // namespace Ort \ No newline at end of file diff --git a/java/src/main/java/ai/onnxruntime/OrtSession.java b/java/src/main/java/ai/onnxruntime/OrtSession.java index 435f86daa5fe2..fbea13d155507 100644 --- a/java/src/main/java/ai/onnxruntime/OrtSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtSession.java @@ -239,7 +239,7 @@ public Result run(Map inputs, RunOptions runOp */ public Result run(Map inputs, Set requestedOutputs) throws OrtException { - return run(inputs, requestedOutputs, null); + return run(inputs, requestedOutputs, Collections.emptyMap(), null); } /** @@ -259,17 +259,90 @@ public Result run( Set requestedOutputs, RunOptions runOptions) throws OrtException { + return run(inputs, requestedOutputs, Collections.emptyMap(), runOptions); + } + + /** + * Scores an input feed dict, returning the map of pinned outputs. + * + *

The outputs are sorted based on the supplied map traversal order. + * + *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed + * when the result object is closed. + * + * @param inputs The inputs to score. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return The inferred outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. + */ + public Result run( + Map inputs, Map pinnedOutputs) + throws OrtException { + return run(inputs, Collections.emptySet(), pinnedOutputs, null); + } + + /** + * Scores an input feed dict, returning the map of requested and pinned outputs. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed + * when the result object is closed. + * + * @param inputs The inputs to score. + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return The inferred outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. + */ + public Result run( + Map inputs, + Set requestedOutputs, + Map pinnedOutputs) + throws OrtException { + return run(inputs, requestedOutputs, pinnedOutputs, null); + } + + /** + * Scores an input feed dict, returning the map of requested and pinned outputs. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link Result} object, and are not closed + * when the result object is closed. + * + * @param inputs The inputs to score. + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @param runOptions The RunOptions to control this run. + * @return The inferred outputs. + * @throws OrtException If there was an error in native code, the input or output names are + * invalid, or if there are zero or too many inputs or outputs. + */ + public Result run( + Map inputs, + Set requestedOutputs, + Map pinnedOutputs, + RunOptions runOptions) + throws OrtException { if (!closed) { if ((inputs.isEmpty() && (numInputs != 0)) || (inputs.size() > numInputs)) { throw new OrtException( "Unexpected number of inputs, expected [1," + numInputs + ") found " + inputs.size()); } - if (requestedOutputs.isEmpty() || (requestedOutputs.size() > numOutputs)) { + int totalOutputs = requestedOutputs.size() + pinnedOutputs.size(); + if ((totalOutputs == 0) || (totalOutputs > numOutputs)) { throw new OrtException( - "Unexpected number of requestedOutputs, expected [1," + "Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + numOutputs + ") found " - + requestedOutputs.size()); + + totalOutputs); } String[] inputNamesArray = new String[inputs.size()]; long[] inputHandles = new long[inputs.size()]; @@ -284,20 +357,41 @@ public Result run( "Unknown input name " + t.getKey() + ", expected one of " + inputNames.toString()); } } - String[] outputNamesArray = new String[requestedOutputs.size()]; + String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()]; + OnnxValue[] outputValues = new OnnxValue[outputNamesArray.length]; + long[] outputHandles = new long[outputNamesArray.length]; i = 0; + for (Map.Entry e : pinnedOutputs.entrySet()) { + if (outputNames.contains(e.getKey())) { + outputNamesArray[i] = e.getKey(); + outputValues[i] = e.getValue(); + outputHandles[i] = getHandle(e.getValue()); + i++; + } else { + throw new OrtException( + "Unknown output name " + e.getKey() + ", expected one of " + outputNames.toString()); + } + } for (String s : requestedOutputs) { if (outputNames.contains(s)) { - outputNamesArray[i] = s; - i++; + if (!pinnedOutputs.containsKey(s)) { + outputNamesArray[i] = s; + // outputValues and outputHandles can be null/0 for these outputs as ORT will allocate + // them. + i++; + } else { + throw new OrtException( + "Output '" + + s + + "' was found in both the requested outputs and the pinned outputs"); + } } else { throw new OrtException( "Unknown output name " + s + ", expected one of " + outputNames.toString()); } } long runOptionsHandle = runOptions == null ? 0 : runOptions.getNativeHandle(); - - OnnxValue[] outputValues = + boolean[] ownedByResult = run( OnnxRuntime.ortApiHandle, nativeHandle, @@ -307,13 +401,40 @@ public Result run( inputNamesArray.length, outputNamesArray, outputNamesArray.length, + outputValues, + outputHandles, runOptionsHandle); - return new Result(outputNamesArray, outputValues); + return new Result(outputNamesArray, outputValues, ownedByResult); } else { throw new IllegalStateException("Trying to score a closed OrtSession."); } } + /** + * Pulls out the native handle by casting it to the appropriate type. + * + * @param v The OnnxValue. + * @return The native handle. + */ + static long getHandle(OnnxValue v) { + /* + * Note this method exists as interface methods are all public, but we do not want users to be + * able to access the native pointer via a public API so can't add a method to OnnxValue which + * exposes it. + */ + if (v instanceof OnnxTensorLike) { + return ((OnnxTensorLike) v).nativeHandle; + } else if (v instanceof OnnxSequence) { + return ((OnnxSequence) v).nativeHandle; + } else if (v instanceof OnnxMap) { + return ((OnnxMap) v).nativeHandle; + } else { + throw new IllegalArgumentException( + "Unexpected OnnxValue subclass, should be {OnnxTensorLike, OnnxSequence, OnnxMap}, found " + + v.getClass()); + } + } + /** * Gets the metadata for the currently loaded model. * @@ -409,8 +530,9 @@ private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long throws OrtException; /** - * The native run call. runOptionsHandle can be zero (i.e. the null pointer), but all other - * handles must be valid pointers. + * The native run call. runOptionsHandle can be zero (i.e. the null pointer), outputValues can + * contain null entries, and outputHandles can contain zero values (i.e. the null pointer), but + * all other handles must be valid pointers. * * @param apiHandle The pointer to the api. * @param nativeHandle The pointer to the session. @@ -419,12 +541,14 @@ private native NodeInfo[] getOutputInfo(long apiHandle, long nativeHandle, long * @param inputs The input tensors. * @param numInputs The number of inputs. * @param outputNamesArray The requested output names. + * @param outputValues The OnnxValue output array. + * @param outputHandles The OrtValue output pointer array. * @param numOutputs The number of requested outputs. * @param runOptionsHandle The (possibly null) pointer to the run options. - * @return The OnnxValues produced by this run. + * @return A boolean array representing if the OnnxValues were allocated by this run call. * @throws OrtException If the native call failed in some way. */ - private native OnnxValue[] run( + private native boolean[] run( long apiHandle, long nativeHandle, long allocatorHandle, @@ -433,6 +557,8 @@ private native OnnxValue[] run( long numInputs, String[] outputNamesArray, long numOutputs, + OnnxValue[] outputValues, + long[] outputHandles, long runOptionsHandle) throws OrtException; @@ -1417,9 +1543,13 @@ private native void addRunConfigEntry( /** * An {@link AutoCloseable} wrapper around a {@link Map} containing {@link OnnxValue}s. * - *

When this is closed it closes all the {@link OnnxValue}s inside it. If you maintain a - * reference to a value after this object has been closed it will throw an {@link + *

When this is closed it closes all the {@link OnnxValue}s owned by the result object. If you + * maintain a reference to a value after this object has been closed it will throw an {@link * IllegalStateException} upon access. + * + *

{@link OnnxValue}s which are supplied as pinned outputs to a {@code run} call are not closed + * by the {@link Result#close()} method. Ownership of each output can be checked with {@link + * Result#isResultOwner(int)}. */ public static class Result implements AutoCloseable, Iterable> { @@ -1429,6 +1559,8 @@ public static class Result implements AutoCloseable, Iterable list; + private final boolean[] ownedByResult; + private boolean closed; /** @@ -1437,21 +1569,23 @@ public static class Result implements AutoCloseable, IterableThrows {@link IllegalStateException} if the container has been closed, and {@link + * ArrayIndexOutOfBoundsException} if the index is invalid. + * + * @param index The index to lookup. + * @return Is that value owned by this result object? + */ + public boolean isResultOwner(int index) { + if (!closed) { + return ownedByResult[index]; + } else { + throw new IllegalStateException("Result is closed"); + } + } + /** * Returns the number of outputs in this Result. * diff --git a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java index 8c03c5b80433c..49ddf29c22335 100644 --- a/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java +++ b/java/src/main/java/ai/onnxruntime/OrtTrainingSession.java @@ -418,7 +418,7 @@ private static native void setSeed(long apiHandle, long trainingHandle, long see */ public OrtSession.Result trainStep(Map inputs) throws OrtException { - return trainStep(inputs, trainOutputNames, null); + return trainStep(inputs, trainOutputNames, Collections.emptyMap(), null); } /** @@ -432,7 +432,7 @@ public OrtSession.Result trainStep(Map inputs) public OrtSession.Result trainStep( Map inputs, OrtSession.RunOptions runOptions) throws OrtException { - return trainStep(inputs, trainOutputNames, runOptions); + return trainStep(inputs, trainOutputNames, Collections.emptyMap(), runOptions); } /** @@ -446,14 +446,41 @@ public OrtSession.Result trainStep( public OrtSession.Result trainStep( Map inputs, Set requestedOutputs) throws OrtException { - return trainStep(inputs, requestedOutputs, null); + return trainStep(inputs, requestedOutputs, Collections.emptyMap(), null); } /** * Performs a single step of training, accumulating the gradients. * + *

The outputs are sorted based on the supplied map traversal order. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * * @param inputs The inputs (must include both the features and the target). - * @param requestedOutputs The requested outputs. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return Requested outputs produced by the training step. + * @throws OrtException If the native call failed. + */ + public OrtSession.Result trainStep( + Map inputs, Map pinnedOutputs) + throws OrtException { + return trainStep(inputs, Collections.emptySet(), pinnedOutputs, null); + } + + /** + * Performs a single step of training, accumulating the gradients. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * + * @param inputs The inputs (must include both the features and the target). + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. * @param runOptions Run options for controlling this specific call. * @return Requested outputs produced by the training step. * @throws OrtException If the native call failed. @@ -461,6 +488,7 @@ public OrtSession.Result trainStep( public OrtSession.Result trainStep( Map inputs, Set requestedOutputs, + Map pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException { checkClosed(); @@ -472,12 +500,14 @@ public OrtSession.Result trainStep( + ") found " + inputs.size()); } - if (requestedOutputs.isEmpty() || (requestedOutputs.size() > trainOutputNames.size())) { + int numTrainOutputs = trainOutputNames.size(); + int totalOutputs = requestedOutputs.size() + pinnedOutputs.size(); + if ((totalOutputs == 0) || (totalOutputs > numTrainOutputs)) { throw new OrtException( - "Unexpected number of requestedOutputs, expected [1," - + trainOutputNames.size() + "Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + + numTrainOutputs + ") found " - + requestedOutputs.size()); + + totalOutputs); } String[] inputNamesArray = new String[inputs.size()]; long[] inputHandles = new long[inputs.size()]; @@ -492,12 +522,35 @@ public OrtSession.Result trainStep( "Unknown input name " + t.getKey() + ", expected one of " + trainInputNames); } } - String[] outputNamesArray = new String[requestedOutputs.size()]; + String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()]; + OnnxValue[] outputValues = new OnnxValue[outputNamesArray.length]; + long[] outputHandles = new long[outputNamesArray.length]; i = 0; + for (Map.Entry e : pinnedOutputs.entrySet()) { + if (trainOutputNames.contains(e.getKey())) { + outputNamesArray[i] = e.getKey(); + outputValues[i] = e.getValue(); + outputHandles[i] = OrtSession.getHandle(e.getValue()); + i++; + } else { + throw new OrtException( + "Unknown output name " + + e.getKey() + + ", expected one of " + + trainOutputNames.toString()); + } + } for (String s : requestedOutputs) { if (trainOutputNames.contains(s)) { - outputNamesArray[i] = s; - i++; + if (!pinnedOutputs.containsKey(s)) { + outputNamesArray[i] = s; + // outputValues and outputHandles can be null/0 for these outputs as ORT will allocate + // them. + i++; + } else { + throw new OrtException( + "Output '" + s + "' was found in both the requested outputs and the pinned outputs"); + } } else { throw new OrtException( "Unknown output name " + s + ", expected one of " + trainOutputNames.toString()); @@ -505,7 +558,7 @@ public OrtSession.Result trainStep( } long runOptionsHandle = runOptions == null ? 0 : runOptions.getNativeHandle(); - OnnxValue[] outputValues = + boolean[] ownedByResult = trainStep( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -516,8 +569,10 @@ public OrtSession.Result trainStep( inputNamesArray.length, outputNamesArray, outputNamesArray.length, + outputValues, + outputHandles, runOptionsHandle); - return new OrtSession.Result(outputNamesArray, outputValues); + return new OrtSession.Result(outputNamesArray, outputValues, ownedByResult); } /* @@ -540,7 +595,7 @@ public OrtSession.Result trainStep( * run_options, size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, size_t * outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); */ - private native OnnxValue[] trainStep( + private native boolean[] trainStep( long apiHandle, long trainingApiHandle, long nativeHandle, @@ -550,6 +605,8 @@ private native OnnxValue[] trainStep( long numInputs, String[] outputNamesArray, long numOutputs, + OnnxValue[] outputValues, + long[] outputHandles, long runOptionsHandle); /** @@ -561,7 +618,7 @@ private native OnnxValue[] trainStep( */ public OrtSession.Result evalStep(Map inputs) throws OrtException { - return evalStep(inputs, evalOutputNames, null); + return evalStep(inputs, evalOutputNames, Collections.emptyMap(), null); } /** @@ -575,7 +632,7 @@ public OrtSession.Result evalStep(Map inputs) public OrtSession.Result evalStep( Map inputs, OrtSession.RunOptions runOptions) throws OrtException { - return evalStep(inputs, evalOutputNames, runOptions); + return evalStep(inputs, evalOutputNames, Collections.emptyMap(), runOptions); } /** @@ -589,14 +646,41 @@ public OrtSession.Result evalStep( public OrtSession.Result evalStep( Map inputs, Set requestedOutputs) throws OrtException { - return evalStep(inputs, requestedOutputs, null); + return evalStep(inputs, requestedOutputs, Collections.emptyMap(), null); } /** * Performs a single evaluation step using the supplied inputs. * - * @param inputs The model inputs. - * @param requestedOutputs The requested output names. + *

The outputs are sorted based on the supplied map traversal order. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * + * @param inputs The inputs to score. + * @param pinnedOutputs The requested outputs which the user has allocated. + * @return The requested outputs. + * @throws OrtException If the native call failed. + */ + public OrtSession.Result evalStep( + Map inputs, Map pinnedOutputs) + throws OrtException { + return evalStep(inputs, Collections.emptySet(), pinnedOutputs, null); + } + + /** + * Performs a single evaluation step using the supplied inputs. + * + *

The outputs are sorted based on the supplied set traversal order with pinned outputs first, + * then requested outputs. An {@link IllegalArgumentException} is thrown if the same output name + * appears in both the requested outputs and the pinned outputs. + * + *

Note: pinned outputs are not owned by the {@link OrtSession.Result} object, and are + * not closed when the result object is closed. + * + * @param inputs The inputs to score. + * @param requestedOutputs The requested outputs which ORT will allocate. + * @param pinnedOutputs The requested outputs which the user has allocated. * @param runOptions Run options for controlling this specific call. * @return The requested outputs. * @throws OrtException If the native call failed. @@ -604,6 +688,7 @@ public OrtSession.Result evalStep( public OrtSession.Result evalStep( Map inputs, Set requestedOutputs, + Map pinnedOutputs, OrtSession.RunOptions runOptions) throws OrtException { checkClosed(); @@ -615,12 +700,14 @@ public OrtSession.Result evalStep( + ") found " + inputs.size()); } - if (requestedOutputs.isEmpty() || (requestedOutputs.size() > evalOutputNames.size())) { + int numEvalOutputs = evalOutputNames.size(); + int totalOutputs = requestedOutputs.size() + pinnedOutputs.size(); + if ((totalOutputs == 0) || (totalOutputs > numEvalOutputs)) { throw new OrtException( - "Unexpected number of requestedOutputs, expected [1," - + evalOutputNames.size() + "Unexpected number of requestedOutputs & pinnedOutputs, expected [1," + + numEvalOutputs + ") found " - + requestedOutputs.size()); + + totalOutputs); } String[] inputNamesArray = new String[inputs.size()]; long[] inputHandles = new long[inputs.size()]; @@ -635,12 +722,35 @@ public OrtSession.Result evalStep( "Unknown input name " + t.getKey() + ", expected one of " + evalInputNames.toString()); } } - String[] outputNamesArray = new String[requestedOutputs.size()]; + String[] outputNamesArray = new String[requestedOutputs.size() + pinnedOutputs.size()]; + OnnxValue[] outputValues = new OnnxValue[outputNamesArray.length]; + long[] outputHandles = new long[outputNamesArray.length]; i = 0; + for (Map.Entry e : pinnedOutputs.entrySet()) { + if (evalOutputNames.contains(e.getKey())) { + outputNamesArray[i] = e.getKey(); + outputValues[i] = e.getValue(); + outputHandles[i] = OrtSession.getHandle(e.getValue()); + i++; + } else { + throw new OrtException( + "Unknown output name " + + e.getKey() + + ", expected one of " + + evalOutputNames.toString()); + } + } for (String s : requestedOutputs) { if (evalOutputNames.contains(s)) { - outputNamesArray[i] = s; - i++; + if (!pinnedOutputs.containsKey(s)) { + outputNamesArray[i] = s; + // outputValues and outputHandles can be null/0 for these outputs as ORT will allocate + // them. + i++; + } else { + throw new OrtException( + "Output '" + s + "' was found in both the requested outputs and the pinned outputs"); + } } else { throw new OrtException( "Unknown output name " + s + ", expected one of " + evalOutputNames.toString()); @@ -648,7 +758,7 @@ public OrtSession.Result evalStep( } long runOptionsHandle = runOptions == null ? 0 : runOptions.getNativeHandle(); - OnnxValue[] outputValues = + boolean[] ownedByResult = evalStep( OnnxRuntime.ortApiHandle, OnnxRuntime.ortTrainingApiHandle, @@ -659,8 +769,10 @@ public OrtSession.Result evalStep( inputNamesArray.length, outputNamesArray, outputNamesArray.length, + outputValues, + outputHandles, runOptionsHandle); - return new OrtSession.Result(outputNamesArray, outputValues); + return new OrtSession.Result(outputNamesArray, outputValues, ownedByResult); } /* @@ -682,7 +794,7 @@ public OrtSession.Result evalStep( * run_options, size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, size_t * outputs_len, _Inout_updates_all_(outputs_len) OrtValue** outputs); */ - private native OnnxValue[] evalStep( + private native boolean[] evalStep( long apiHandle, long trainingApiHandle, long nativeHandle, @@ -692,6 +804,8 @@ private native OnnxValue[] evalStep( long numInputs, String[] outputNamesArray, long numOutputs, + OnnxValue[] outputValues, + long[] outputHandles, long runOptionsHandle) throws OrtException; diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 34d635b5c40f8..69ccb954e8afe 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -58,14 +58,37 @@ public enum OnnxTensorType { */ ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16( 16), // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN( - 17), // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ( - 18), // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2( - 19), // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ( - 20); // Non-IEEE floating-point format based on IEEE754 single-precision + /** + * A non-IEEE 8-bit floating point format with 4 exponent bits and 3 mantissa bits, with NaN and + * no infinite values (FN). + * + *

See the float 8 ONNX standard for + * details. + */ + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN(17), + /** + * A non-IEEE 8-bit floating point format with 4 exponent bits and 3 mantissa bits, with NaN, no + * infinite values (FN) and no negative zero (UZ). + * + *

See the float 8 ONNX standard for + * details. + */ + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ(18), + /** + * A non-IEEE 8-bit floating point format with 5 exponent bits and 2 mantissa bits. + * + *

See the float 8 ONNX standard for + * details. + */ + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2(19), + /** + * A non-IEEE 8-bit floating point format with 5 exponent bits and 2 mantissa bits, with NaN, no + * infinite values (FN) and no negative zero (UZ). + * + *

See the float 8 ONNX standard for + * details. + */ + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ(20); /** The int id on the native side. */ public final int value; diff --git a/java/src/main/native/ai_onnxruntime_OrtSession.c b/java/src/main/native/ai_onnxruntime_OrtSession.c index 6f4e34648cf81..f4d5ab080cd31 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession.c @@ -316,14 +316,19 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_getOutputInfo(JNIE /* * Class: ai_onnxruntime_OrtSession * Method: run - * Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue; - * private native OnnxValue[] run(long apiHandle, long nativeHandle, long allocatorHandle, String[] inputNamesArray, long[] inputs, long numInputs, String[] outputNamesArray, long numOutputs) + * Signature: (JJJ[Ljava/lang/String;[JJ[Ljava/lang/String;J[Lai/onnxruntime/OnnxValue;[JJ)[Z + * private native boolean[] run(long apiHandle, long nativeHandle, long allocatorHandle, + * String[] inputNamesArray, long[] inputs, long numInputs, + * String[] outputNamesArray, long numOutputs, + * OnnxValue[] outputValues, long[] outputHandles, + * long runOptionsHandle) throws OrtException; */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, +JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv, jobject jobj, jlong apiHandle, jlong sessionHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray tensorArr, jlong numInputs, jobjectArray outputNamesArr, - jlong numOutputs, jlong runOptionsHandle) { + jlong numOutputs, jobjectArray outputValuesArr, + jlongArray outputHandlesArr, jlong runOptionsHandle) { (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; @@ -331,7 +336,7 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv OrtSession* session = (OrtSession*)sessionHandle; OrtRunOptions* runOptions = (OrtRunOptions*)runOptionsHandle; - jobjectArray outputArray = NULL; + jbooleanArray outputArray = NULL; // Create the buffers for the Java input & output strings, and the input pointers const char** inputNames = allocarray(numInputs, sizeof(char*)); @@ -376,13 +381,19 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv // Release the java array copy of pointers to the tensors. (*jniEnv)->ReleaseLongArrayElements(jniEnv, tensorArr, inputValueLongs, JNI_ABORT); + // Extract a C array of longs which are pointers to the output tensors. + jlong* outputHandleLongs = (*jniEnv)->GetLongArrayElements(jniEnv, outputHandlesArr, NULL); + // Extract the names of the output values. for (int i = 0; i < numOutputs; i++) { javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, outputNamesArr, i); outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaOutputStrings[i], NULL); - outputValues[i] = NULL; + outputValues[i] = (OrtValue*)outputHandleLongs[i]; } + // Release the java array copy of pointers to the outputs. + (*jniEnv)->ReleaseLongArrayElements(jniEnv, outputHandlesArr, outputHandleLongs, JNI_ABORT); + // Actually score the inputs. // ORT_API_STATUS(OrtRun, _Inout_ OrtSession* sess, _In_ OrtRunOptions* run_options, // _In_ const char* const* input_names, _In_ const OrtValue* const* input, size_t input_len, @@ -394,21 +405,26 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtSession_run(JNIEnv* jniEnv goto cleanup_output_values; } - // Construct the output array of ONNXValues - jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, ORTJNI_OnnxValueClassName); - outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL); + // Create the output boolean array denoting if ORT owns the memory for each output. + // Java boolean arrays are initialized to false. + outputArray = (*jniEnv)->NewBooleanArray(jniEnv, safecast_int64_to_jsize(numOutputs)); + jboolean* boolArr = (*jniEnv)->GetBooleanArrayElements(jniEnv, outputArray, NULL); // Convert the output tensors into ONNXValues for (int i = 0; i < numOutputs; i++) { - if (outputValues[i] != NULL) { + if (outputValues[i] != NULL && (*jniEnv)->GetObjectArrayElement(jniEnv, outputValuesArr, i) == NULL) { jobject onnxValue = convertOrtValueToONNXValue(jniEnv, api, allocator, outputValues[i]); if (onnxValue == NULL) { break; // go to cleanup, exception thrown } - (*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, i, onnxValue); + boolArr[i] = 1; + (*jniEnv)->SetObjectArrayElement(jniEnv, outputValuesArr, i, onnxValue); } } + // Write the output array back to Java. + (*jniEnv)->ReleaseBooleanArrayElements(jniEnv, outputArray, boolArr, 0); + // Note these gotos are in a specific order so they mirror the allocation pattern above. // They must be changed if the allocation code is rearranged. cleanup_output_values: diff --git a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c index b3b530a8b15aa..9f7b8d3a3dcfc 100644 --- a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022 Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2022, 2023, Oracle and/or its affiliates. All rights reserved. * Licensed under the MIT License. */ #include @@ -330,12 +330,12 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtTrainingSession_lazyResetGrad /* * Class: ai_onnxruntime_OrtTrainingSession * Method: trainStep - * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue; + * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;J[Lai/onnxruntime/OnnxValue;[JJ)[Z */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep +JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainApiHandle, jlong nativeHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray inputHandles, jlong numInputs, - jobjectArray outputNamesArr, jlong numOutputs, jlong runOptionsHandle) { + jobjectArray outputNamesArr, jlong numOutputs, jobjectArray outputValuesArr, jlongArray outputHandlesArr, jlong runOptionsHandle) { (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; const OrtTrainingApi* trainApi = (const OrtTrainingApi*)trainApiHandle; @@ -343,31 +343,31 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep OrtTrainingSession* trainSession = (OrtTrainingSession*)nativeHandle; OrtRunOptions* runOptions = (OrtRunOptions*)runOptionsHandle; - jobjectArray outputArray = NULL; + jbooleanArray outputArray = NULL; // Create the buffers for the Java input & output strings, and the input pointers - const char** inputNames = malloc(sizeof(char*) * numInputs); + const char** inputNames = allocarray(numInputs, sizeof(char*)); if (inputNames == NULL) { // Nothing to cleanup, return and throw exception return outputArray; } - const char** outputNames = malloc(sizeof(char*) * numOutputs); + const char** outputNames = allocarray(numOutputs, sizeof(char*)); if (outputNames == NULL) { goto cleanup_input_names; } - jobject* javaInputStrings = malloc(sizeof(jobject) * numInputs); + jobject* javaInputStrings = allocarray(numInputs, sizeof(jobject)); if (javaInputStrings == NULL) { goto cleanup_output_names; } - jobject* javaOutputStrings = malloc(sizeof(jobject) * numOutputs); + jobject* javaOutputStrings = allocarray(numOutputs, sizeof(jobject)); if (javaOutputStrings == NULL) { goto cleanup_java_input_strings; } - const OrtValue** inputValuePtrs = malloc(sizeof(OrtValue*) * numInputs); + const OrtValue** inputValuePtrs = allocarray(numInputs, sizeof(OrtValue*)); if (inputValuePtrs == NULL) { goto cleanup_java_output_strings; } - OrtValue** outputValues = malloc(sizeof(OrtValue*) * numOutputs); + OrtValue** outputValues = allocarray(numOutputs, sizeof(OrtValue*)); if (outputValues == NULL) { goto cleanup_input_values; } @@ -388,13 +388,19 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep // Release the java array copy of pointers to the tensors. (*jniEnv)->ReleaseLongArrayElements(jniEnv, inputHandles, inputValueLongs, JNI_ABORT); + // Extract a C array of longs which are pointers to the output tensors. + jlong* outputHandleLongs = (*jniEnv)->GetLongArrayElements(jniEnv, outputHandlesArr, NULL); + // Extract the names of the output values. for (int i = 0; i < numOutputs; i++) { javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, outputNamesArr, i); outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaOutputStrings[i], NULL); - outputValues[i] = NULL; + outputValues[i] = (OrtValue*)outputHandleLongs[i]; } + // Release the java array copy of pointers to the outputs. + (*jniEnv)->ReleaseLongArrayElements(jniEnv, outputHandlesArr, outputHandleLongs, JNI_ABORT); + // Actually score the inputs. //ORT_API2_STATUS(TrainStep, _Inout_ OrtTrainingSession* sess, _In_opt_ const OrtRunOptions* run_options, // size_t inputs_len, _In_reads_(inputs_len) const OrtValue* const* inputs, @@ -406,24 +412,29 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep goto cleanup_output_values; } - // Construct the output array of ONNXValues - jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxValue"); - outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL); + // Create the output boolean array denoting if ORT owns the memory for each output. + // Java boolean arrays are initialized to false. + outputArray = (*jniEnv)->NewBooleanArray(jniEnv, safecast_int64_to_jsize(numOutputs)); + jboolean* boolArr = (*jniEnv)->GetBooleanArrayElements(jniEnv, outputArray, NULL); // Convert the output tensors into ONNXValues for (int i = 0; i < numOutputs; i++) { - if (outputValues[i] != NULL) { + if (outputValues[i] != NULL && (*jniEnv)->GetObjectArrayElement(jniEnv, outputValuesArr, i) == NULL) { jobject onnxValue = convertOrtValueToONNXValue(jniEnv, api, allocator, outputValues[i]); if (onnxValue == NULL) { break; // go to cleanup, exception thrown } - (*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, i, onnxValue); + boolArr[i] = 1; + (*jniEnv)->SetObjectArrayElement(jniEnv, outputValuesArr, i, onnxValue); } } + // Write the output array back to Java. + (*jniEnv)->ReleaseBooleanArrayElements(jniEnv, outputArray, boolArr, 0); + // Note these gotos are in a specific order so they mirror the allocation pattern above. // They must be changed if the allocation code is rearranged. - cleanup_output_values: +cleanup_output_values: free(outputValues); // Release the Java output strings @@ -437,15 +448,15 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep } // Release the buffers - cleanup_input_values: +cleanup_input_values: free((void*)inputValuePtrs); - cleanup_java_output_strings: +cleanup_java_output_strings: free(javaOutputStrings); - cleanup_java_input_strings: +cleanup_java_input_strings: free(javaInputStrings); - cleanup_output_names: +cleanup_output_names: free((void*)outputNames); - cleanup_input_names: +cleanup_input_names: free((void*)inputNames); return outputArray; @@ -454,12 +465,12 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_trainStep /* * Class: ai_onnxruntime_OrtTrainingSession * Method: evalStep - * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;JJ)[Lai/onnxruntime/OnnxValue; + * Signature: (JJJJ[Ljava/lang/String;[JJ[Ljava/lang/String;J[Lai/onnxruntime/OnnxValue;[JJ)[Z */ -JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep +JNIEXPORT jbooleanArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep (JNIEnv * jniEnv, jobject jobj, jlong apiHandle, jlong trainApiHandle, jlong nativeHandle, jlong allocatorHandle, jobjectArray inputNamesArr, jlongArray inputHandles, jlong numInputs, - jobjectArray outputNamesArr, jlong numOutputs, jlong runOptionsHandle) { + jobjectArray outputNamesArr, jlong numOutputs, jobjectArray outputValuesArr, jlongArray outputHandlesArr, jlong runOptionsHandle) { (void)jobj; // Required JNI parameter not needed by functions which don't need to access their host object. const OrtApi* api = (const OrtApi*)apiHandle; const OrtTrainingApi* trainApi = (const OrtTrainingApi*)trainApiHandle; @@ -467,31 +478,31 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep OrtTrainingSession* trainSession = (OrtTrainingSession*)nativeHandle; OrtRunOptions* runOptions = (OrtRunOptions*)runOptionsHandle; - jobjectArray outputArray = NULL; + jbooleanArray outputArray = NULL; // Create the buffers for the Java input & output strings, and the input pointers - const char** inputNames = malloc(sizeof(char*) * numInputs); + const char** inputNames = allocarray(numInputs, sizeof(char*)); if (inputNames == NULL) { // Nothing to cleanup, return and throw exception return outputArray; } - const char** outputNames = malloc(sizeof(char*) * numOutputs); + const char** outputNames = allocarray(numOutputs, sizeof(char*)); if (outputNames == NULL) { goto cleanup_input_names; } - jobject* javaInputStrings = malloc(sizeof(jobject) * numInputs); + jobject* javaInputStrings = allocarray(numInputs, sizeof(jobject)); if (javaInputStrings == NULL) { goto cleanup_output_names; } - jobject* javaOutputStrings = malloc(sizeof(jobject) * numOutputs); + jobject* javaOutputStrings = allocarray(numOutputs, sizeof(jobject)); if (javaOutputStrings == NULL) { goto cleanup_java_input_strings; } - const OrtValue** inputValuePtrs = malloc(sizeof(OrtValue*) * numInputs); + const OrtValue** inputValuePtrs = allocarray(numInputs, sizeof(OrtValue*)); if (inputValuePtrs == NULL) { goto cleanup_java_output_strings; } - OrtValue** outputValues = malloc(sizeof(OrtValue*) * numOutputs); + OrtValue** outputValues = allocarray(numOutputs, sizeof(OrtValue*)); if (outputValues == NULL) { goto cleanup_input_values; } @@ -512,11 +523,14 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep // Release the java array copy of pointers to the tensors. (*jniEnv)->ReleaseLongArrayElements(jniEnv, inputHandles, inputValueLongs, JNI_ABORT); + // Extract a C array of longs which are pointers to the output tensors. + jlong* outputHandleLongs = (*jniEnv)->GetLongArrayElements(jniEnv, outputHandlesArr, NULL); + // Extract the names of the output values. for (int i = 0; i < numOutputs; i++) { javaOutputStrings[i] = (*jniEnv)->GetObjectArrayElement(jniEnv, outputNamesArr, i); outputNames[i] = (*jniEnv)->GetStringUTFChars(jniEnv, javaOutputStrings[i], NULL); - outputValues[i] = NULL; + outputValues[i] = (OrtValue*)outputHandleLongs[i]; } // Actually score the inputs. @@ -530,24 +544,29 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep goto cleanup_output_values; } - // Construct the output array of ONNXValues - jclass onnxValueClass = (*jniEnv)->FindClass(jniEnv, "ai/onnxruntime/OnnxValue"); - outputArray = (*jniEnv)->NewObjectArray(jniEnv, safecast_int64_to_jsize(numOutputs), onnxValueClass, NULL); + // Create the output boolean array denoting if ORT owns the memory for each output. + // Java boolean arrays are initialized to false. + outputArray = (*jniEnv)->NewBooleanArray(jniEnv, safecast_int64_to_jsize(numOutputs)); + jboolean* boolArr = (*jniEnv)->GetBooleanArrayElements(jniEnv, outputArray, NULL); // Convert the output tensors into ONNXValues for (int i = 0; i < numOutputs; i++) { - if (outputValues[i] != NULL) { + if (outputValues[i] != NULL && (*jniEnv)->GetObjectArrayElement(jniEnv, outputValuesArr, i) == NULL) { jobject onnxValue = convertOrtValueToONNXValue(jniEnv, api, allocator, outputValues[i]); if (onnxValue == NULL) { break; // go to cleanup, exception thrown } - (*jniEnv)->SetObjectArrayElement(jniEnv, outputArray, i, onnxValue); + boolArr[i] = 1; + (*jniEnv)->SetObjectArrayElement(jniEnv, outputValuesArr, i, onnxValue); } } + // Write the output array back to Java. + (*jniEnv)->ReleaseBooleanArrayElements(jniEnv, outputArray, boolArr, 0); + // Note these gotos are in a specific order so they mirror the allocation pattern above. // They must be changed if the allocation code is rearranged. - cleanup_output_values: +cleanup_output_values: free(outputValues); // Release the Java output strings @@ -561,15 +580,15 @@ JNIEXPORT jobjectArray JNICALL Java_ai_onnxruntime_OrtTrainingSession_evalStep } // Release the buffers - cleanup_input_values: +cleanup_input_values: free((void*)inputValuePtrs); - cleanup_java_output_strings: +cleanup_java_output_strings: free(javaOutputStrings); - cleanup_java_input_strings: +cleanup_java_input_strings: free(javaInputStrings); - cleanup_output_names: +cleanup_output_names: free((void*)outputNames); - cleanup_input_names: +cleanup_input_names: free((void*)inputNames); return outputArray; diff --git a/java/src/test/java/ai/onnxruntime/InferenceTest.java b/java/src/test/java/ai/onnxruntime/InferenceTest.java index 08d2a5698d579..e975117fb75bd 100644 --- a/java/src/test/java/ai/onnxruntime/InferenceTest.java +++ b/java/src/test/java/ai/onnxruntime/InferenceTest.java @@ -6,11 +6,14 @@ import static org.junit.jupiter.api.Assertions.assertArrayEquals; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; +import ai.onnxruntime.OrtException.OrtErrorCode; import ai.onnxruntime.OrtSession.Result; import ai.onnxruntime.OrtSession.SessionOptions; import ai.onnxruntime.OrtSession.SessionOptions.ExecutionMode; @@ -31,6 +34,8 @@ import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; import java.util.List; import java.util.Map; import java.util.Optional; @@ -71,7 +76,7 @@ public void environmentTest() { @Test public void testVersion() { String version = env.getVersion(); - Assertions.assertFalse(version.isEmpty()); + assertFalse(version.isEmpty()); } @Test @@ -749,6 +754,151 @@ public void testOverridingInitializer() throws OrtException { } } + @Test + public void testPinnedOutputs() throws OrtException { + String modelPath = TestHelpers.getResourcePath("/java-three-output-matmul.onnx").toString(); + FloatBuffer outputABuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer outputBBuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer outputCBuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer tooSmallBuf = + ByteBuffer.allocateDirect(4 * 2).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer tooBigBuf = + ByteBuffer.allocateDirect(4 * 6).order(ByteOrder.nativeOrder()).asFloatBuffer(); + FloatBuffer wrongShapeBuf = + ByteBuffer.allocateDirect(4 * 4).order(ByteOrder.nativeOrder()).asFloatBuffer(); + LongBuffer wrongTypeBuf = + ByteBuffer.allocateDirect(8 * 4).order(ByteOrder.nativeOrder()).asLongBuffer(); + + try (SessionOptions options = new SessionOptions()) { + try (OrtSession session = env.createSession(modelPath, options); + OnnxTensor t = OnnxTensor.createTensor(env, new float[][] {{1, 2, 3, 4}}); + OnnxTensor outputA = OnnxTensor.createTensor(env, outputABuf, new long[] {1, 4}); + OnnxTensor outputB = OnnxTensor.createTensor(env, outputBBuf, new long[] {1, 4}); + OnnxTensor outputC = OnnxTensor.createTensor(env, outputCBuf, new long[] {1, 4}); + OnnxTensor tooSmall = OnnxTensor.createTensor(env, tooSmallBuf, new long[] {1, 2}); + OnnxTensor tooBig = OnnxTensor.createTensor(env, tooBigBuf, new long[] {1, 6}); + OnnxTensor wrongShape = OnnxTensor.createTensor(env, wrongShapeBuf, new long[] {2, 2}); + OnnxTensor wrongType = OnnxTensor.createTensor(env, wrongTypeBuf, new long[] {1, 4})) { + Map inputMap = Collections.singletonMap("input", t); + Set requestedOutputs = new LinkedHashSet<>(); + Map pinnedOutputs = new LinkedHashMap<>(); + + // Test that all outputs can be pinned + pinnedOutputs.put("output-0", outputA); + pinnedOutputs.put("output-1", outputB); + pinnedOutputs.put("output-2", outputC); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + assertEquals(3, r.size()); + assertSame(outputA, r.get(0)); + assertSame(outputB, r.get(1)); + assertSame(outputC, r.get(2)); + assertFalse(r.isResultOwner(0)); + assertFalse(r.isResultOwner(1)); + assertFalse(r.isResultOwner(2)); + // More tests + } + TestHelpers.zeroBuffer(outputABuf); + TestHelpers.zeroBuffer(outputBBuf); + TestHelpers.zeroBuffer(outputCBuf); + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test a single pinned output + pinnedOutputs.put("output-1", outputB); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + assertEquals(1, r.size()); + assertSame(outputB, r.get(0)); + assertSame(outputB, r.get("output-1").get()); + assertFalse(r.isResultOwner(0)); + // More tests + } + TestHelpers.zeroBuffer(outputABuf); + TestHelpers.zeroBuffer(outputBBuf); + TestHelpers.zeroBuffer(outputCBuf); + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test a mixture of pinned and generated outputs + requestedOutputs.add("output-0"); + requestedOutputs.add("output-2"); + pinnedOutputs.put("output-1", outputB); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + assertEquals(3, r.size()); + // pinned outputs are first + assertSame(outputB, r.get(0)); + assertSame(outputB, r.get("output-1").get()); + // requested outputs are different + assertNotSame(outputA, r.get("output-0").get()); + assertNotSame(outputC, r.get("output-2").get()); + // check ownership. + assertFalse(r.isResultOwner(0)); + assertTrue(r.isResultOwner(1)); + assertTrue(r.isResultOwner(2)); + // More tests + } + TestHelpers.zeroBuffer(outputABuf); + TestHelpers.zeroBuffer(outputBBuf); + TestHelpers.zeroBuffer(outputCBuf); + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that overlapping names causes an error + requestedOutputs.add("output-1"); + pinnedOutputs.put("output-1", outputB); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_JAVA_UNKNOWN, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor of the wrong type causes an error + pinnedOutputs.put("output-0", wrongType); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor of the wrong shape (but right capacity) causes an error. + pinnedOutputs.put("output-1", wrongShape); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor which is too small causes an error + pinnedOutputs.put("output-1", tooSmall); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + + // Test that a tensor which is too large causes an error + pinnedOutputs.put("output-1", tooBig); + try (OrtSession.Result r = session.run(inputMap, requestedOutputs, pinnedOutputs)) { + fail("Should have thrown OrtException"); + } catch (OrtException e) { + assertEquals(OrtErrorCode.ORT_INVALID_ARGUMENT, e.getCode()); + } + requestedOutputs.clear(); + pinnedOutputs.clear(); + } + } + } + private static File getTestModelsDir() throws IOException { // get build directory, append downloaded models location String cwd = System.getProperty("user.dir"); diff --git a/java/src/test/java/ai/onnxruntime/ModelGenerators.java b/java/src/test/java/ai/onnxruntime/ModelGenerators.java index 90fda4c5cf610..7bf7cef43208a 100644 --- a/java/src/test/java/ai/onnxruntime/ModelGenerators.java +++ b/java/src/test/java/ai/onnxruntime/ModelGenerators.java @@ -182,6 +182,102 @@ public void generateMatMul() throws IOException { } } + public void generateThreeOutputMatmul() throws IOException { + OnnxMl.GraphProto.Builder graph = OnnxMl.GraphProto.newBuilder(); + graph.setName("ort-test-three-matmul"); + + // Add placeholders + OnnxMl.ValueInfoProto.Builder input = OnnxMl.ValueInfoProto.newBuilder(); + input.setName("input"); + OnnxMl.TypeProto inputType = + buildTensorTypeNode( + new long[] {-1, 4}, + new String[] {"batch_size", null}, + OnnxMl.TensorProto.DataType.FLOAT); + input.setType(inputType); + graph.addInput(input); + OnnxMl.ValueInfoProto.Builder outputA = OnnxMl.ValueInfoProto.newBuilder(); + outputA.setName("output-0"); + OnnxMl.TypeProto outputType = + buildTensorTypeNode( + new long[] {-1, 4}, + new String[] {"batch_size", null}, + OnnxMl.TensorProto.DataType.FLOAT); + outputA.setType(outputType); + graph.addOutput(outputA); + OnnxMl.ValueInfoProto.Builder outputB = OnnxMl.ValueInfoProto.newBuilder(); + outputB.setName("output-1"); + outputB.setType(outputType); + graph.addOutput(outputB); + OnnxMl.ValueInfoProto.Builder outputC = OnnxMl.ValueInfoProto.newBuilder(); + outputC.setName("output-2"); + outputC.setType(outputType); + graph.addOutput(outputC); + + // Add initializers + OnnxMl.TensorProto.Builder tensor = OnnxMl.TensorProto.newBuilder(); + tensor.addDims(4); + tensor.addDims(4); + Float[] floats = + new Float[] {1f, 2f, 3f, 4f, 5f, 6f, 7f, 8f, 9f, 10f, 11f, 12f, 13f, 14f, 15f, 16f}; + tensor.addAllFloatData(Arrays.asList(floats)); + tensor.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); + tensor.setName("tensor"); + graph.addInitializer(tensor); + OnnxMl.TensorProto.Builder addInit = OnnxMl.TensorProto.newBuilder(); + addInit.addDims(4); + Float[] addFloats = new Float[] {1f, 2f, 3f, 4f}; + addInit.addAllFloatData(Arrays.asList(addFloats)); + addInit.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber()); + addInit.setName("add-init"); + graph.addInitializer(addInit); + + // Add operations + OnnxMl.NodeProto.Builder matmul = OnnxMl.NodeProto.newBuilder(); + matmul.setName("matmul-0"); + matmul.setOpType("MatMul"); + matmul.addInput("input"); + matmul.addInput("tensor"); + matmul.addOutput("matmul-output"); + graph.addNode(matmul); + + OnnxMl.NodeProto.Builder id = OnnxMl.NodeProto.newBuilder(); + id.setName("id-1"); + id.setOpType("Identity"); + id.addInput("matmul-output"); + id.addOutput("output-0"); + graph.addNode(id); + + OnnxMl.NodeProto.Builder add = OnnxMl.NodeProto.newBuilder(); + add.setName("add-2"); + add.setOpType("Add"); + add.addInput("matmul-output"); + add.addInput("add-init"); + add.addOutput("output-1"); + graph.addNode(add); + + OnnxMl.NodeProto.Builder log = OnnxMl.NodeProto.newBuilder(); + log.setName("log-3"); + log.setOpType("Log"); + log.addInput("matmul-output"); + log.addOutput("output-2"); + graph.addNode(log); + + // Build model + OnnxMl.ModelProto.Builder model = OnnxMl.ModelProto.newBuilder(); + model.setGraph(graph); + model.setDocString("ORT three output matmul test"); + model.setModelVersion(0); + model.setIrVersion(8); + model.setDomain("ai.onnxruntime.test"); + model.addOpsetImport(OnnxMl.OperatorSetIdProto.newBuilder().setVersion(18).build()); + try (OutputStream os = + Files.newOutputStream( + Paths.get("src", "test", "resources", "java-three-output-matmul.onnx"))) { + model.build().writeTo(os); + } + } + private static void genCast( String name, OnnxMl.TensorProto.DataType inputDataType, diff --git a/java/src/test/java/ai/onnxruntime/TestHelpers.java b/java/src/test/java/ai/onnxruntime/TestHelpers.java index 7d41918b1c6c7..55d8169434d48 100644 --- a/java/src/test/java/ai/onnxruntime/TestHelpers.java +++ b/java/src/test/java/ai/onnxruntime/TestHelpers.java @@ -262,6 +262,12 @@ public static Path getResourcePath(String path) { return new File(TestHelpers.class.getResource(path).getFile()).toPath(); } + public static void zeroBuffer(FloatBuffer buf) { + for (int i = 0; i < buf.capacity(); i++) { + buf.put(i, 0.0f); + } + } + public static float[] loadTensorFromFile(Path filename) { return loadTensorFromFile(filename, true); } diff --git a/java/src/test/java/ai/onnxruntime/TrainingTest.java b/java/src/test/java/ai/onnxruntime/TrainingTest.java index a02f5a88b2ac5..eaa7da1fc6a16 100644 --- a/java/src/test/java/ai/onnxruntime/TrainingTest.java +++ b/java/src/test/java/ai/onnxruntime/TrainingTest.java @@ -16,7 +16,6 @@ import java.util.Map; import java.util.Set; import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfSystemProperty; @@ -69,8 +68,6 @@ public void testCreateTrainingSessionWithEval() throws OrtException { } } - // this test is not enabled as ORT Java doesn't support supplying an output buffer - @Disabled @Test public void testTrainingSessionTrainStep() throws OrtException { String checkpointPath = TestHelpers.getResourcePath("/checkpoint.ckpt").toString(); @@ -99,14 +96,11 @@ public void testTrainingSessionTrainStep() throws OrtException { ByteBuffer.allocateDirect(4 * expectedOutput.length) .order(ByteOrder.nativeOrder()) .asFloatBuffer(); - OnnxTensor outputTensor = - OnnxTensor.createTensor(env, output, new long[expectedOutput.length]); + OnnxTensor outputTensor = OnnxTensor.createTensor(env, output, new long[0]); outputMap.put("onnx::loss::21273", outputTensor); - /* Disabled as we haven't implemented this yet - try (trainingSession.trainStep(pinnedInputs, outputMap)) { - Assertions.assertArrayEquals(expectedOutput, (float[]) outputTensor.getValue(), 1e-3f); + try (OrtSession.Result r = trainingSession.trainStep(pinnedInputs, outputMap)) { + Assertions.assertEquals(expectedOutput[0], (float) outputTensor.getValue(), 1e-3f); } - */ } finally { OnnxValue.close(outputMap); OnnxValue.close(pinnedInputs); diff --git a/java/src/test/resources/java-three-output-matmul.onnx b/java/src/test/resources/java-three-output-matmul.onnx new file mode 100644 index 0000000000000..fed0bbca460cf Binary files /dev/null and b/java/src/test/resources/java-three-output-matmul.onnx differ diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 59da1369e152e..b7b2ff4537095 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import type {Tensor} from 'onnxruntime-common'; + export declare namespace JSEP { type BackendType = unknown; type AllocFunction = (size: number) => number; @@ -9,11 +11,8 @@ export declare namespace JSEP { type DownloadFunction = (gpuDataId: number, dataOffset: number, size: number) => Promise; type CreateKernelFunction = (name: string, kernel: number, attribute: unknown) => void; type ReleaseKernelFunction = (kernel: number) => void; - type RunFunction = (kernel: number, contextDataOffset: number, sessionState: SessionState) => number; - export interface SessionState { - sessionId: number; - errors: Array>; - } + type RunFunction = + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; } export interface OrtWasmModule extends EmscriptenModule { @@ -40,14 +39,23 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtFree(stringHandle: number): void; - _OrtCreateTensor(dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number): - number; + _OrtCreateTensor( + dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number, + dataLocation: number): number; _OrtGetTensorData(tensorHandle: number, dataType: number, dataOffset: number, dimsOffset: number, dimsLength: number): number; _OrtReleaseTensor(tensorHandle: number): void; + _OrtCreateBinding(sessionHandle: number): number; + _OrtBindInput(bindingHandle: number, nameOffset: number, tensorHandle: number): Promise; + _OrtBindOutput(bindingHandle: number, nameOffset: number, tensorHandle: number, location: number): number; + _OrtClearBoundOutputs(ioBindingHandle: number): void; + _OrtReleaseBinding(ioBindingHandle: number): void; + _OrtRunWithBinding( + sessionHandle: number, ioBindingHandle: number, outputCount: number, outputsOffset: number, + runOptionsHandle: number): Promise; _OrtRun( sessionHandle: number, inputNamesOffset: number, inputsOffset: number, inputCount: number, - outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): number; + outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): Promise; _OrtCreateSessionOptions( graphOptimizationLevel: number, enableCpuMemArena: boolean, enableMemPattern: boolean, executionMode: number, @@ -102,17 +110,67 @@ export interface OrtWasmModule extends EmscriptenModule { // #endregion // #region JSEP + /** + * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime. + * This function initializes WebGPU backend and registers a few callbacks that will be called in C++ code. + */ jsepInit? (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction, download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void; + /** + * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). + * + * @param context - specify the kernel context pointer. + * @param index - specify the index of the output. + * @param data - specify the pointer to encoded data of type and dims. + */ _JsepOutput(context: number, index: number, data: number): number; + /** + * [exported from wasm] Get name of an operator node. + * + * @param kernel - specify the kernel pointer. + * @returns the pointer to a C-style UTF8 encoded string representing the node name. + */ _JsepGetNodeName(kernel: number): number; - jsepOnRunStart?(sessionId: number): void; - jsepOnRunEnd?(sessionId: number): Promise; - jsepRunPromise?: Promise; + /** + * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output. + * + * @param sessionId - specify the session ID. + * @param index - specify an integer to represent which input/output it is registering for. For input, it is the + * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index + * corresponding to the session's ouputNames. + * @param buffer - specify the GPU buffer to register. + * @param size - specify the original data size in byte. + * @returns the GPU data ID for the registered GPU buffer. + */ + jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; + /** + * [exported from js_internal_api.js] Unregister all user GPU buffers for a session. + * + * @param sessionId - specify the session ID. + */ + jsepUnregisterBuffers?: (sessionId: number) => void; + /** + * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. + * + * @param dataId - specify the GPU data ID + * @returns the GPU buffer. + */ + jsepGetBuffer: (dataId: number) => GPUBuffer; + /** + * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor. + * + * @param gpuBuffer - specify the GPU buffer + * @param size - specify the original data size in byte. + * @param type - specify the tensor type. + * @returns the generated downloader function. + */ + jsepCreateDownloader: + (gpuBuffer: GPUBuffer, size: number, + type: Tensor.GpuBufferDataTypes) => () => Promise; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 5e77a0343b4ee..5bec562b157ac 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env} from 'onnxruntime-common'; +import {Env, Tensor} from 'onnxruntime-common'; import {configureLogger, LOG_DEBUG} from './log'; -import {TensorView} from './tensor-view'; -import {createGpuDataManager, GpuDataManager} from './webgpu/gpu-data-manager'; +import {createView, TensorView} from './tensor-view'; +import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; import {ComputeContext, GpuData, ProgramInfo, ProgramInfoLoader} from './webgpu/types'; @@ -98,6 +98,11 @@ export class WebGpuBackend { env: Env; + /** + * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. + */ + sessionExternalDataMapping: Map> = new Map(); + async initialize(env: Env): Promise { if (!navigator.gpu) { // WebGPU is not available. @@ -192,11 +197,13 @@ export class WebGpuBackend { } flush(): void { - this.endComputePass(); - this.device.queue.submit([this.getCommandEncoder().finish()]); - this.gpuDataManager.refreshPendingBuffers(); - this.commandEncoder = null; - this.pendingDispatchNumber = 0; + if (this.commandEncoder) { + this.endComputePass(); + this.device.queue.submit([this.getCommandEncoder().finish()]); + this.gpuDataManager.refreshPendingBuffers(); + this.commandEncoder = null; + this.pendingDispatchNumber = 0; + } } /** @@ -304,12 +311,9 @@ export class WebGpuBackend { } async download(gpuDataId: number, getTargetBuffer: () => Uint8Array): Promise { - const arrayBuffer = await this.gpuDataManager.download(gpuDataId); - // the underlying buffer may be changed after the async function is called. so we use a getter function to make sure // the buffer is up-to-date. - const data = getTargetBuffer(); - data.set(new Uint8Array(arrayBuffer, 0, data.byteLength)); + await this.gpuDataManager.download(gpuDataId, getTargetBuffer); } alloc(size: number): number { @@ -372,7 +376,7 @@ export class WebGpuBackend { kernelEntry(context, attributes[1]); return 0; // ORT_OK } catch (e) { - LOG_DEBUG('warning', `[WebGPU] Kernel "[${opType}] ${nodeName}" failed. Error: ${e}`); + errors.push(Promise.resolve(`[WebGPU] Kernel "[${opType}] ${nodeName}" failed. ${e}`)); return 1; // ORT_FAIL } finally { if (useErrorScope) { @@ -387,4 +391,40 @@ export class WebGpuBackend { this.currentKernelId = null; } } + + // #region external buffer + registerBuffer(sessionId: number, index: number, buffer: GPUBuffer, size: number): number { + let sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId); + if (!sessionInputOutputMapping) { + sessionInputOutputMapping = new Map(); + this.sessionExternalDataMapping.set(sessionId, sessionInputOutputMapping); + } + + const previousBuffer = sessionInputOutputMapping.get(index); + const id = this.gpuDataManager.registerExternalBuffer(buffer, size, previousBuffer?.[1]); + sessionInputOutputMapping.set(index, [id, buffer]); + return id; + } + unregisterBuffers(sessionId: number): void { + const sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId); + if (sessionInputOutputMapping) { + sessionInputOutputMapping.forEach(bufferInfo => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[1])); + this.sessionExternalDataMapping.delete(sessionId); + } + } + getBuffer(gpuDataId: number): GPUBuffer { + const gpuData = this.gpuDataManager.get(gpuDataId); + if (!gpuData) { + throw new Error(`no GPU data for buffer: ${gpuDataId}`); + } + return gpuData.buffer; + } + createDownloader(gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes): + () => Promise { + return async () => { + const data = await downloadGpuData(this, gpuBuffer, size); + return createView(data.buffer, type); + }; + } + // #endregion } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 78316cbe1c825..6ff3971d720fd 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -3,7 +3,7 @@ import {Env} from 'onnxruntime-common'; -import {JSEP, OrtWasmModule} from '../binding/ort-wasm'; +import {OrtWasmModule} from '../binding/ort-wasm'; import {DataType, getTensorElementSize} from '../wasm-common'; import {WebGpuBackend} from './backend-webgpu'; @@ -120,6 +120,11 @@ class ComputeContextImpl implements ComputeContext { this.module.HEAPU32[offset++] = dims[i]; } return this.module._JsepOutput(this.opKernelContext, index, data); + } catch (e) { + throw new Error( + `Failed to generate kernel's output[${index}] with dims [${dims}]. ` + + 'If you are running with pre-allocated output, please make sure the output type/dims are correct. ' + + `Error: ${e}`); } finally { this.module.stackRestore(stack); } @@ -138,7 +143,7 @@ export const init = async(module: OrtWasmModule, env: Env): Promise => { init( // backend - {backend}, + backend, // jsepAlloc() (size: number) => backend.alloc(size), @@ -178,13 +183,13 @@ export const init = async(module: OrtWasmModule, env: Env): Promise => { (kernel: number) => backend.releaseKernel(kernel), // jsepRun - (kernel: number, contextDataOffset: number, sessionState: JSEP.SessionState) => { + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { LOG_DEBUG( 'verbose', - () => `[WebGPU] jsepRun: sessionId=${sessionState.sessionId}, kernel=${kernel}, contextDataOffset=${ + () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${ contextDataOffset}`); const context = new ComputeContextImpl(module, backend, contextDataOffset); - return backend.computeKernel(kernel, context, sessionState.errors); + return backend.computeKernel(kernel, context, errors); }); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 92fdd5abc3892..131f7a9bfa29b 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -35,7 +35,7 @@ export interface GpuDataManager { /** * copy data from GPU to CPU. */ - download(id: GpuDataId): Promise; + download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise; /** * refresh the buffers that marked for release. @@ -46,6 +46,19 @@ export interface GpuDataManager { */ refreshPendingBuffers(): void; + /** + * register an external buffer for IO Binding. If the buffer is already registered, return the existing GPU data ID. + * + * GPU data manager only manages a mapping between the buffer and the GPU data ID. It will not manage the lifecycle of + * the external buffer. + */ + registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number; + + /** + * unregister an external buffer for IO Binding. + */ + unregisterExternalBuffer(buffer: GPUBuffer): void; + /** * destroy all gpu buffers. Call this when the session.release is called. */ @@ -62,12 +75,56 @@ interface StorageCacheValue { */ const calcNormalizedBufferSize = (size: number) => Math.ceil(size / 16) * 16; -let guid = 0; +let guid = 1; const createNewGpuDataId = () => guid++; +/** + * exported standard download function. This function is used by the session to download the data from GPU, and also by + * factory to create GPU tensors with the capacity of downloading data from GPU. + * + * @param backend - the WebGPU backend + * @param gpuBuffer - the GPU buffer to download + * @param originalSize - the original size of the data + * @param getTargetBuffer - optional. If provided, the data will be copied to the target buffer. Otherwise, a new buffer + * will be created and returned. + */ +export const downloadGpuData = + async(backend: WebGpuBackend, gpuBuffer: GPUBuffer, originalSize: number, getTargetBuffer?: () => Uint8Array): + Promise => { + const bufferSize = calcNormalizedBufferSize(originalSize); + const gpuReadBuffer = backend.device.createBuffer( + // eslint-disable-next-line no-bitwise + {size: bufferSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ}); + try { + const commandEncoder = backend.getCommandEncoder(); + backend.endComputePass(); + commandEncoder.copyBufferToBuffer( + gpuBuffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */, + 0 /* destination offset */, bufferSize /* size */ + ); + backend.flush(); + + await gpuReadBuffer.mapAsync(GPUMapMode.READ); + + const arrayBuffer = gpuReadBuffer.getMappedRange(); + if (getTargetBuffer) { + // if we already have a CPU buffer to accept the data, no need to clone the ArrayBuffer. + const targetBuffer = getTargetBuffer(); + targetBuffer.set(new Uint8Array(arrayBuffer, 0, originalSize)); + return targetBuffer; + } else { + // the mapped ArrayBuffer will be released when the GPU buffer is destroyed. Need to clone the + // ArrayBuffer. + return new Uint8Array(arrayBuffer.slice(0, originalSize)); + } + } finally { + gpuReadBuffer.destroy(); + } + }; + class GpuDataManagerImpl implements GpuDataManager { // GPU Data ID => GPU Data ( storage buffer ) - storageCache: Map; + private storageCache: Map; // pending buffers for uploading ( data is unmapped ) private buffersForUploadingPending: GPUBuffer[]; @@ -77,11 +134,15 @@ class GpuDataManagerImpl implements GpuDataManager { // The reusable storage buffers for computing. private freeBuffers: Map; + // The external buffers registered users for IO Binding. + private externalBuffers: Map; + constructor(private backend: WebGpuBackend) { this.storageCache = new Map(); this.freeBuffers = new Map(); this.buffersForUploadingPending = []; this.buffersPending = []; + this.externalBuffers = new Map(); } upload(id: GpuDataId, data: Uint8Array): void { @@ -143,6 +204,42 @@ class GpuDataManagerImpl implements GpuDataManager { sourceGpuDataCache.gpuData.buffer, 0, destinationGpuDataCache.gpuData.buffer, 0, size); } + registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number { + let id: number|undefined; + if (previousBuffer) { + id = this.externalBuffers.get(previousBuffer); + if (id === undefined) { + throw new Error('previous buffer is not registered'); + } + if (buffer === previousBuffer) { + LOG_DEBUG( + 'verbose', + () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ + id}, buffer is the same, skip.`); + return id; + } + this.externalBuffers.delete(previousBuffer); + } else { + id = createNewGpuDataId(); + } + + this.storageCache.set(id, {gpuData: {id, type: GpuDataType.default, buffer}, originalSize}); + this.externalBuffers.set(buffer, id); + LOG_DEBUG( + 'verbose', + () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, registered.`); + return id; + } + + unregisterExternalBuffer(buffer: GPUBuffer): void { + const id = this.externalBuffers.get(buffer); + if (id !== undefined) { + this.storageCache.delete(id); + this.externalBuffers.delete(buffer); + LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.unregisterExternalBuffer() => id=${id}`); + } + } + // eslint-disable-next-line no-bitwise create(size: number, usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST): GpuData { const bufferSize = calcNormalizedBufferSize(size); @@ -193,31 +290,13 @@ class GpuDataManagerImpl implements GpuDataManager { return cachedData.originalSize; } - async download(id: GpuDataId): Promise { + async download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise { const cachedData = this.storageCache.get(id); if (!cachedData) { throw new Error('data does not exist'); } - const commandEncoder = this.backend.getCommandEncoder(); - this.backend.endComputePass(); - const bufferSize = calcNormalizedBufferSize(cachedData.originalSize); - const gpuReadBuffer = this.backend.device.createBuffer( - // eslint-disable-next-line no-bitwise - {size: bufferSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ}); - commandEncoder.copyBufferToBuffer( - cachedData.gpuData.buffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */, - 0 /* destination offset */, bufferSize /* size */ - ); - this.backend.flush(); - - return new Promise((resolve) => { - gpuReadBuffer.mapAsync(GPUMapMode.READ).then(() => { - const data = gpuReadBuffer.getMappedRange().slice(0); - gpuReadBuffer.destroy(); - resolve(data); - }); - }); + await downloadGpuData(this.backend, cachedData.gpuData.buffer, cachedData.originalSize, getTargetBuffer); } refreshPendingBuffers(): void { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts new file mode 100644 index 0000000000000..3925e1cb4f564 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -0,0 +1,243 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts +// +// modified to fit the needs of the project + +import {LOG_DEBUG} from '../../../log'; +import {TensorView} from '../../../tensor-view'; +import {ShapeUtil} from '../../../util'; +import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; +import {ConvTransposeAttributes} from '../conv-transpose'; + +import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; +import {utilFunctions} from './conv_util'; +import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; + +const conv2dTransposeCommonSnippet = + (isChannelsLast: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false, + innerElementSize = 4): string => { + const getWSnippet = (innerElementSize: number) => { + switch (innerElementSize) { + case 1: + return 'return W[getIndexFromCoords4D(coord, wShape)];'; + case 4: + return ` + let coord1 = vec4(coordX, coordY, col + 1, rowInner); + let coord2 = vec4(coordX, coordY, col + 2, rowInner); + let coord3 = vec4(coordX, coordY, col + 3, rowInner); + let v0 = W[getIndexFromCoords4D(coord, wShape)]; + let v1 = W[getIndexFromCoords4D(coord1, wShape)]; + let v2 = W[getIndexFromCoords4D(coord2, wShape)]; + let v3 = W[getIndexFromCoords4D(coord3, wShape)]; + return vec4(v0, v1, v2, v3); + `; + default: + throw new Error(`innerElementSize ${innerElementSize} is not supported.`); + } + }; + const coordASnippet = isChannelsLast ? ` + let coord = vec4(batch, iXR, iXC, xCh); + ` : + ` + let coord = vec4(batch, xCh, iXR, iXC); + `; + + const coordResSnippet = isChannelsLast ? ` + let coords = vec4( + batch, + row / outWidth, + row % outWidth, + col); + ` : + ` + let coords = vec4( + batch, + row, + col / outWidth, + col % outWidth); + `; + + const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]'; + const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]'; + const row = isChannelsLast ? 'row' : 'col'; + const col = isChannelsLast ? 'col' : 'row'; + + const readASnippet = ` + let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outRow = ${row} / outWidth; + let outCol = ${row} % outWidth; + + let WRow = ${col} / (filterDims[1] * inChannels); + let WCol = ${col} / inChannels % filterDims[1]; + let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); + let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); + if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { + return ${typeSnippet(innerElementSize)}(0.0); + } + if (xC < 0.0 || xC >= f32(${xWidth}) || fract(xC) > 0.0) { + return ${typeSnippet(innerElementSize)}(0.0); + } + let iXR = i32(xR); + let iXC = i32(xC); + let xCh = ${col} % inChannels; + ${coordASnippet} + return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`; + + const sampleA = isChannelsLast ? ` + let col = colIn * ${innerElementSize}; + if (row < dimAOuter && col < dimInner) { + ${readASnippet} + } + return ${typeSnippet(innerElementSize)}(0.0);` : + ` + let col = colIn * ${innerElementSize}; + if (row < dimInner && col < dimBOuter) { + ${readASnippet} + } + return ${typeSnippet(innerElementSize)}(0.0);`; + + const sampleW = ` + let col = colIn * ${innerElementSize}; + let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); + let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; + if (${ + isChannelsLast ? 'row < dimInner && col < dimBOuter' : + 'row < dimInner && col < dimAOuter'} && coordX >= 0 && coordY >= 0) { + let rowInner = row % inChannels; + let coord = vec4(coordX, coordY, col, rowInner); + ${getWSnippet(innerElementSize)} + } + return ${typeSnippet(innerElementSize)}(0.0); + `; + + + const userCode = ` + ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} + fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} { + ${isChannelsLast ? sampleA : sampleW} + } + + fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${typeSnippet(innerElementSize)} { + ${isChannelsLast ? sampleW : sampleA} + } + + fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${typeSnippet(innerElementSize)}) { + let col = colIn * ${innerElementSize}; + if (row < dimAOuter && col < dimBOuter) { + var value = valueInput; + let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + ${coordResSnippet} + ${biasActivationSnippet(addBias, activation)} + result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value; + } + }`; + return userCode; + }; + +export const createConv2DTransposeMatMulProgramInfo = + (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvTransposeAttributes, + outputShape: readonly number[], dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean, + sequentialAccessByThreads: boolean): ProgramInfo => { + const isChannelsLast = attributes.format === 'NHWC'; + const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; + const batchSize = outputShape[0]; + const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; + const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; + const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; + const isVec4 = + isChannelsLast ? inChannels % 4 === 0 && outChannels % 4 === 0 : outWidth % 4 === 0 && outChannels % 4 === 0; + + // TODO: fine tune size + const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; + const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; + const workGroupSize: [number, number, number] = isVec4 ? + [8, 8, 1] : + [(dispatchX <= 4 || dispatchY <= 4) ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; + const elementsPerThread = + isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 4, dispatchX > 4 && dispatchY <= 4 ? 1 : 4, 1]; + const dispatch = [ + Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), + Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), + Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2]) + ]; + + LOG_DEBUG('verbose', () => `[conv_backprop_mm_webgpu] dispatch = ${dispatch}`); + + const innerElementSize = isVec4 ? 4 : 1; + const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); + + + const declareInputs = [ + `@group(0) @binding(0) var x: array<${isVec4 ? 'vec4' : 'f32'}>;`, + '@group(0) @binding(1) var W: array;' + ]; + let declareFunctions = ''; + if (hasBias) { + declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); + declareFunctions += ` + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; + }`; + } + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), + getShaderSource: () => ` + ${utilFunctions} + ${declareInputs.join('\n')} + @group(0) @binding(${declareInputs.length}) var result: array<${ + isVec4 ? 'vec4' : 'f32'}>; + const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); + const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); + const wShape : vec4 = vec4(${inputs[1].dims.join(',')}); + const outShape : vec4 = vec4(${outputShape.join(',')}); + const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); + const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ + attributes.kernelShape[isChannelsLast ? 2 : 3]}); + const effectiveFilterDims : vec2 = filterDims + vec2( + ${ + attributes.dilations[0] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, + ${ + attributes.dilations[1] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); + const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${ + attributes.pads[0] + attributes.pads[2]})/2, + i32(effectiveFilterDims[1]) - 1 - (${ + attributes.pads[1] + attributes.pads[3]})/2); + const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); + const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); + const dimAOuter : i32 = ${dimAOuter}; + const dimBOuter : i32 = ${dimBOuter}; + const dimInner : i32 = ${dimInner}; + ${declareFunctions} + ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)} + ${ + isVec4 ? + makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) : + makeMatMulPackedSource( + elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined, + sequentialAccessByThreads)}` + }; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index aa912ff4ad48f..c60b94056a360 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -198,14 +198,14 @@ const createConvTranspose2DOpProgramShaderSource = continue; } let idyC: u32 = u32(dyC); - + var inputChannel = groupId * ${inputChannelsPerGroup}; for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) { - let inputChannel = groupId * ${inputChannelsPerGroup} + d2; let xValue = ${ isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')}; dotProd = dotProd + xValue * wValue; + inputChannel = inputChannel + 1; } } } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 13d3a91bb339e..9c05080f7e118 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -62,14 +62,24 @@ const createBinaryOpProgramShader = let assignment: string; if (vectorize) { if (doBroadcast) { - assignment = ` + const isAOneElement = ShapeUtil.size(dimsA) === 1; + const isBOneElement = ShapeUtil.size(dimsB) === 1; + if (isAOneElement || isBOneElement) { + assignment = output.setByOffset( + 'global_idx', + expressionVector( + isAOneElement ? `${a.type.value}(${a.getByOffset('0')}.x)` : a.getByOffset('global_idx'), + isBOneElement ? `${b.type.value}(${b.getByOffset('0')}.x)` : b.getByOffset('global_idx'))); + } else { + assignment = ` let outputIndices = ${output.offsetToIndices('global_idx * 4u')}; let offsetA = calcOffsetA(outputIndices); let offsetB = calcOffsetB(outputIndices); ${ - output.setByOffset( - 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} + output.setByOffset( + 'global_idx', expressionVector(a.getByOffset('offsetA / 4u'), b.getByOffset('offsetB / 4u')))} `; + } } else { assignment = output.setByOffset( 'global_idx', expressionVector(a.getByOffset('global_idx'), b.getByOffset('global_idx'))); @@ -141,6 +151,8 @@ const createBinaryOpProgramInfo = } outputShape = calculatedShape; outputSize = ShapeUtil.size(outputShape); + const isAOneElement = ShapeUtil.size(a.dims) === 1; + const isBOneElement = ShapeUtil.size(b.dims) === 1; // check whether vectorize can be enabled let sharedDimension = 1; @@ -153,7 +165,7 @@ const createBinaryOpProgramInfo = break; } } - if (sharedDimension % 4 === 0) { + if (sharedDimension % 4 === 0 || isAOneElement || isBOneElement) { vectorize = true; } } else { @@ -167,8 +179,7 @@ const createBinaryOpProgramInfo = shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, funcCall, a.dataType, b.dataType, outputDataType, additionalImplementation), outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], - dispatchGroup: () => - ({x: Math.ceil(outputSize / 64 /* workgroup size */ / (vectorize ? 4 : 1) /* vec size */)}) + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}) }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index c054da51a3098..0ab777bfbdee9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -366,7 +366,7 @@ const createIndicesHelper = const getByIndicesImplementation = rank < 2 ? '' : ` fn get_${name}ByIndices(indices: ${type.indices}) -> ${valueType} { - return ${name}[i2o_${name}(indices)]; + return ${getByOffset(`i2o_${name}(indices)`)}; }`; const getImplementation = rank < 2 ? '' : (() => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index 2b418e65e7c3b..59f11d6d9abba 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -7,7 +7,9 @@ import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '. import {createConvTranspose2DProgramInfo} from './3rd-party/conv_backprop_webgpu'; import {ConvAttributes} from './conv'; +import {createConv2DTransposeMatMulProgramInfoLoader} from './conv2dtranspose-mm'; import {parseInternalActivationAttributes} from './fuse-utils'; +import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose'; const computeTotalPad = (inDim: number, stride: number, adj: number, kernel: number, dilation: number, outSize: number) => @@ -62,7 +64,7 @@ const getAdjustedConvTransposeAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims - if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 0) === 0) { + if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 1) === 0) { kernelShape.length = 0; for (let i = 2; i < inputs[1].dims.length; ++i) { kernelShape.push(inputs[1].dims[i]); @@ -94,9 +96,11 @@ const getAdjustedConvTransposeAttributes = // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - Object.assign( - newAttributes, - {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey: attributes.cacheKey}); + const cacheKey = attributes.cacheKey + [ + kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','), + dilations.join(',') + ].join('_'); + Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey}); return newAttributes; }; @@ -216,12 +220,64 @@ const createConvTranspose2DProgramInfoLoader = }; }; +// for transposing weight tensor from [C, M/group, KH, KW] to [KH, KW, M/group, C] +const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [2, 3, 1, 0]}); + const convTranspose2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); + const isChannelsLast = attributes.format === 'NHWC'; + const hasBias = inputs.length === 3; + if (adjustedAttributes.group !== 1) { + context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); + return; + } + const outputShape = adjustedAttributes.outputShape; + const outHeight = outputShape[isChannelsLast ? 1 : 2]; + const outWidth = outputShape[isChannelsLast ? 2 : 3]; + const outChannels = outputShape[isChannelsLast ? 3 : 1]; + const weightHeight = inputs[1].dims[2]; + const weightWidth = inputs[1].dims[3]; + const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; + + const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; + const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; + const dimInner = weightHeight * weightWidth * inputChannels; + + const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; + - context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); + // STEP.1: transpose weight + const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? + context.compute( + { + ...transposeProgramMetadata, + cacheHint: weightTransposeAttribute.cacheKey, + get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm) + }, + {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + if (attributes.wIsConst && !context.kernelCustomData.wT) { + context.kernelCustomData.wT = transposedWeight; + } + + // STEP.2: prepare reshaped inputs + const convTransposeInputs = [inputs[0], transposedWeight]; + if (hasBias) { + if (!isChannelsLast && inputs[2].dims.length === 1) { + convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); + } else { + convTransposeInputs.push(inputs[2]); + } + } + + // STEP.3: compute matmul + context.compute( + createConv2DTransposeMatMulProgramInfoLoader( + convTransposeInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, + sequentialAccessByThreads), + {inputs: convTransposeInputs}); }; + const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => { // extend the input to 2D by adding H dimension const isChannelLast = attributes.format === 'NHWC'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts new file mode 100644 index 0000000000000..da04b5063a9f0 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; +import {ConvTransposeAttributes} from './conv-transpose'; + + +const createConv2DTransposeMatMulProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ + name: 'Conv2DTransposeMatMul', + inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : + [GpuDataType.default, GpuDataType.default], + cacheHint +}); + +export const createConv2DTransposeMatMulProgramInfoLoader = + (inputs: readonly TensorView[], attributes: ConvTransposeAttributes, outputShape: readonly number[], + dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean, + sequentialAccessByThreads: boolean): ProgramInfoLoader => { + const metadata = createConv2DTransposeMatMulProgramMetadata(hasBias, attributes.cacheKey); + return { + ...metadata, + get: () => createConv2DTransposeMatMulProgramInfo( + inputs, metadata, attributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, + sequentialAccessByThreads) + }; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts index 0db060dbec54a..47aae13d6799d 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts @@ -1,13 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; -import {ShaderHelper} from './common'; +import {inputVariable, outputVariable, ShaderHelper} from './common'; export interface GatherAttributes extends AttributeWithCacheKey { axis: number; @@ -30,63 +29,55 @@ const createGatherProgramInfo = const outputShape = inputShape.slice(0); outputShape.splice(axis, 1, ...indicesShape); - const inputDataType = inputs[0].dataType; - const block = ShapeUtil.sizeFromDimension(inputShape, axis + 1); - const elementSize = [DataType.int64, DataType.uint64, DataType.double].includes(inputDataType) ? 2 : 1; - const indicesElementSize = inputs[1].dataType === DataType.int64 ? 2 : 1; - const blockSize = elementSize * block; - const M = ShapeUtil.sizeToDimension(inputShape, axis); - const N = ShapeUtil.size(indicesShape); - const dataBatchElements = ShapeUtil.sizeFromDimension(inputShape, axis) * elementSize; - const gatheredBatchElements = N * block * elementSize; const axisDimLimit = inputShape[axis]; + const outputSize = ShapeUtil.size(outputShape); + + const data = inputVariable('data', inputs[0].dataType, inputs[0].dims); + const indices = inputVariable('inputIndices', inputs[1].dataType, inputs[1].dims); + const output = outputVariable('output', inputs[0].dataType, outputShape); + const calcDataIndices = (): string => { + const indicesRank = indicesShape.length; + let calcStr = `var indicesIndices = ${indices.type.indices}(0);`; + for (let i = 0; i < indicesRank; i++) { + calcStr += `${indicesRank > 1 ? `indicesIndices[${i}]` : 'indicesIndices'} = ${ + outputShape.length > 1 ? `outputIndices[${axis + i}]` : 'outputIndices'};`; + } + calcStr += ` + var idx = ${indices.getByIndices('indicesIndices')}; + if (idx < 0) { + idx = idx + ${axisDimLimit}; + } + var dataIndices = ${data.type.indices}(0); + `; + for (let i = 0, j = 0; i < inputRank; i++) { + if (i === axis) { + calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = u32(idx);`; + j += indicesRank; + } else { + calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = ${ + outputShape.length > 1 ? `outputIndices[${j}]` : 'outputIndices'};`; + j++; + } + } + return calcStr; + }; - const inputSize = ShapeUtil.size(inputShape) * elementSize; - const outputSize = ShapeUtil.size(outputShape) * elementSize; - - const totalGathers = M * N; - // int64 indices would be treated as little endian i32 with assumption they fall in i32 limits - // That assumption is safe as it's not possible to allocate >2gb buffer for input tensor - // Input data will be treated as u32 or two u32 for 8-byte tensors const getShaderSource = (shaderHelper: ShaderHelper) => ` - const N: u32 = ${N}; - const elementSize: u32 = ${elementSize}; - const indicesElementSize: u32 = ${indicesElementSize}; - - @group(0) @binding(0) var input : array; - @group(0) @binding(1) var inputIndices : array; - @group(0) @binding(2) var output: array; - - ${shaderHelper.mainStart()} - let batch: u32 = global_idx / N; - let i: u32 = global_idx % N; - - let srcOffsetBatch: u32 = batch * ${dataBatchElements}; - let dstOffsetBatch: u32 = batch * ${gatheredBatchElements}; - var idx = inputIndices[i * indicesElementSize]; - if (idx < 0) { - idx = idx + ${axisDimLimit}; - } - - let srcOffset = srcOffsetBatch + u32(idx) * ${blockSize}; - let dstOffset = dstOffsetBatch + i * ${blockSize}; - if (srcOffset >= ${inputSize}) { - return; - } - if (dstOffset >= ${outputSize}) { - return; - } - for (var j: u32 = 0; j < ${blockSize}; j++) { - output[dstOffset + j] = input[srcOffset + j]; - } - }`; + ${shaderHelper.declareVariables(data, indices, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let outputIndices = ${output.offsetToIndices('global_idx')}; + ${calcDataIndices()}; + let value = ${data.getByIndices('dataIndices')}; + ${output.setByOffset('global_idx', 'value')}; + }`; return { ...metadata, outputs: [ {dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}, ], getShaderSource, - dispatchGroup: () => ({x: Math.ceil(totalGathers / 64 /* workgroup size */)}) + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 8c8c12fc54ddb..120a0e9de5490 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {PoolConvUtil, ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -22,9 +21,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (inputs[0].dims.length !== 4) { throw new Error('Pool ops supports 2-D inputs only for now.'); } - if (inputs[0].dataType !== DataType.float) { - throw new Error('Invalid input type.'); - } }; const getAdjustedPoolAttributesAndOutputShape = ( @@ -248,7 +244,7 @@ const createAveragePoolProgramInfo = const kernelSize = ShapeUtil.size(adjustedAttributes.kernelShape); const x = inputVariable('x', input.dataType, input.dims); - const dataType = 'f32'; + const dataType = x.type.value; const op1 = 'value += x_val;'; let op2 = ''; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts index 0b8d03ea73b6b..598b1db033c61 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts @@ -17,10 +17,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (inputs.length === 2 && inputs[1].dims.length !== 1) { throw new Error('Invalid axes input dims.'); } - - if (inputs[0].dataType !== DataType.float) { - throw new Error('Invalid input type.'); - } }; export interface ReduceAttributes extends AttributeWithCacheKey { @@ -161,7 +157,7 @@ export const reduceL1 = (context: ComputeContext, attributes: ReduceAttributes): export const reduceL2 = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => - [`var t = f32(0); var value = ${output.type.storage}(0);`, + [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, '', `t = ${input.getByOffset('inputOffset')}; value += (t * t);`, 'value = sqrt(value);', @@ -212,10 +208,10 @@ export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes } return [ - `var value = ${output.type.storage}(0);`, + 'var sum = f32(0);', '', - `value += ${input.getByOffset('inputOffset')};`, - `value = value / ${size}.;`, + `sum += f32(${input.getByOffset('inputOffset')});`, + `let value = ${output.type.value}(sum / ${size});`, ]; }; context.compute(createReduceProgramInfoLoader(context.inputs, 'ReduceMean', attributes, reduceOp), {inputs: [0]}); @@ -266,7 +262,7 @@ export const reduceSum = (context: ComputeContext, attributes: ReduceAttributes) export const reduceSumSquare = (context: ComputeContext, attributes: ReduceAttributes): void => { validateInputs(context.inputs); const reduceOp: ReduceOp = (input, output) => - [`var t = f32(0); var value = ${output.type.storage}(0);`, + [`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`, '', `t = ${input.getByOffset('inputOffset')}; value += t * t;`, '', diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index e5a2d8c2351b8..43f70c23f7193 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -3,20 +3,24 @@ import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; -/** - * tuple elements are: ORT element type; dims; tensor data - */ -export type SerializableTensor = [Tensor.Type, readonly number[], Tensor.DataType]; +export type SerializableTensorMetadata = + [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu']; -/** - * tuple elements are: InferenceSession handle; input names; output names - */ -export type SerializableSessionMetadata = [number, string[], string[]]; +export type GpuBufferMetadata = { + gpuBuffer: Tensor.GpuBufferType; + download?: () => Promise; + dispose?: () => void; +}; -/** - * tuple elements are: modeldata.offset, modeldata.length - */ -export type SerializableModeldata = [number, number]; +export type UnserializableTensorMetadata = + [dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']| + [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned']; + +export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata; + +export type SerializableSessionMetadata = [sessionHandle: number, inputNames: string[], outputNames: string[]]; + +export type SerializableModeldata = [modelDataOffset: number, modelDataLength: number]; interface MessageError { err?: string; @@ -58,10 +62,10 @@ interface MessageReleaseSession extends MessageError { interface MessageRun extends MessageError { type: 'run'; in ?: { - sessionId: number; inputIndices: number[]; inputs: SerializableTensor[]; outputIndices: number[]; + sessionId: number; inputIndices: number[]; inputs: SerializableTensorMetadata[]; outputIndices: number[]; options: InferenceSession.RunOptions; }; - out?: SerializableTensor[]; + out?: SerializableTensorMetadata[]; } interface MesssageEndProfiling extends MessageError { diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 815b223e40379..202209ed3bfed 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -3,7 +3,7 @@ import {Env, env, InferenceSession} from 'onnxruntime-common'; -import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensor} from './proxy-messages'; +import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; import * as core from './wasm-core-impl'; import {initializeWebAssembly} from './wasm-factory'; @@ -22,7 +22,7 @@ const createSessionAllocateCallbacks: Array> = []; const createSessionCallbacks: Array> = []; const releaseSessionCallbacks: Array> = []; -const runCallbacks: Array> = []; +const runCallbacks: Array> = []; const endProfilingCallbacks: Array> = []; const ensureWorker = (): void => { @@ -177,6 +177,10 @@ export const createSessionFinalize = async(modeldata: SerializableModeldata, opt export const createSession = async(model: Uint8Array, options?: InferenceSession.SessionOptions): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + // check unsupported options + if (options?.preferredOutputLocation) { + throw new Error('session option "preferredOutputLocation" is not supported for proxy.'); + } ensureWorker(); return new Promise((resolve, reject) => { createSessionCallbacks.push([resolve, reject]); @@ -202,17 +206,27 @@ export const releaseSession = async(sessionId: number): Promise => { }; export const run = async( - sessionId: number, inputIndices: number[], inputs: SerializableTensor[], outputIndices: number[], - options: InferenceSession.RunOptions): Promise => { + sessionId: number, inputIndices: number[], inputs: TensorMetadata[], outputIndices: number[], + outputs: Array, options: InferenceSession.RunOptions): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + // check inputs location + if (inputs.some(t => t[3] !== 'cpu')) { + throw new Error('input tensor on GPU is not supported for proxy.'); + } + // check outputs location + if (outputs.some(t => t)) { + throw new Error('pre-allocated output tensor is not supported for proxy.'); + } ensureWorker(); - return new Promise((resolve, reject) => { + return new Promise((resolve, reject) => { runCallbacks.push([resolve, reject]); - const message: OrtWasmMessage = {type: 'run', in : {sessionId, inputIndices, inputs, outputIndices, options}}; - proxyWorker!.postMessage(message, core.extractTransferableBuffers(inputs)); + const serializableInputs = inputs as SerializableTensorMetadata[]; // every input is on CPU. + const message: OrtWasmMessage = + {type: 'run', in : {sessionId, inputIndices, inputs: serializableInputs, outputIndices, options}}; + proxyWorker!.postMessage(message, core.extractTransferableBuffers(serializableInputs)); }); } else { - return core.run(sessionId, inputIndices, inputs, outputIndices, options); + return core.run(sessionId, inputIndices, inputs, outputIndices, outputs, options); } }; diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index d8c5ae7886fe4..4e00878d0063b 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -5,12 +5,41 @@ import {readFile} from 'fs'; import {env, InferenceSession, SessionHandler, Tensor} from 'onnxruntime-common'; import {promisify} from 'util'; -import {SerializableModeldata} from './proxy-messages'; +import {SerializableModeldata, TensorMetadata} from './proxy-messages'; import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper'; +import {isGpuBufferSupportedType} from './wasm-common'; let runtimeInitialized: boolean; let runtimeInitializationPromise: Promise|undefined; +const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { + switch (tensor.location) { + case 'cpu': + return [tensor.type, tensor.dims, tensor.data, 'cpu']; + case 'gpu-buffer': + return [tensor.type, tensor.dims, {gpuBuffer: tensor.gpuBuffer}, 'gpu-buffer']; + default: + throw new Error(`invalid data location: ${tensor.location} for ${getName()}`); + } +}; + +const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { + switch (tensor[3]) { + case 'cpu': + return new Tensor(tensor[0], tensor[2], tensor[1]); + case 'gpu-buffer': { + const dataType = tensor[0]; + if (!isGpuBufferSupportedType(dataType)) { + throw new Error(`not supported data type: ${dataType} for deserializing GPU tensor`); + } + const {gpuBuffer, download, dispose} = tensor[2]; + return Tensor.fromGpuBuffer(gpuBuffer, {dataType, dims: tensor[1], download, dispose}); + } + default: + throw new Error(`invalid data location: ${tensor[3]}`); + } +}; + export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { private sessionId: number; @@ -74,25 +103,31 @@ export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { inputIndices.push(index); }); + const outputArray: Array = []; const outputIndices: number[] = []; Object.entries(fetches).forEach(kvp => { const name = kvp[0]; - // TODO: support pre-allocated output + const tensor = kvp[1]; const index = this.outputNames.indexOf(name); if (index === -1) { throw new Error(`invalid output '${name}'`); } + outputArray.push(tensor); outputIndices.push(index); }); - const outputs = - await run(this.sessionId, inputIndices, inputArray.map(t => [t.type, t.dims, t.data]), outputIndices, options); + const inputs = + inputArray.map((t, i) => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`)); + const outputs = outputArray.map( + (t, i) => t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null); + + const results = await run(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); - const result: SessionHandler.ReturnType = {}; - for (let i = 0; i < outputs.length; i++) { - result[this.outputNames[outputIndices[i]]] = new Tensor(outputs[i][0], outputs[i][2], outputs[i][1]); + const resultMap: SessionHandler.ReturnType = {}; + for (let i = 0; i < results.length; i++) { + resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); } - return result; + return resultMap; } startProfiling(): void { diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index 389773f3e8884..b9eff45e890c4 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -164,3 +164,35 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro throw new Error(`unsupported logging level: ${logLevel}`); } }; + +/** + * Check whether the given tensor type is supported by GPU buffer + */ +export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' || + type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32'; + +/** + * Map string data location to integer value + */ +export const dataLocationStringToEnum = (location: Tensor.DataLocation): number => { + switch (location) { + case 'none': + return 0; + case 'cpu': + return 1; + case 'cpu-pinned': + return 2; + case 'texture': + return 3; + case 'gpu-buffer': + return 4; + default: + throw new Error(`unsupported data location: ${location}`); + } +}; + +/** + * Map integer data location to string value + */ +export const dataLocationEnumToString = (location: number): Tensor.DataLocation|undefined => + (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer'] as const)[location]; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index fcca82ab2aa54..5b49a1d4202e3 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -3,10 +3,10 @@ import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata, SerializableTensor} from './proxy-messages'; +import {SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; -import {logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; @@ -60,9 +60,36 @@ export const initRuntime = async(env: Env): Promise => { }; /** - * tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded + * valid data locations for input/output tensors. */ -type SessionMetadata = [number, number[], number[]]; +type SupportedTensorDataLocationForInputOutput = 'cpu'|'cpu-pinned'|'gpu-buffer'; + +type IOBindingState = { + /** + * the handle of IO binding. + */ + readonly handle: number; + + /** + * the preferred location for each output tensor. + * + * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer'. + */ + readonly outputPreferredLocations: readonly SupportedTensorDataLocationForInputOutput[]; + + /** + * enum value of the preferred location for each output tensor. + */ + readonly outputPreferredLocationsEncoded: readonly number[]; +}; + +/** + * tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded; bindingState + */ +type SessionMetadata = [ + inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], + bindingState: IOBindingState|null +]; const activeSessions = new Map(); @@ -92,6 +119,7 @@ export const createSessionFinalize = let sessionHandle = 0; let sessionOptionsHandle = 0; + let ioBindingHandle = 0; let allocs: number[] = []; const inputNamesUTF8Encoded = []; const outputNamesUTF8Encoded = []; @@ -108,6 +136,7 @@ export const createSessionFinalize = const inputNames = []; const outputNames = []; + const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = []; for (let i = 0; i < inputCount; i++) { const name = wasm._OrtGetInputName(sessionHandle, i); if (name === 0) { @@ -122,15 +151,45 @@ export const createSessionFinalize = checkLastError('Can\'t get an output name.'); } outputNamesUTF8Encoded.push(name); - outputNames.push(wasm.UTF8ToString(name)); + const nameString = wasm.UTF8ToString(name); + outputNames.push(nameString); + + if (!BUILD_DEFS.DISABLE_WEBGPU) { + const location = typeof options?.preferredOutputLocation === 'string' ? + options.preferredOutputLocation : + options?.preferredOutputLocation?.[nameString] ?? 'cpu'; + if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { + throw new Error(`Not supported preferred output location: ${location}.`); + } + outputPreferredLocations.push(location); + } + } + + // use IO binding only when at least one output is preffered to be on GPU. + let bindingState: IOBindingState|null = null; + if (!BUILD_DEFS.DISABLE_WEBGPU && outputPreferredLocations.some(l => l === 'gpu-buffer')) { + ioBindingHandle = wasm._OrtCreateBinding(sessionHandle); + if (ioBindingHandle === 0) { + checkLastError('Can\'t create IO binding.'); + } + + bindingState = { + handle: ioBindingHandle, + outputPreferredLocations, + outputPreferredLocationsEncoded: outputPreferredLocations.map(l => dataLocationStringToEnum(l)), + }; } - activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded]); + activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]); return [sessionHandle, inputNames, outputNames]; } catch (e) { inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + if (ioBindingHandle !== 0) { + wasm._OrtReleaseBinding(ioBindingHandle); + } + if (sessionHandle !== 0) { wasm._OrtReleaseSession(sessionHandle); } @@ -161,7 +220,13 @@ export const releaseSession = (sessionId: number): void => { if (!session) { throw new Error(`cannot release session. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + + if (ioBindingState) { + wasm._OrtReleaseBinding(ioBindingState.handle); + } + + wasm.jsepUnregisterBuffers?.(sessionId); inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -169,18 +234,84 @@ export const releaseSession = (sessionId: number): void => { activeSessions.delete(sessionId); }; +const prepareInputOutputTensor = + (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): + void => { + if (!tensor) { + tensorHandles.push(0); + return; + } + + const wasm = getInstance(); + + const dataType = tensor[0]; + const dims = tensor[1]; + const location = tensor[3]; + + let rawData: number; + let dataByteLength: number; + + if (dataType === 'string' && location === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); + } + + if (location === 'gpu-buffer') { + const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; + dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); + } else { + const data = tensor[2]; + + if (Array.isArray(data)) { + // string tensor + dataByteLength = 4 * data.length; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + let dataIndex = rawData / 4; + for (let i = 0; i < data.length; i++) { + if (typeof data[i] !== 'string') { + throw new TypeError(`tensor data at index ${i} is not a string`); + } + wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); + } + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } + + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + let dimIndex = dimsOffset / 4; + dims.forEach(d => wasm.HEAP32[dimIndex++] = d); + const tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(location)); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + } + tensorHandles.push(tensor); + } finally { + wasm.stackRestore(stack); + } + }; + /** * perform inference run */ export const run = async( - sessionId: number, inputIndices: number[], inputs: SerializableTensor[], outputIndices: number[], - options: InferenceSession.RunOptions): Promise => { + sessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], + outputTensors: Array, options: InferenceSession.RunOptions): Promise => { const wasm = getInstance(); const session = activeSessions.get(sessionId); if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; const inputCount = inputIndices.length; const outputCount = outputIndices.length; @@ -188,171 +319,200 @@ export const run = async( let runOptionsHandle = 0; let runOptionsAllocs: number[] = []; - const inputValues: number[] = []; - const inputAllocs: number[] = []; + const inputTensorHandles: number[] = []; + const outputTensorHandles: number[] = []; + const inputOutputAllocs: number[] = []; + + const beforeRunStack = wasm.stackSave(); + const inputValuesOffset = wasm.stackAlloc(inputCount * 4); + const inputNamesOffset = wasm.stackAlloc(inputCount * 4); + const outputValuesOffset = wasm.stackAlloc(outputCount * 4); + const outputNamesOffset = wasm.stackAlloc(outputCount * 4); try { [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); // create input tensors for (let i = 0; i < inputCount; i++) { - const dataType = inputs[i][0]; - const dims = inputs[i][1]; - const data = inputs[i][2]; - - let dataOffset: number; - let dataByteLength: number; - - if (Array.isArray(data)) { - // string tensor - dataByteLength = 4 * data.length; - dataOffset = wasm._malloc(dataByteLength); - inputAllocs.push(dataOffset); - let dataIndex = dataOffset / 4; - for (let i = 0; i < data.length; i++) { - if (typeof data[i] !== 'string') { - throw new TypeError(`tensor data at index ${i} is not a string`); - } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], inputAllocs); - } - } else { - dataByteLength = data.byteLength; - dataOffset = wasm._malloc(dataByteLength); - inputAllocs.push(dataOffset); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), dataOffset); - } + prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]); + } - const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); - try { - let dimIndex = dimsOffset / 4; - dims.forEach(d => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(dataType), dataOffset, dataByteLength, dimsOffset, dims.length); - if (tensor === 0) { - checkLastError(`Can't create tensor for input[${i}].`); - } - inputValues.push(tensor); - } finally { - wasm.stackRestore(stack); - } + // create output tensors + for (let i = 0; i < outputCount; i++) { + prepareInputOutputTensor( + outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]); + } + + let inputValuesIndex = inputValuesOffset / 4; + let inputNamesIndex = inputNamesOffset / 4; + let outputValuesIndex = outputValuesOffset / 4; + let outputNamesIndex = outputNamesOffset / 4; + for (let i = 0; i < inputCount; i++) { + wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i]; + wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]]; + } + for (let i = 0; i < outputCount; i++) { + wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i]; + wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; } - const beforeRunStack = wasm.stackSave(); - const inputValuesOffset = wasm.stackAlloc(inputCount * 4); - const inputNamesOffset = wasm.stackAlloc(inputCount * 4); - const outputValuesOffset = wasm.stackAlloc(outputCount * 4); - const outputNamesOffset = wasm.stackAlloc(outputCount * 4); - - try { - let inputValuesIndex = inputValuesOffset / 4; - let inputNamesIndex = inputNamesOffset / 4; - let outputValuesIndex = outputValuesOffset / 4; - let outputNamesIndex = outputNamesOffset / 4; + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; + + if (inputNamesUTF8Encoded.length !== inputCount) { + throw new Error(`input count from feeds (${ + inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`); + } + + // process inputs for (let i = 0; i < inputCount; i++) { - wasm.HEAPU32[inputValuesIndex++] = inputValues[i]; - wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]]; + const index = inputIndices[i]; + const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]); + if (errorCode !== 0) { + checkLastError(`Can't bind input[${i}] for session=${sessionId}.`); + } } + + // process pre-allocated outputs for (let i = 0; i < outputCount; i++) { - wasm.HEAPU32[outputValuesIndex++] = 0; - wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; + const index = outputIndices[i]; + const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated. + + if (location) { + // output is pre-allocated. bind the tensor. + const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0); + if (errorCode !== 0) { + checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`); + } + } else { + // output is not pre-allocated. reset preferred location. + const errorCode = + wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], 0, outputPreferredLocationsEncoded[index]); + if (errorCode !== 0) { + checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`); + } + } } + } - // jsepOnRunStart is only available when JSEP is enabled. - wasm.jsepOnRunStart?.(sessionId); + let errorCode: number; - // support RunOptions - let errorCode = wasm._OrtRun( + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + errorCode = await wasm._OrtRunWithBinding( + sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle); + } else { + errorCode = await wasm._OrtRun( sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount, outputValuesOffset, runOptionsHandle); + } - const runPromise = wasm.jsepRunPromise; - if (runPromise) { - // jsepRunPromise is a Promise object. It is only available when JSEP is enabled. - // - // OrtRun() is a synchrnous call, but it internally calls async functions. Emscripten's ASYNCIFY allows it to - // work in this way. However, OrtRun() does not return a promise, so when code reaches here, it is earlier than - // the async functions are finished. - // - // To make it work, we created a Promise and resolve the promise when the C++ code actually reaches the end of - // OrtRun(). If the promise exists, we need to await for the promise to be resolved. - errorCode = await runPromise; - } + if (errorCode !== 0) { + checkLastError('failed to call OrtRun().'); + } - const jsepOnRunEnd = wasm.jsepOnRunEnd; - if (jsepOnRunEnd) { - // jsepOnRunEnd is only available when JSEP is enabled. - // - // it returns a promise, which is resolved or rejected when the following async functions are finished: - // - collecting GPU validation errors. - await jsepOnRunEnd(sessionId); + const output: TensorMetadata[] = []; + + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + if (tensor === outputTensorHandles[i]) { + // output tensor is pre-allocated. no need to copy data. + output.push(outputTensors[i]!); + continue; } - const output: SerializableTensor[] = []; + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); - if (errorCode !== 0) { - checkLastError('failed to call OrtRun().'); - } + let keepOutputTensor = false; + let type: Tensor.Type|undefined, dataOffset = 0; + try { + const errorCode = wasm._OrtGetTensorData( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + if (errorCode !== 0) { + checkLastError(`Can't access output tensor data on index ${i}.`); + } + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + const size = dims.reduce((a, b) => a * b, 1); + type = tensorDataTypeEnumToString(dataType); - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); + const preferredLocation = ioBindingState?.outputPreferredLocations[outputIndices[i]]; - let type: Tensor.Type|undefined, dataOffset = 0; - try { - errorCode = wasm._OrtGetTensorData( - tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); - if (errorCode !== 0) { - checkLastError(`Can't access output tensor data on index ${i}.`); + if (type === 'string') { + if (preferredLocation === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); } - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); } - wasm._OrtFree(dimsOffset); - - const size = dims.length === 0 ? 1 : dims.reduce((a, b) => a * b); - type = tensorDataTypeEnumToString(dataType); - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); + output.push([type, dims, stringData, 'cpu']); + } else { + // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU + // tensor for it. There is no mapping GPU buffer for an empty tensor. + if (preferredLocation === 'gpu-buffer' && size > 0) { + const gpuBuffer = wasm.jsepGetBuffer(dataOffset); + const elementSize = getTensorElementSize(dataType); + if (elementSize === undefined || !isGpuBufferSupportedType(type)) { + throw new Error(`Unsupported data type: ${type}`); } - output.push([type, dims, stringData]); + + // do not release the tensor right now. it will be released when user calls tensor.dispose(). + keepOutputTensor = true; + + output.push([ + type, dims, { + gpuBuffer, + download: wasm.jsepCreateDownloader(gpuBuffer, size * elementSize, type), + dispose: () => { + wasm._OrtReleaseTensor(tensor); + } + }, + 'gpu-buffer' + ]); } else { const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); const data = new typedArrayConstructor(size); new Uint8Array(data.buffer, data.byteOffset, data.byteLength) .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); - output.push([type, dims, data]); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); + output.push([type, dims, data, 'cpu']); } + } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); + } + if (!keepOutputTensor) { wasm._OrtReleaseTensor(tensor); } } + } - return output; - } finally { - wasm.stackRestore(beforeRunStack); + if (ioBindingState) { + wasm._OrtClearBoundOutputs(ioBindingState.handle); } + + return output; } finally { - inputValues.forEach(v => wasm._OrtReleaseTensor(v)); - inputAllocs.forEach(p => wasm._free(p)); + wasm.stackRestore(beforeRunStack); + + inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach(p => wasm._free(p)); if (runOptionsHandle !== 0) { wasm._OrtReleaseRunOptions(runOptionsHandle); @@ -380,11 +540,11 @@ export const endProfiling = (sessionId: number): void => { wasm._OrtFree(profileFileName); }; -export const extractTransferableBuffers = (tensors: readonly SerializableTensor[]): ArrayBufferLike[] => { +export const extractTransferableBuffers = (tensors: readonly SerializableTensorMetadata[]): ArrayBufferLike[] => { const buffers: ArrayBufferLike[] = []; for (const tensor of tensors) { const data = tensor[2]; - if (!Array.isArray(data) && data.buffer) { + if (!Array.isArray(data) && 'buffer' in data) { buffers.push(data.buffer); } } diff --git a/js/web/script/build.ts b/js/web/script/build.ts index d3a5be429bfa1..03510ae86b85f 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -176,7 +176,12 @@ npmlog.info('Build', 'Building bundle...'); webpackArgs.push('--env', `-f=${FILTER}`); } npmlog.info('Build.Bundle', `CMD: npx ${webpackArgs.join(' ')}`); - const webpack = spawnSync('npx', webpackArgs, {shell: true, stdio: 'inherit', cwd: ROOT_FOLDER}); + const webpack = spawnSync('npx', webpackArgs, { + shell: true, + stdio: 'inherit', + cwd: ROOT_FOLDER, + env: {...process.env, NODE_OPTIONS: (process.env.NODE_OPTIONS ?? '') + ' --max-old-space-size=5120'} + }); if (webpack.status !== 0) { console.error(webpack.error); process.exit(webpack.status === null ? undefined : webpack.status); diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index f90f568879146..3f903515694db 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -51,6 +51,10 @@ Options: -P[=<...>], --perf[=<...>] Generate performance number. Cannot be used with flag --debug. This flag can be used with a number as value, specifying the total count of test cases to run. The test cases may be used multiple times. Default value is 10. -c, --file-cache Enable file cache. + -i=<...>, --io-binding=<...> Specify the IO binding testing type. Should be one of the following: + none (default) + gpu-tensor use pre-allocated GPU tensors for inputs and outputs + gpu-location use pre-allocated GPU tensors for inputs and set preferredOutputLocation to 'gpu-buffer' *** Session Options *** -u=<...>, --optimized-model-file-path=<...> Specify whether to dump the optimized model. @@ -109,6 +113,7 @@ export declare namespace TestRunnerCliArgs { type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'xnnpack'|'webnn'; type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'; type BundleMode = 'prod'|'dev'|'perf'; + type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location'; } export interface TestRunnerCliArgs { @@ -140,6 +145,8 @@ export interface TestRunnerCliArgs { */ bundleMode: TestRunnerCliArgs.BundleMode; + ioBindingMode: TestRunnerCliArgs.IOBindingMode; + logConfig: Test.Config['log']; /** @@ -416,6 +423,13 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs logConfig.push({category: 'TestRunner.Perf', config: {minimalSeverity: 'verbose'}}); } + // Option: -i=<...>, --io-binding=<...> + const ioBindingArg = args['io-binding'] || args.i; + const ioBindingMode = (typeof ioBindingArg !== 'string') ? 'none' : ioBindingArg; + if (['none', 'gpu-tensor', 'gpu-location'].indexOf(ioBindingMode) === -1) { + throw new Error(`not supported io binding mode ${ioBindingMode}`); + } + // Option: -u, --optimized-model-file-path const optimizedModelFilePath = args['optimized-model-file-path'] || args.u || undefined; if (typeof optimizedModelFilePath !== 'undefined' && typeof optimizedModelFilePath !== 'string') { @@ -455,6 +469,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs npmlog.verbose('TestRunnerCli.Init', ` Env: ${env}`); npmlog.verbose('TestRunnerCli.Init', ` Debug: ${debug}`); npmlog.verbose('TestRunnerCli.Init', ` Backend: ${backend}`); + npmlog.verbose('TestRunnerCli.Init', ` IO Binding Mode: ${ioBindingMode}`); npmlog.verbose('TestRunnerCli.Init', 'Parsing commandline arguments... DONE'); return { @@ -467,6 +482,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs logConfig, profile, times: perf ? times : undefined, + ioBindingMode: ioBindingMode as TestRunnerCliArgs['ioBindingMode'], optimizedModelFilePath, graphOptimizationLevel: graphOptimizationLevel as TestRunnerCliArgs['graphOptimizationLevel'], fileCache, diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index a75321d45f1ef..d8fecec1b8084 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -257,7 +257,7 @@ async function main() { times?: number): Test.ModelTest { if (times === 0) { npmlog.verbose('TestRunnerCli.Init.Model', `Skip test data from folder: ${testDataRootFolder}`); - return {name: path.basename(testDataRootFolder), backend, modelUrl: '', cases: []}; + return {name: path.basename(testDataRootFolder), backend, modelUrl: '', cases: [], ioBinding: args.ioBindingMode}; } let modelUrl: string|null = null; @@ -323,6 +323,16 @@ async function main() { } } + let ioBinding: Test.IOBindingMode; + if (backend !== 'webgpu' && args.ioBindingMode !== 'none') { + npmlog.warn( + 'TestRunnerCli.Init.Model', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`); + ioBinding = 'none'; + } else { + ioBinding = args.ioBindingMode; + } + + npmlog.verbose('TestRunnerCli.Init.Model', 'Finished preparing test data.'); npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); npmlog.verbose('TestRunnerCli.Init.Model', ` Model file: ${modelUrl}`); @@ -330,7 +340,7 @@ async function main() { npmlog.verbose('TestRunnerCli.Init.Model', ` Test set(s): ${cases.length} (${caseCount})`); npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); - return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases}; + return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases, ioBinding}; } function tryLocateModelTestFolder(searchPattern: string): string { @@ -390,6 +400,13 @@ async function main() { for (const test of tests) { test.backend = backend; test.opset = test.opset || {domain: '', version: MAX_OPSET_VERSION}; + if (backend !== 'webgpu' && args.ioBindingMode !== 'none') { + npmlog.warn( + 'TestRunnerCli.Init.Op', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`); + test.ioBinding = 'none'; + } else { + test.ioBinding = args.ioBindingMode; + } } npmlog.verbose('TestRunnerCli.Init.Op', 'Finished preparing test data.'); npmlog.verbose('TestRunnerCli.Init.Op', '==============================================================='); @@ -493,19 +510,13 @@ async function main() { karmaArgs.push('--force-localhost'); } if (webgpu) { - if (browser.includes('Canary')) { - chromiumFlags.push('--enable-dawn-features=allow_unsafe_apis,use_dxc'); - } else { - chromiumFlags.push('--enable-dawn-features=use_dxc'); - chromiumFlags.push('--disable-dawn-features=disallow_unsafe_apis'); - } + // flag 'allow_unsafe_apis' is required to enable experimental features like fp16 and profiling inside pass. + // flag 'use_dxc' is required to enable DXC compiler. + chromiumFlags.push('--enable-dawn-features=allow_unsafe_apis,use_dxc'); } if (webnn) { chromiumFlags.push('--enable-experimental-web-platform-features'); } - if (config.options.globalEnvFlags?.webgpu?.profilingMode === 'default') { - chromiumFlags.push('--disable-dawn-features=disallow_unsafe_apis'); - } karmaArgs.push(`--bundle-mode=${args.bundleMode}`); karmaArgs.push(...chromiumFlags.map(flag => `--chromium-flags=${flag}`)); if (browser.startsWith('Edge')) { diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index a249dc807fa0b..7038e2a4f8766 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -28,6 +28,37 @@ } ] }, + { + "name": "ConvTranspose without bias addition A - NHWC", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [10, 40, 40, 60, 200, 160, 90, 240, 160], + "dims": [1, 1, 3, 3], + "type": "float32" + } + ] + } + ] + }, { "name": "ConvTranspose without bias addition B", "operator": "ConvTranspose", @@ -74,26 +105,22 @@ }, { "data": [ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 ], "dims": [4, 4, 2, 2], "type": "float32" }, { - "data": [0.1, 0.2, 0.3, 0.4], + "data": [65, 66, 67, 68], "dims": [4], "type": "float32" } ], "outputs": [ { - "data": [ - 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.19999694824219, - 100.19999694824219, 100.19999694824219, 100.19999694824219, 100.30000305175781, 100.30000305175781, - 100.30000305175781, 100.30000305175781, 100.4000015258789, 100.4000015258789, 100.4000015258789, - 100.4000015258789 - ], + "data": [3365, 3465, 3565, 3665, 3766, 3866, 3966, 4066, 4167, 4267, 4367, 4467, 4568, 4668, 4768, 4868], "dims": [1, 4, 2, 2], "type": "float32" } @@ -115,7 +142,43 @@ "type": "float32" }, { - "data": [1, 1, 1, 1], + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [5], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [11, 25, 28, 19, 32, 86, 99, 55, 40, 114, 131, 67, 29, 73, 80, 41], + "dims": [1, 1, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose with bias addition B - NHWC", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [6, 8, 7, 9, 15, 11, 8, 12, 9], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], "dims": [1, 1, 2, 2], "type": "float32" }, @@ -127,7 +190,7 @@ ], "outputs": [ { - "data": [11, 19, 20, 12, 20, 43, 46, 23, 22, 49, 52, 25, 13, 25, 26, 14], + "data": [11, 25, 28, 19, 32, 86, 99, 55, 40, 114, 131, 67, 29, 73, 80, 41], "dims": [1, 1, 4, 4], "type": "float32" } @@ -251,7 +314,6 @@ } ] }, - { "name": "ConvTranspose- pointwise", "operator": "ConvTranspose", @@ -285,5 +347,50 @@ ] } ] + }, + { + "name": "ConvTranspose with bias addition C", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [{ "name": "kernel_shape", "data": [1, 1], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ], + "dims": [1, 4, 4, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [4, 4, 1, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1021, 1049, 1077, 1105, 1133, 1161, 1189, 1217, 1245, 1273, 1301, 1329, 1357, 1385, 1413, 1441, 1122, + 1154, 1186, 1218, 1250, 1282, 1314, 1346, 1378, 1410, 1442, 1474, 1506, 1538, 1570, 1602, 1223, 1259, + 1295, 1331, 1367, 1403, 1439, 1475, 1511, 1547, 1583, 1619, 1655, 1691, 1727, 1763, 1324, 1364, 1404, + 1444, 1484, 1524, 1564, 1604, 1644, 1684, 1724, 1764, 1804, 1844, 1884, 1924 + ], + "dims": [1, 4, 4, 4], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 46d80a9f56f35..628e5408150f8 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -14,7 +14,8 @@ import {Operator} from '../lib/onnxjs/operators'; import {onnx} from '../lib/onnxjs/ort-schema/protobuf/onnx'; import {Tensor} from '../lib/onnxjs/tensor'; import {ProtoUtil} from '../lib/onnxjs/util'; -import {tensorDataTypeStringToEnum} from '../lib/wasm/wasm-common'; +import {createView} from '../lib/wasm/jsep/tensor-view'; +import {getTensorElementSize, isGpuBufferSupportedType, tensorDataTypeStringToEnum} from '../lib/wasm/wasm-common'; import {base64toBuffer, createMockGraph, readFile} from './test-shared'; import {Test} from './test-types'; @@ -136,8 +137,8 @@ async function loadTensors( } async function initializeSession( - modelFilePath: string, backendHint: string, profile: boolean, sessionOptions: ort.InferenceSession.SessionOptions, - fileCache?: FileCacheBuffer): Promise { + modelFilePath: string, backendHint: string, ioBindingMode: Test.IOBindingMode, profile: boolean, + sessionOptions: ort.InferenceSession.SessionOptions, fileCache?: FileCacheBuffer): Promise { const preloadModelData: Uint8Array|undefined = fileCache && fileCache[modelFilePath] ? fileCache[modelFilePath] : undefined; Logger.verbose( @@ -146,8 +147,14 @@ async function initializeSession( preloadModelData ? ` [preloaded(${preloadModelData.byteLength})]` : ''}`); const profilerConfig = profile ? {maxNumberEvents: 65536} : undefined; - const sessionConfig = - {...sessionOptions, executionProviders: [backendHint], profiler: profilerConfig, enableProfiling: profile}; + const sessionConfig = { + ...sessionOptions, + executionProviders: [backendHint], + profiler: profilerConfig, + enableProfiling: profile, + preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined + }; + let session: ort.InferenceSession; try { @@ -181,6 +188,7 @@ export class ModelTestContext { readonly session: ort.InferenceSession, readonly backend: string, readonly perfData: ModelTestContext.ModelTestPerfData, + readonly ioBinding: Test.IOBindingMode, private readonly profile: boolean, ) {} @@ -232,8 +240,8 @@ export class ModelTestContext { this.initializing = true; const initStart = now(); - const session = - await initializeSession(modelTest.modelUrl, modelTest.backend!, profile, sessionOptions || {}, this.cache); + const session = await initializeSession( + modelTest.modelUrl, modelTest.backend!, modelTest.ioBinding, profile, sessionOptions || {}, this.cache); const initEnd = now(); for (const testCase of modelTest.cases) { @@ -244,6 +252,7 @@ export class ModelTestContext { session, modelTest.backend!, {init: initEnd - initStart, firstRun: -1, runs: [], count: 0}, + modelTest.ioBinding, profile, ); } finally { @@ -481,6 +490,130 @@ export class TensorResultValidator { } } +function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor { + if (!isGpuBufferSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) { + throw new Error(`createGpuTensorForInput can not work with ${cpuTensor.type} tensor`); + } + const device = ort.env.webgpu.device as GPUDevice; + const gpuBuffer = device.createBuffer({ + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, + size: Math.ceil(cpuTensor.data.byteLength / 16) * 16, + mappedAtCreation: true + }); + const arrayBuffer = gpuBuffer.getMappedRange(); + new Uint8Array(arrayBuffer) + .set(new Uint8Array(cpuTensor.data.buffer, cpuTensor.data.byteOffset, cpuTensor.data.byteLength)); + gpuBuffer.unmap(); + + // TODO: how to "await" for the copy to finish, so that we can get more accurate performance data? + + return ort.Tensor.fromGpuBuffer( + gpuBuffer, {dataType: cpuTensor.type, dims: cpuTensor.dims, dispose: () => gpuBuffer.destroy()}); +} + +function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) { + if (!isGpuBufferSupportedType(type)) { + throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`); + } + + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(type))!; + const size = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + + const device = ort.env.webgpu.device as GPUDevice; + const gpuBuffer = device.createBuffer({ + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, + size: Math.ceil(size / 16) * 16 + }); + + return ort.Tensor.fromGpuBuffer(gpuBuffer, { + dataType: type, + dims, + dispose: () => gpuBuffer.destroy(), + download: async () => { + const stagingBuffer = device.createBuffer({ + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + size: gpuBuffer.size + }); + const encoder = device.createCommandEncoder(); + encoder.copyBufferToBuffer(gpuBuffer, 0, stagingBuffer, 0, gpuBuffer.size); + device.queue.submit([encoder.finish()]); + + await stagingBuffer.mapAsync(GPUMapMode.READ); + const arrayBuffer = stagingBuffer.getMappedRange().slice(0, size); + stagingBuffer.destroy(); + + return createView(arrayBuffer, type) as ort.Tensor.DataTypeMap[ort.Tensor.GpuBufferDataTypes]; + } + }); +} + +export async function sessionRun(options: { + session: ort.InferenceSession; feeds: Record; + outputsMetaInfo: Record>; + ioBinding: Test.IOBindingMode; +}): Promise<[number, number, ort.InferenceSession.OnnxValueMapType]> { + const session = options.session; + const feeds = options.feeds; + const fetches: Record = {}; + + // currently we only support IO Binding for WebGPU + // + // For inputs, we create GPU tensors on both 'gpu-tensor' and 'gpu-location' binding testing mode. + // For outputs, we create GPU tensors on 'gpu-tensor' binding testing mode only. + // in 'gpu-device' binding mode, outputs are not pre-allocated. + const shouldUploadInput = options.ioBinding === 'gpu-tensor' || options.ioBinding === 'gpu-location'; + const shouldUploadOutput = options.ioBinding === 'gpu-tensor'; + try { + if (shouldUploadInput) { + // replace the CPU tensors in feeds into GPU tensors + for (const name in feeds) { + if (Object.hasOwnProperty.call(feeds, name)) { + feeds[name] = createGpuTensorForInput(feeds[name]); + } + } + } + + if (shouldUploadOutput) { + for (const name in options.outputsMetaInfo) { + if (Object.hasOwnProperty.call(options.outputsMetaInfo, name)) { + const {type, dims} = options.outputsMetaInfo[name]; + fetches[name] = createGpuTensorForOutput(type, dims); + } + } + } + + const start = now(); + Logger.verbose('TestRunner', `Timestamp before session run: ${start}`); + const outputs = await ( + shouldUploadOutput ? session.run(feeds, fetches) : + session.run(feeds, Object.getOwnPropertyNames(options.outputsMetaInfo))); + const end = now(); + Logger.verbose('TestRunner', `Timestamp after session run: ${end}`); + + // download each output tensor if needed + for (const name in outputs) { + if (Object.hasOwnProperty.call(outputs, name)) { + const tensor = outputs[name]; + // Tensor.getData(true) release the underlying resource + await tensor.getData(true); + } + } + + return [start, end, outputs]; + } finally { + // dispose the GPU tensors in feeds + for (const name in feeds) { + if (Object.hasOwnProperty.call(feeds, name)) { + const tensor = feeds[name]; + tensor.dispose(); + } + } + } +} + /** * run a single model test case. the inputs/outputs tensors should already been prepared. */ @@ -491,12 +624,11 @@ export async function runModelTestSet( const validator = new TensorResultValidator(context.backend); try { const feeds: Record = {}; + const outputsMetaInfo: Record = {}; testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor); - const start = now(); - Logger.verbose('TestRunner', `Timestamp before session run: ${start}`); - const outputs = await context.session.run(feeds); - const end = now(); - Logger.verbose('TestRunner', `Timestamp after session run: ${end}`); + testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor); + const [start, end, outputs] = + await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding}); if (context.perfData.count === 0) { context.perfData.firstRun = end - start; } else { @@ -575,6 +707,7 @@ export class ProtoOpTestContext { private readonly loadedData: Uint8Array; // model data, inputs, outputs session: ort.InferenceSession; readonly backendHint: string; + readonly ioBindingMode: Test.IOBindingMode; constructor(test: Test.OperatorTest, private readonly sessionOptions: ort.InferenceSession.SessionOptions = {}) { const opsetImport = onnx.OperatorSetIdProto.create(test.opset); const operator = test.operator; @@ -713,6 +846,7 @@ export class ProtoOpTestContext { model.graph.name = test.name; this.backendHint = test.backend!; + this.ioBindingMode = test.ioBinding; this.loadedData = onnx.ModelProto.encode(model).finish(); // in debug mode, open a new tab in browser for the generated onnx model. @@ -729,8 +863,11 @@ export class ProtoOpTestContext { } } async init(): Promise { - this.session = await ort.InferenceSession.create( - this.loadedData, {executionProviders: [this.backendHint], ...this.sessionOptions}); + this.session = await ort.InferenceSession.create(this.loadedData, { + executionProviders: [this.backendHint], + preferredOutputLocation: this.ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined, + ...this.sessionOptions + }); } async dispose(): Promise { @@ -739,10 +876,11 @@ export class ProtoOpTestContext { } async function runProtoOpTestcase( - session: ort.InferenceSession, testCase: Test.OperatorTestCase, validator: TensorResultValidator): Promise { + session: ort.InferenceSession, testCase: Test.OperatorTestCase, ioBindingMode: Test.IOBindingMode, + validator: TensorResultValidator): Promise { const feeds: Record = {}; - const fetches: string[] = []; - testCase.inputs!.forEach((input, i) => { + const fetches: Record> = {}; + testCase.inputs.forEach((input, i) => { if (input.data) { let data: number[]|BigUint64Array|BigInt64Array = input.data; if (input.type === 'uint64') { @@ -756,7 +894,7 @@ async function runProtoOpTestcase( const outputs: ort.Tensor[] = []; const expectedOutputNames: string[] = []; - testCase.outputs!.forEach((output, i) => { + testCase.outputs.forEach((output, i) => { if (output.data) { let data: number[]|BigUint64Array|BigInt64Array = output.data; if (output.type === 'uint64') { @@ -766,11 +904,11 @@ async function runProtoOpTestcase( } outputs.push(new ort.Tensor(output.type, data, output.dims)); expectedOutputNames.push(`output_${i}`); - fetches.push(`output_${i}`); + fetches[`output_${i}`] = {dims: output.dims, type: output.type}; } }); - const results = await session.run(feeds, fetches); + const [, , results] = await sessionRun({session, feeds, outputsMetaInfo: fetches, ioBinding: ioBindingMode}); const actualOutputNames = Object.getOwnPropertyNames(results); expect(actualOutputNames.length).to.equal(expectedOutputNames.length); @@ -821,7 +959,8 @@ async function runOpTestcase( export async function runOpTest( testcase: Test.OperatorTestCase, context: ProtoOpTestContext|OpTestContext): Promise { if (context instanceof ProtoOpTestContext) { - await runProtoOpTestcase(context.session, testcase, new TensorResultValidator(context.backendHint)); + await runProtoOpTestcase( + context.session, testcase, context.ioBindingMode, new TensorResultValidator(context.backendHint)); } else { await runOpTestcase( context.inferenceHandler, context.createOperator(), testcase, new TensorResultValidator(context.backendHint)); diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts index 1f95d1cd8e682..88915e7972383 100644 --- a/js/web/test/test-types.ts +++ b/js/web/test/test-types.ts @@ -43,6 +43,18 @@ export declare namespace Test { */ export type PlatformCondition = string; + /** + * The IOBindingMode represents how to test a model with GPU data. + * + * - none: inputs will be pre-allocated as CPU tensors; no output will be pre-allocated; `preferredOutputLocation` + * will not be set. + * - gpu-location: inputs will be pre-allocated as GPU tensors; no output will be pre-allocated; + * `preferredOutputLocation` will be set to `gpu-buffer`. + * - gpu-tensor: inputs and outputs will all be pre-allocated as GPU tensors. `preferredOutputLocation` + * will not be set. + */ + export type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location'; + export interface ModelTestCase { name: string; dataFiles: readonly string[]; @@ -54,6 +66,7 @@ export declare namespace Test { name: string; modelUrl: string; backend?: string; // value should be populated at build time + ioBinding: IOBindingMode; platformCondition?: PlatformCondition; cases: readonly ModelTestCase[]; } @@ -82,6 +95,7 @@ export declare namespace Test { inputShapeDefinitions?: 'none'|'rankOnly'|'static'|ReadonlyArray; opset?: OperatorTestOpsetImport; backend?: string; // value should be populated at build time + ioBinding: IOBindingMode; platformCondition?: PlatformCondition; attributes?: readonly AttributeValue[]; cases: readonly OperatorTestCase[]; diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index fd147eaa11f3f..0ed7d887fc5e5 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -42,6 +42,7 @@ from onnxruntime.capi._pybind_state import get_build_info # noqa: F401 from onnxruntime.capi._pybind_state import get_device # noqa: F401 from onnxruntime.capi._pybind_state import get_version_string # noqa: F401 + from onnxruntime.capi._pybind_state import has_collective_ops # noqa: F401 from onnxruntime.capi._pybind_state import set_default_logger_severity # noqa: F401 from onnxruntime.capi._pybind_state import set_default_logger_verbosity # noqa: F401 from onnxruntime.capi._pybind_state import set_seed # noqa: F401 diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index f0385ea5abdfb..48af0a9a6adec 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -140,27 +140,29 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { #endif if (!use_flash_attention) { - if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT - // GPT fused kernels requires left side padding. mask can be: - // none (no padding), 1D sequence lengths or 2d mask. - // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token - // where past state is empty. - bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; - bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && - nullptr == relative_position_bias && - parameters.past_sequence_length == 0 && - parameters.hidden_size == parameters.v_hidden_size && - FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, true); - if (use_causal_fused_runner) { - // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. - if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); + if (is_unidirectional_) { // GPT + if (enable_fused_causal_attention_) { + // GPT fused kernels requires left side padding. mask can be: + // none (no padding), 1D sequence lengths or 2d mask. + // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token + // where past state is empty. + bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; + bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && + nullptr == relative_position_bias && + parameters.past_sequence_length == 0 && + parameters.hidden_size == parameters.v_hidden_size && + FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, true); + if (use_causal_fused_runner) { + // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. + if (nullptr == fused_fp16_runner_.get()) { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + } + + // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. + fused_runner = fused_fp16_runner_.get(); } - - // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. - fused_runner = fused_fp16_runner_.get(); } } else { // BERT bool use_fused_runner = !disable_fused_self_attention_ && diff --git a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 5827bdfee1ab5..c8877a5e3f872 100644 --- a/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -174,7 +174,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio q = add_vec(q, q_bias); } - T* params_k_cache = reinterpret_cast(params.k_cache); const float inv_sqrt_dh = params.scale; @@ -350,24 +349,22 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio // The keys loaded from the key cache. K_vec_k k_vec[K_VECS_PER_THREAD]; + if (ti < tlength) { + if (has_beams) { + const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size; - if (has_beams) { #pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.max_sequence_length + ti; + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + ti; - if (ti < tlength) { - const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size; k_vec[ii] = vec_conversion( (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); } - } - } else { + } else { #pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.max_sequence_length + ti; + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + ti; - if (ti < tlength) { k_vec[ii] = vec_conversion( (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); } diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index a23409292bb74..6a82b3fcc734d 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -135,38 +135,34 @@ void CPUIDInfo::ArmLinuxInit() { LOGS_DEFAULT(WARNING) << "Failed to init pytorch cpuinfo library, may cause CPU EP performance degradation due to undetected CPU features."; return; } + is_hybrid_ = cpuinfo_get_uarchs_count() > 1; + has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); + has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); + const uint32_t core_cnt = cpuinfo_get_cores_count(); + core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); + is_armv8_narrow_ld_.resize(core_cnt, false); + for (uint32_t c = 0; c < core_cnt; c++) { + const struct cpuinfo_processor* proc = cpuinfo_get_processor(c); + if (proc == nullptr) { + continue; + } + const struct cpuinfo_core* corep = proc->core; + if (corep == nullptr) { + continue; + } + auto coreid = proc->linux_id; + auto uarch = corep->uarch; + core_uarchs_[coreid] = uarch; + if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || + uarch == cpuinfo_uarch_cortex_a55) { + is_armv8_narrow_ld_[coreid] = true; + } + } #else pytorch_cpuinfo_init_ = false; + has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); + has_fp16_ |= has_arm_neon_dot_; #endif - - if (pytorch_cpuinfo_init_) { - is_hybrid_ = cpuinfo_get_uarchs_count() > 1; - has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); - has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); - const uint32_t core_cnt = cpuinfo_get_cores_count(); - core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); - is_armv8_narrow_ld_.resize(core_cnt, false); - for (uint32_t c = 0; c < core_cnt; c++) { - const struct cpuinfo_processor* proc = cpuinfo_get_processor(c); - if (proc == nullptr) { - continue; - } - const struct cpuinfo_core* corep = proc->core; - if (corep == nullptr) { - continue; - } - auto coreid = proc->linux_id; - auto uarch = corep->uarch; - core_uarchs_[coreid] = uarch; - if (uarch == cpuinfo_uarch_cortex_a53 || uarch == cpuinfo_uarch_cortex_a55r0 || - uarch == cpuinfo_uarch_cortex_a55) { - is_armv8_narrow_ld_[coreid] = true; - } - } - } else { - has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); - has_fp16_ |= has_arm_neon_dot_; - } } #elif defined(_WIN32) diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index dede1ecc95885..1b492a3561396 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -177,9 +177,9 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - // Run layout transformer only for EPs other than CPU EP and provided the preferred layout is NHWC + // Run layout transformer for EPs with preferred layout of NHWC // CPU EP layout transformation happens later when level 3 transformers are run. - if (params.mode != GraphPartitioner::Mode::kAssignOnly && + if (params.mode != GraphPartitioner::Mode::kAssignOnly && params.transform_layout.get() && current_ep.GetPreferredLayout() == DataLayout::NHWC) { for (auto& capability : capabilities) { TryAssignNodes(graph, *capability->sub_graph, ep_type); diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index 38c8a4a4e3d5e..c4eef5b27c1bb 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -63,8 +63,9 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node, auto create_error_message = [&node, &status](const std::string& prefix) { std::ostringstream errormsg; errormsg << prefix << node.OpType() << "(" << node.SinceVersion() << ")"; - if (!node.Name().empty()) errormsg << " (node " << node.Name() << "). "; - if (!status.IsOK()) errormsg << status.ErrorMessage(); + errormsg << " (node:'" << node.Name() << "' ep:'" << node.GetExecutionProviderType() << "'). "; + if (!status.IsOK()) + errormsg << status.ErrorMessage(); return errormsg.str(); }; diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index f0e5fbbd38721..6244d426450a2 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1046,7 +1046,8 @@ Status SessionState::CreateSubgraphSessionState() { auto subgraph_session_state = std::make_unique(*subgraph, execution_providers_, thread_pool_, inter_op_thread_pool_, data_transfer_mgr_, - logger_, profiler_, sess_options_, nullptr, allocators_); + logger_, profiler_, sess_options_, + prepacked_weights_container_, allocators_); // Pass fused function manager to subgraph subgraph_session_state->fused_funcs_mgr_.SetFusedFuncs(fused_funcs_mgr_); diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc index a9e5038e19e02..36f03a9b1046a 100644 --- a/onnxruntime/core/framework/tensor.cc +++ b/onnxruntime/core/framework/tensor.cc @@ -27,80 +27,68 @@ int64_t GetSizeFromStrides(const TensorShape& shape, gsl::span st } // namespace #endif -size_t Tensor::CalculateTensorStorageSize(MLDataType p_type, - const TensorShape& shape, - gsl::span strides) { -#ifdef ENABLE_STRIDED_TENSORS - int64_t shape_size = 1; - if (shape.NumDimensions() > 0 && !strides.empty()) { - ORT_ENFORCE(shape.NumDimensions() == strides.size(), "Length of strides doesn't match with tensor dimension size."); - shape_size = GetSizeFromStrides(shape, strides); - } else { - shape_size = shape.Size(); - } -#else - ORT_ENFORCE(strides.empty(), "Strided tensor is supported for training only for now."); +size_t Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape) { int64_t shape_size = shape.Size(); -#endif - if (shape_size < 0) ORT_THROW("shape.Size() must >=0"); + if (shape_size < 0) + ORT_THROW("shape.Size() must >=0"); if (shape_size > 0) { SafeInt len = 0; - if (!IAllocator::CalcMemSizeForArray(SafeInt(shape_size), p_type->Size(), &len)) + if (!IAllocator::CalcMemSizeForArray(SafeInt(shape_size), elt_type->Size(), &len)) ORT_THROW("tensor failed memory size calculation"); return len; } + return 0; } -Tensor::Tensor(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, +Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, ptrdiff_t offset, gsl::span strides) - : alloc_info_(alloc) { - ORT_ENFORCE(p_type != nullptr); - Init(p_type, shape, p_data, nullptr, offset, strides); + : alloc_info_(location) { + ORT_ENFORCE(elt_type != nullptr); + Init(elt_type, shape, p_data, nullptr, offset, strides); } -Tensor::Tensor(MLDataType p_type, const TensorShape& shape, std::shared_ptr allocator, - gsl::span strides) +Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, std::shared_ptr allocator) : alloc_info_(allocator->Info()) { - ORT_ENFORCE(p_type != nullptr); - size_t len = Tensor::CalculateTensorStorageSize(p_type, shape, strides); + ORT_ENFORCE(elt_type != nullptr); + size_t len = Tensor::CalculateTensorStorageSize(elt_type, shape); void* p_data = nullptr; if (len > 0) { p_data = allocator->Alloc(len); } - Init(p_type, shape, p_data, allocator, 0L, strides); + Init(elt_type, shape, p_data, allocator, 0L); } -Tensor::Tensor(MLDataType p_type, const TensorShape& shape, void* p_data, std::shared_ptr deleter, +Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, std::shared_ptr deleter, ptrdiff_t offset, gsl::span strides) : alloc_info_(deleter->Info()) { - ORT_ENFORCE(p_type != nullptr); - Init(p_type, shape, p_data, deleter, offset, strides); + ORT_ENFORCE(elt_type != nullptr); + Init(elt_type, shape, p_data, deleter, offset, strides); } void Tensor::InitOrtValue(MLDataType elt_type, const TensorShape& shape, std::shared_ptr allocator, - OrtValue& ort_value, gsl::span strides) { - auto p_tensor = std::make_unique(elt_type, shape, std::move(allocator), strides); + OrtValue& ort_value) { + auto p_tensor = std::make_unique(elt_type, shape, std::move(allocator)); auto ml_tensor = DataTypeImpl::GetType(); ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); } -void Tensor::InitOrtValue(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, +void Tensor::InitOrtValue(MLDataType elt_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, OrtValue& ort_value, ptrdiff_t offset, gsl::span strides) { auto ml_tensor = DataTypeImpl::GetType(); - auto p_tensor = std::make_unique(p_type, shape, p_data, location, offset, strides); + auto p_tensor = std::make_unique(elt_type, shape, p_data, location, offset, strides); ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); } -void Tensor::InitOrtValue(MLDataType p_type, const TensorShape& shape, +void Tensor::InitOrtValue(MLDataType elt_type, const TensorShape& shape, void* p_data, std::shared_ptr allocator, OrtValue& ort_value, ptrdiff_t offset, gsl::span strides) { auto ml_tensor = DataTypeImpl::GetType(); - auto p_tensor = std::make_unique(p_type, shape, p_data, std::move(allocator), offset, strides); + auto p_tensor = std::make_unique(elt_type, shape, p_data, std::move(allocator), offset, strides); ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); } @@ -123,27 +111,31 @@ size_t Tensor::SizeInBytes() const { return ret; } -void Tensor::Init(MLDataType p_type, const TensorShape& shape, void* p_raw_data, AllocatorPtr deleter, ptrdiff_t offset, - gsl::span strides) { +void Tensor::Init(MLDataType elt_type, const TensorShape& shape, void* p_raw_data, AllocatorPtr deleter, + ptrdiff_t offset, gsl::span strides) { int64_t shape_size = shape.Size(); - if (shape_size < 0) ORT_THROW("shape.Size() must >=0"); - dtype_ = p_type->AsPrimitiveDataType(); + if (shape_size < 0) + ORT_THROW("shape.Size() must >=0"); + + dtype_ = elt_type->AsPrimitiveDataType(); ORT_ENFORCE(dtype_ != nullptr, - "Tensor is expected to contain one of the primitive data types. Got: ", DataTypeImpl::ToString(p_type)); + "Tensor is expected to contain one of the primitive data types. Got: ", + DataTypeImpl::ToString(elt_type)); shape_ = shape; p_data_ = p_raw_data; - // if caller passed in a deleter, that means this tensor own this buffer - // we will release the buffer when this tensor is deconstructed. + // if caller passed in a deleter we now own p_data_ and must free it in the dtor buffer_deleter_ = std::move(deleter); // for string tensors, if this tensor own the buffer (caller passed in the deleter) // do the placement new for strings on pre-allocated buffer. if (buffer_deleter_ && IsDataTypeString()) { utils::ConstructStrings(p_data_, shape_size); } + byte_offset_ = offset; + #ifdef ENABLE_STRIDED_TENSORS if (shape.NumDimensions() > 0 && !strides.empty()) { - ORT_ENFORCE(shape.NumDimensions() == strides.size(), "Length of strides doesn't match with tensor dimension size."); + ORT_ENFORCE(shape.NumDimensions() == strides.size(), "Length of strides doesn't match tensor dimension size."); strides_.assign(strides.begin(), strides.end()); is_contiguous_ = CheckIsContiguous(); } @@ -156,13 +148,21 @@ Tensor::Tensor(Tensor&& other) noexcept : p_data_(other.p_data_), buffer_deleter_(other.buffer_deleter_), shape_(other.shape_), +#ifdef ENABLE_STRIDED_TENSORS + strides_(other.strides_), + is_contiguous_(other.is_contiguous_), +#endif dtype_(other.dtype_), alloc_info_(other.alloc_info_), byte_offset_(other.byte_offset_) { - other.dtype_ = DataTypeImpl::GetType()->AsPrimitiveDataType(); - other.shape_ = TensorShape(std::vector(1, 0)); other.p_data_ = nullptr; other.buffer_deleter_ = nullptr; + other.dtype_ = DataTypeImpl::GetType()->AsPrimitiveDataType(); + other.shape_ = TensorShape(std::vector(1, 0)); +#ifdef ENABLE_STRIDED_TENSORS + other.strides_ = {}; + other.is_contiguous_ = true; +#endif other.byte_offset_ = 0; } @@ -170,19 +170,28 @@ Tensor& Tensor::operator=(Tensor&& other) noexcept { if (this != &other) { ReleaseBuffer(); - dtype_ = other.dtype_; + p_data_ = other.p_data_; + buffer_deleter_ = other.buffer_deleter_; shape_ = other.shape_; +#ifdef ENABLE_STRIDED_TENSORS + strides_ = other.strides_; + is_contiguous_ = other.is_contiguous_; +#endif + dtype_ = other.dtype_; alloc_info_ = other.alloc_info_; byte_offset_ = other.byte_offset_; - p_data_ = other.p_data_; - buffer_deleter_ = other.buffer_deleter_; - other.dtype_ = DataTypeImpl::GetType()->AsPrimitiveDataType(); - other.shape_ = TensorShape(std::vector(1, 0)); other.p_data_ = nullptr; - other.byte_offset_ = 0; other.buffer_deleter_ = nullptr; + other.shape_ = TensorShape(std::vector(1, 0)); +#ifdef ENABLE_STRIDED_TENSORS + other.strides_ = {}; + other.is_contiguous_ = true; +#endif + other.dtype_ = DataTypeImpl::GetType()->AsPrimitiveDataType(); + other.byte_offset_ = 0; } + return *this; } diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 08ed811d9ac38..fd32aaedcc2ee 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -87,6 +87,7 @@ bool operator!=(const ONNX_NAMESPACE::TensorShapeProto_Dimension& l, } // namespace ONNX_NAMESPACE +namespace onnxruntime { namespace { // This function doesn't support string tensors @@ -162,9 +163,9 @@ static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_prot // Uses the tensor_proto_dir to construct the full path for external data. If tensor_proto_dir == nullptr // then uses the current directory instead. // This function does not unpack string_data of an initializer tensor -static Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const ORTCHAR_T* tensor_proto_dir, - std::vector& unpacked_tensor) { +Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const ORTCHAR_T* tensor_proto_dir, + std::vector& unpacked_tensor) { std::basic_string external_file_path; onnxruntime::FileOffsetType file_offset; SafeInt tensor_byte_size; @@ -184,10 +185,52 @@ static Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tenso return Status::OK(); } + +// TODO(unknown): Change the current interface to take Path object for model path +// so that validating and manipulating path for reading external data becomes easy +Status TensorProtoToOrtValueImpl(const Env& env, const ORTCHAR_T* model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + const MemBuffer* m, AllocatorPtr alloc, + OrtValue& value) { + if (m && m->GetBuffer() == nullptr) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "MemBuffer has not been allocated."); + } + + // to construct a Tensor with std::string we need to pass an allocator to the Tensor ctor + // as the contents of each string needs to be allocated and freed separately. + ONNXTensorElementDataType ele_type = utils::GetTensorElementType(tensor_proto); + if (ele_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING && (m || !alloc)) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "string tensor requires allocator to be provided."); + } + + // Note: We permit an empty tensor_shape_vec, and treat it as a scalar (a tensor of size 1). + TensorShape tensor_shape = GetTensorShapeFromTensorProto(tensor_proto); + const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); + + std::unique_ptr tensor; + + if (m) { + tensor = std::make_unique(type, tensor_shape, m->GetBuffer(), m->GetAllocInfo()); + + if (tensor->SizeInBytes() > m->GetLen()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The preallocated buffer is too small. Requires ", + tensor->SizeInBytes(), ", Got ", m->GetLen()); + } + } else { + tensor = std::make_unique(type, tensor_shape, alloc); + } + + ORT_RETURN_IF_ERROR(TensorProtoToTensor(env, model_path, tensor_proto, *tensor)); + + auto ml_tensor = DataTypeImpl::GetType(); + value.Init(tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); + return Status::OK(); +} + } // namespace -namespace onnxruntime { namespace utils { + #if !defined(ORT_MINIMAL_BUILD) static Status UnpackTensorWithExternalDataImpl(const ONNX_NAMESPACE::TensorProto& tensor, const ORTCHAR_T* tensor_proto_dir, @@ -810,7 +853,7 @@ Status GetExtDataFromTensorProto(const Env& env, const ORTCHAR_T* model_path, * @param env * @param model_path * @param tensor_proto tensor data in protobuf format - * @param tensorp pre-allocated tensor object, where we store the data + * @param tensor pre-allocated tensor object, where we store the data * @return */ Status TensorProtoToTensor(const Env& env, const ORTCHAR_T* model_path, @@ -824,7 +867,7 @@ Status TensorProtoToTensor(const Env& env, const ORTCHAR_T* model_path, const DataTypeImpl* const source_type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); if (source_type->Size() > tensor.DataType()->Size()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "TensorProto type ", DataTypeImpl::ToString(source_type), - " can not be writen into Tensor type ", DataTypeImpl::ToString(tensor.DataType())); + " can not be written into Tensor type ", DataTypeImpl::ToString(tensor.DataType())); } // find raw data in proto buf @@ -900,44 +943,18 @@ Status TensorProtoToTensor(const Env& env, const ORTCHAR_T* model_path, return Status::OK(); } -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 6239) -#endif -// TODO: Change the current interface to take Path object for model path -// so that validating and manipulating path for reading external data becomes easy -Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* model_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, - const MemBuffer& m, OrtValue& value) { - if (m.GetBuffer() == nullptr) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "TensorProtoToMLValue() must take a pre-allocated MemBuffer!"); - } - - ONNXTensorElementDataType ele_type = utils::GetTensorElementType(tensor_proto); - if (ele_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "string tensor can not use pre-allocated buffer"); - } - - // Note: We permit an empty tensor_shape_vec, and treat it as a scalar (a tensor of size 1). - TensorShape tensor_shape = GetTensorShapeFromTensorProto(tensor_proto); - const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); - std::unique_ptr tensorp = std::make_unique(type, tensor_shape, m.GetBuffer(), m.GetAllocInfo()); - if (tensorp->SizeInBytes() > m.GetLen()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The preallocated buffer is too small. Requires ", - tensorp->SizeInBytes(), ", Got ", m.GetLen()); - } - - ORT_RETURN_IF_ERROR(TensorProtoToTensor(env, model_path, tensor_proto, *tensorp)); +Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + const MemBuffer& m, OrtValue& value) { + return TensorProtoToOrtValueImpl(env, model_path, tensor_proto, &m, nullptr, value); +} - auto ml_tensor = DataTypeImpl::GetType(); - value.Init(tensorp.release(), ml_tensor, ml_tensor->GetDeleteFunc()); - return Status::OK(); +Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + AllocatorPtr alloc, OrtValue& value) { + return TensorProtoToOrtValueImpl(env, model_path, tensor_proto, nullptr, alloc, value); } -#ifdef _MSC_VER -#pragma warning(pop) -#pragma warning(disable : 6239) -#endif + #define CASE_TYPE(X) \ case ONNX_NAMESPACE::TensorProto_DataType_##X: \ return ONNX_TENSOR_ELEMENT_DATA_TYPE_##X; diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 6f1b60dd2982e..000502ba47594 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -39,13 +39,28 @@ TensorShape GetTensorShapeFromTensorShapeProto(const ONNX_NAMESPACE::TensorShape TensorShape GetTensorShapeFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto); /** - * deserialize a TensorProto into a preallocated memory buffer. - * \param tensor_proto_path A local file path of where the 'input' was loaded from. Can be NULL if the tensor proto doesn't - * have any external data or it was loaded from current working dir. This path could be either a - * relative path or an absolute path. + * deserialize a TensorProto into a preallocated memory buffer on CPU. + * \param tensor_proto_path A local file path of where the 'input' was loaded from. + * Can be NULL if the tensor proto doesn't have external data or it was loaded from + * the current working dir. This path could be either a relative path or an absolute path. + * \return Status::OK on success with 'value' containing the Tensor in CPU based memory. */ -common::Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* tensor_proto_path, - const ONNX_NAMESPACE::TensorProto& input, const MemBuffer& m, OrtValue& value); +common::Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* tensor_proto_path, + const ONNX_NAMESPACE::TensorProto& input, + const MemBuffer& m, OrtValue& value); + +/** + * deserialize a TensorProto into a buffer on CPU allocated using 'alloc'. + * \param tensor_proto_path A local file path of where the 'input' was loaded from. + * Can be NULL if the tensor proto doesn't have external data or it was loaded from + * the current working dir. This path could be either a relative path or an absolute path. + * \param alloc Allocator to use for allocating the buffer. Must allocate CPU based memory. + * \return Status::OK on success with 'value' containing the Tensor in CPU based memory. + */ +common::Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* tensor_proto_path, + const ONNX_NAMESPACE::TensorProto& input, + AllocatorPtr alloc, OrtValue& value); + /** * @brief Deserialize a TensorProto into a preallocated empty Tensor * @param env diff --git a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc index 3ce7c40e754dc..d3fc5873cb274 100644 --- a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc +++ b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc @@ -94,7 +94,7 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::functionGemmU8U8Dispatch = &MlasGemmU8S8DispatchAmx; this->GemmU8S8Dispatch = &MlasGemmU8S8DispatchAmx; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index 54511aa02a57c..c4416068e2457 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -189,6 +189,8 @@ InlinedVector> GenerateTransformers( const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; #endif const InlinedHashSet dml_ep = {onnxruntime::kDmlExecutionProvider}; + AllocatorPtr cpu_allocator = std::make_shared(); + switch (level) { case TransformerLevel::Level1: { // RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run) @@ -240,13 +242,14 @@ InlinedVector> GenerateTransformers( // run TransposeOptimizer last as it works in a slightly different way by moving Transpose nodes around. // shouldn't affect the end result - just easier to debug any issue if it's last. - // local CPU allocator is enough as this allocator is finally passed to a local tensor. - // We will also benefit by using a local allocator as we don't need to pass allocator as parameter for EP API refactor - AllocatorPtr cpu_allocator = std::make_shared(); transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); } break; case TransformerLevel::Level2: { + // we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be + // applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2). + transformers.emplace_back(std::make_unique(std::move(cpu_allocator), kCpuExecutionProvider)); + const bool enable_quant_qdq_cleanup = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableQuantQDQCleanup, "0") == "1"; #if !defined(DISABLE_CONTRIB_OPS) @@ -366,16 +369,16 @@ InlinedVector> GenerateTransformers( if (MlasNchwcGetBlockSize() > 1) { transformers.emplace_back(std::make_unique()); } - AllocatorPtr cpu_allocator = std::make_shared(); + auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); if (nhwc_transformer->IsActive()) { transformers.emplace_back(std::move(nhwc_transformer)); } - // NCHWCtransformer should have a higher priority versus this. Because NCHWCtransformer also do the similar things - // of fusion patterns and target on CPU. However, NCHWCtransformer will reorder the layout to nchwc which is only available for - // x86-64 cpu, not edge cpu like arm. But This transformer could be used by opencl-ep/cpu-ep. So - // we will prefer NhwcTransformer once ort runs on x86-64 CPU, otherwise ConvAddActivationFusion is enabled. + + // NchwcTransformer must have a higher priority than ConvAddActivationFusion. NchwcTransformer does similar + // fusions targeting CPU but also reorders the layout to NCHWc which is expected to be more efficient but is + // only available on x86-64. // PR #6351 implemented similar fusion-pattern for CUDA only, and can only fuse conv-add-relu, // while we can fuse more activation. transformers.emplace_back(std::make_unique(cpu_ep)); diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index 9cdc0d9ef0473..c8da15f65a6d7 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -3,22 +3,23 @@ #include "core/optimizer/initializer.h" -#include "core/common/gsl.h" +#include +#include +#include "core/common/gsl.h" #include "core/common/path.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_external_data_info.h" #include "core/platform/env.h" -#include - namespace onnxruntime { Initializer::Initializer(ONNX_NAMESPACE::TensorProto_DataType data_type, std::string_view name, gsl::span dims) : name_(name), - data_(DataTypeImpl::TensorTypeFromONNXEnum(data_type)->GetElementType(), dims, std::make_shared()) { + data_(DataTypeImpl::TensorTypeFromONNXEnum(data_type)->GetElementType(), dims, + std::make_shared()) { if (!data_.IsDataTypeString()) { memset(data_.MutableDataRaw(), 0, data_.SizeInBytes()); } @@ -39,7 +40,8 @@ Initializer::Initializer(const ONNX_NAMESPACE::TensorProto& tensor_proto, const auto proto_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); // This must be pre-allocated - Tensor w(DataTypeImpl::TensorTypeFromONNXEnum(proto_data_type)->GetElementType(), proto_shape, std::make_shared()); + Tensor w(DataTypeImpl::TensorTypeFromONNXEnum(proto_data_type)->GetElementType(), proto_shape, + std::make_shared()); ORT_THROW_IF_ERROR(utils::TensorProtoToTensor(Env::Default(), model_path.ToPathString().c_str(), tensor_proto, w)); data_ = std::move(w); } diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index bf36f11521be2..159e3b23d1ab0 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -414,20 +414,20 @@ Status LayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int graph_level, NodeArg* scale = nullptr; NodeArg* bias = nullptr; for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) { - if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { - scale = mul_node.MutableInputDefs()[i]; - } + if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) { + continue; + } + if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + scale = mul_node.MutableInputDefs()[i]; } } for (size_t i = 0; i < last_add_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(last_add_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, last_add_node.MutableInputDefs()[i])) { - if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { - bias = last_add_node.MutableInputDefs()[i]; - } + if (last_add_node.MutableInputDefs()[i]->Shape() == nullptr) { + continue; + } + if (last_add_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + bias = last_add_node.MutableInputDefs()[i]; } } if (scale == nullptr || bias == nullptr) { @@ -667,20 +667,20 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr // because SkipLayerNorm kernel, for example, has dependency on single dim size NodeArg* scale = nullptr; for (size_t i = 0; i < mul_node.MutableInputDefs().size(); i++) { - if (graph_utils::NodeArgIsConstant(graph, *(mul_node.MutableInputDefs()[i])) || - graph_utils::IsGraphInput(graph, mul_node.MutableInputDefs()[i])) { + if (mul_node.MutableInputDefs()[i]->Shape() == nullptr) { + continue; + } #ifdef ENABLE_TRAINING_CORE - if (axes_values.empty() || - mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { - scale = mul_node.MutableInputDefs()[i]; - } + if (axes_values.empty() || + mul_node.MutableInputDefs()[i]->Shape()->dim_size() == static_cast(axes_values.size())) { + scale = mul_node.MutableInputDefs()[i]; + } #else - // Scale must be 1d. - if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) { - scale = mul_node.MutableInputDefs()[i]; - } -#endif + // Scale must be 1d. + if (mul_node.MutableInputDefs()[i]->Shape()->dim_size() == 1) { + scale = mul_node.MutableInputDefs()[i]; } +#endif } if (scale == nullptr) { diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 2d12c407e6e31..6c91949e467ae 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -13,27 +13,91 @@ using namespace onnx_transpose_optimization; namespace onnxruntime { namespace layout_transformation { +namespace { +// Cost check for aggressively pushing the Transpose nodes involved in the layout transformation further out. +CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const api::NodeRef& node, + const std::vector& perm, + const std::unordered_set& outputs_leading_to_transpose) { + // we aggressively push the layout transpose nodes. + // Exception: pushing through a Concat can result in Transpose nodes being added to multiple other inputs which + // can potentially be worse for performance. Use the cost check in that case. + if (node.OpType() != "Concat" && + (perm == ChannelFirstToLastPerm(perm.size()) || perm == ChannelLastToFirstPerm(perm.size()))) { + return CostCheckResult::kPushTranspose; + } + + // for other nodes use the default ORT cost check + return OrtEPCostCheck(graph, node, perm, outputs_leading_to_transpose); +} + +///

+/// Default function for checking if a node should have its layout changed. Allows EP specific adjustments to the +/// default set of layout sensitive operators if required. +/// +/// Longer term, if required, the EP API could allow the EP to provide a delegate to plugin EP specific logic so we +/// don't hardcode it here. +/// +/// Node to check +/// true if the node should have its layout converted to NHWC. +bool ConvertNodeLayout(const api::NodeRef& node) { + // skip if op is not an ONNX or contrib op + auto domain = node.Domain(); + if (domain != kOnnxDomain && domain != kMSDomain) { + return false; + } + + const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); + + // handle special cases +#if defined(USE_XNNPACK) + if (node.GetExecutionProviderType() == kXnnpackExecutionProvider) { + if (node.OpType() == "Resize") { + // XNNPACK supports NCHW and NHWC for Resize so we don't need to use the internal NHWC domain and wrap the Resize + // with Transpose nodes. EPAwareHandleResize will allow an NCHW <-> NHWC Transpose to be pushed through + // the Resize during transpose optimization. + return false; + } + } +#endif + +#if defined(USE_JSEP) + // TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed + if (node.GetExecutionProviderType() == kJsExecutionProvider) { + if (node.OpType() == "Resize") { + // leave Resize as-is pending bugfix for NHWC implementation. this means the node will remain in the ONNX domain + // with the original input layout. + return false; + } + } +#endif + + // #if defined(USE_CUDA) + // if (node.GetExecutionProviderType() == kCudaExecutionProvider) { + // Update as per https://github.com/microsoft/onnxruntime/pull/17200 with CUDA ops that support NHWC + // } + // #endif + + return layout_sensitive_ops.count(node.OpType()) != 0; +} +} // namespace // Layout sensitive NCHW ops. TransformLayoutForEP will wrap these with Transpose nodes to convert the input // data to NHWC and output data back to NCHW, and move the op to the internal NHWC domain (kMSInternalNHWCDomain). -// The EP requesting these ops MUST be able to handle the node with the operator in the kMSInternalNHWCDomain. +// The EP requesting these ops MUST be able to handle the node with the operator in the kMSInternalNHWCDomain domain. // Once all the layout sensitive ops requested by the EP are wrapped the transpose optimizer will attempt to remove // as many of the layout transposes as possible. const std::unordered_set& GetORTLayoutSensitiveOps() { static std::unordered_set ort_layout_sensitive_ops = []() { const auto& layout_sensitive_ops = onnx_transpose_optimization::GetLayoutSensitiveOps(); std::unordered_set ort_specific_ops = - { "FusedConv", - "QLinearAveragePool", - "QLinearGlobalAveragePool" -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) || defined(USE_WEBNN) - // The CUDA/ROCM Resize kernel is layout sensitive as it only handles NCHW input. - // The CPU kernel and ONNX spec are not limited to handling NCHW input so are not layout sensitive, and - // onnx_layout_transformation::HandleResize is used. - , - "Resize" -#endif - }; + { + "FusedConv", + "QLinearAveragePool", + "QLinearGlobalAveragePool", + // Whilst the ONNX spec doesn't specify a layout for Resize, we treat it as layout sensitive by default + // as EPs tend to only support one layout. + "Resize", + }; ort_specific_ops.insert(layout_sensitive_ops.cbegin(), layout_sensitive_ops.cend()); return ort_specific_ops; @@ -42,45 +106,21 @@ const std::unordered_set& GetORTLayoutSensitiveOps() { return ort_layout_sensitive_ops; } -// Cost check for aggressively pushing the Transpose nodes involved in the layout transformation further out. -static CostCheckResult -PostLayoutTransformCostCheck(const api::GraphRef& graph, const api::NodeRef& node, - const std::vector& perm, - const std::unordered_set& outputs_leading_to_transpose) { - // we aggressively push the layout transpose nodes. - // Exception: pushing through a Concat can result in Transpose nodes being added to multiple other inputs which - // can potentially be worse for performance. Use the cost check in that case. - if (node.OpType() != "Concat" && - (perm == ChannelFirstToLastPerm(perm.size()) || perm == ChannelLastToFirstPerm(perm.size()))) { - return CostCheckResult::kPushTranspose; - } - - // for other nodes use the default ORT cost check - return OrtEPCostCheck(graph, node, perm, outputs_leading_to_transpose); -} - Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvider& execution_provider, AllocatorPtr cpu_allocator, const DebugGraphFn& debug_graph_fn) { // We pass in nullptr for the new_node_ep param as new nodes will be assigned by the graph partitioner after // TransformLayoutForEP returns. - // sub graph recurse will be added later. + // sub graph recurse will be added later auto api_graph = MakeApiGraph(graph, cpu_allocator, /*new_node_ep*/ nullptr); - const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); // to convert to NHWC we need to wrap layout sensitive nodes to Transpose from NCHW to NHWC and back. for (auto& node : api_graph->Nodes()) { - if (layout_sensitive_ops.count(node->OpType())) { - if (node->GetExecutionProviderType() != execution_provider.Type()) { - continue; - } - - auto domain = node->Domain(); - // Skip if domain is incorrect - if (domain != kOnnxDomain && domain != kMSDomain) { - continue; - } + if (node->GetExecutionProviderType() != execution_provider.Type()) { + continue; + } + if (ConvertNodeLayout(*node)) { // if already transformed then change the domain to kMSInternalNHWCDomain this way the EP // knows this op is in the expected format. if (node->GetAttributeIntDefault("channels_last", 0) == 1) { @@ -137,7 +177,6 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm}); } - // TODO: Technically Resize doesn't need to change domain as the ONNX Resize spec is not layout sensitive. SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain); modified = true; } diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc index fc7e694b6a69b..46041bca9dcc1 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc @@ -49,19 +49,13 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, InitializedTensorSet::const_iterator it = initialized_tensor_set.find(arg.Name()); if (it != initialized_tensor_set.cend()) { const auto& tensor_proto = *(it->second); - size_t cpu_tensor_length; - ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &cpu_tensor_length)); OrtValue ort_value; - std::unique_ptr data = std::make_unique(cpu_tensor_length); - std::unique_ptr p_tensor; - ORT_RETURN_IF_ERROR(utils::TensorProtoToMLValue(Env::Default(), - model_path.IsEmpty() ? nullptr : model_path.ToPathString().c_str(), - tensor_proto, - MemBuffer(data.get(), cpu_tensor_length, allocator_ptr_->Info()), - ort_value)); - - initializers_[idx] = ort_value; - buffer_for_initialized_tensors_[idx] = std::move(data); + ORT_RETURN_IF_ERROR( + utils::TensorProtoToOrtValue(Env::Default(), + model_path.IsEmpty() ? nullptr : model_path.ToPathString().c_str(), + tensor_proto, allocator_ptr_, ort_value)); + + initializers_[idx] = std::move(ort_value); } return Status::OK(); @@ -72,7 +66,6 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, ort_value_name_idx_map_.Reserve(num_inputs_outputs); ort_value_idx_nodearg_map_.reserve(num_inputs_outputs); initializers_.reserve(initialized_tensor_set.size()); - buffer_for_initialized_tensors_.reserve(initialized_tensor_set.size()); for (auto* node : nodes) { ORT_THROW_IF_ERROR(onnxruntime::Node::ForEachWithIndex(node->InputDefs(), initialize_maps)); diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.h b/onnxruntime/core/optimizer/optimizer_execution_frame.h index e1b8a91545f64..13cf9e652c404 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.h +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.h @@ -70,7 +70,6 @@ class OptimizerExecutionFrame final : public IExecutionFrame { OrtValueNameIdxMap ort_value_name_idx_map_; std::unordered_map ort_value_idx_nodearg_map_; std::unordered_map initializers_; - InlinedHashMap> buffer_for_initialized_tensors_; std::unique_ptr node_index_info_; const IExecutionProvider& execution_provider_; const std::function& is_sparse_initializer_func_; diff --git a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc index a0942c31b0161..50653b368857d 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/clip_quantizelinear.cc @@ -1,8 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "core/optimizer/initializer.h" #include "core/optimizer/qdq_transformer/clip_quantizelinear.h" + +#include + +#include "core/optimizer/initializer.h" #include "core/optimizer/qdq_transformer/qdq_util.h" #include "core/optimizer/utils.h" #include "core/graph/graph_utils.h" @@ -50,14 +53,26 @@ static bool GetQConstantLowerUpper(const Graph& graph, const Node& node, float& switch (zp_initializer.data_type()) { case ONNX_NAMESPACE::TensorProto_DataType_INT8: { const int8_t zero_point = zp_initializer.data()[0]; - lower = scale * (-128 - zero_point); - upper = scale * (127 - zero_point); + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); break; } case ONNX_NAMESPACE::TensorProto_DataType_UINT8: { const uint8_t zero_point = zp_initializer.data()[0]; - lower = scale * (0 - zero_point); - upper = scale * (255 - zero_point); + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_INT16: { + const int16_t zero_point = zp_initializer.data()[0]; + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); + break; + } + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: { + const uint16_t zero_point = zp_initializer.data()[0]; + lower = scale * (std::numeric_limits::lowest() - zero_point); + upper = scale * (std::numeric_limits::max() - zero_point); break; } default: diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc index 16c7bd5fce960..5015e48fdb7b8 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.cc @@ -496,6 +496,42 @@ bool LogicalComparisonNodeGroupSelector::Check(const GraphViewer& graph_viewer, return dt_input_1 == dt_input_2; } +bool TopKNodeGroupSelector::Check(const GraphViewer& graph_viewer, + const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const { + constexpr int num_dq_inputs = 1; + constexpr int num_q_outputs = 1; + if (num_dq_inputs != gsl::narrow_cast(dq_nodes.size())) { + return false; + } + + if (const auto dq_validation_status = QDQ::ValidateNodeGroupDQNodes(graph_viewer, node, dq_nodes); + !dq_validation_status.IsOK()) { + return false; + } + + if (num_q_outputs != gsl::narrow_cast(q_nodes.size())) { + return false; + } + + const Node& dq_node = *dq_nodes.front(); + const Node& q_node = *q_nodes.front(); + + int32_t dt_input = dq_node.InputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + int32_t dt_output = q_node.OutputDefs()[0]->TypeAsProto()->tensor_type().elem_type(); + + if (dt_input != dt_output) { + return false; + } + + auto get_const_initializer = [&graph_viewer](const std::string& initializer_name) { + return graph_viewer.GetConstantInitializer(initializer_name, true); + }; + + return IsQDQPairSupported(q_node, dq_node, get_const_initializer, graph_viewer.ModelPath()); +} + } // namespace QDQ } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h index d8fefdd8dc3d9..be7f7e0288eda 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h @@ -220,6 +220,14 @@ class LogicalComparisonNodeGroupSelector : public NodeGroupSelector { const std::vector& q_nodes) const override; }; +// TopK has 1 DQ input node and 1 Q output node. +// Zero point and scale are constant scalars and must match +class TopKNodeGroupSelector : public NodeGroupSelector { + bool Check(const GraphViewer& graph_viewer, const Node& node, + const std::vector& dq_nodes, + const std::vector& q_nodes) const override; +}; + /* * NodeSelector instances for use in the QDQ::SelectorActionTransformer. */ diff --git a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc index f1bdd7a99c329..3f1b2f0458bc0 100644 --- a/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc +++ b/onnxruntime/core/optimizer/qdq_transformer/selectors_actions/shared/utils.cc @@ -36,7 +36,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetMiscOpVersionsMap() { {"Resize", {}}, {"Split", {}}, {"Squeeze", {}}, - {"Unsqueeze", {}}}; + {"Unsqueeze", {}}, + {"Tile", {}}}; } static const OpVersionsAndSelector::OpVersionsMap GetDropDQOpVersionsMap() { @@ -78,7 +79,8 @@ static const OpVersionsAndSelector::OpVersionsMap GetUnaryOpVersionsMap() { {"Abs", {}}, {"Neg", {}}, {"DepthToSpace", {}}, - {"SpaceToDepth", {}}}; + {"SpaceToDepth", {}}, + {"Clip", {}}}; } static const OpVersionsAndSelector::OpVersionsMap GetBinaryOpVersionsMap() { return {{"Add", {}}, @@ -127,6 +129,10 @@ static const OpVersionsAndSelector::OpVersionsMap GetPadOpVersionsMap() { return {{"Pad", {}}}; } +static const OpVersionsAndSelector::OpVersionsMap GetTopKOpVersionsMap() { + return {{"TopK", {}}}; +} + /* Selector rules registration related */ void RegisterMiscSelectors(Selectors& qdq_selectors) { /* register selectors for miscellaneous ops */ @@ -227,6 +233,13 @@ void RegisterPadSelectors(Selectors& qdq_selectors) { std::move(selector)); } +void RegisterTopKSelector(Selectors& qdq_selectors) { + /* register selector for TopK op */ + std::unique_ptr selector = std::make_unique(); + qdq_selectors.RegisterSelector(GetTopKOpVersionsMap(), + std::move(selector)); +} + void SelectorManager::CreateSelectors() { RegisterMiscSelectors(qdq_selectors_); RegisterDropDQSelectors(qdq_selectors_); @@ -242,6 +255,7 @@ void SelectorManager::CreateSelectors() { RegisterLogicalComparisonSelectors(qdq_selectors_); RegisterWhereSelectors(qdq_selectors_); RegisterPadSelectors(qdq_selectors_); + RegisterTopKSelector(qdq_selectors_); } void SelectorManager::InitializeSelectorsMap() { diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 2c11bf144999e..81b415c2e40ae 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -4,9 +4,11 @@ #include "onnx_transpose_optimization.h" #include +#include #include #include #include +#include #include #include "core/common/gsl.h" @@ -93,6 +95,55 @@ static std::unique_ptr MakeSqueezeOrUnsqueeze(int64_t opset, api:: return graph.AddNode(op_type, inputs, /*num_outputs*/ 1); } +/// +/// Return a DequantizeLinear node if it's input is a constant initializer with known consumers. +/// In this case the initializer can be updated in-place by UnsqueezeInput or TransposeInput. +/// +/// Current graph +/// Value to check if produced by a DQ node who's input is a constant initializer +/// NodeRef for DQ node if it meets the requirements. +static std::unique_ptr GetDQWithConstInitializerInput(const api::GraphRef& graph, + std::string_view dq_output_name) { + std::unique_ptr dq_node; + auto maybe_dq_node = graph.GetNodeProducingOutput(dq_output_name); + + if (maybe_dq_node && maybe_dq_node->OpType() == "DequantizeLinear") { + do { + auto dq_input = maybe_dq_node->Inputs()[0]; + auto dq_constant = graph.GetConstant(dq_input); + + // input to DQ must be a constant initializer + if (!dq_constant) { + break; + } + + // For now keep it simple and don't support per-axis quantization as that would require updating the + // scale and zero point values in the DQ node to re-order if transposing, or reshape if unsqueezing. + // the rank of the `scale` and `zero point` inputs must match so we only need to check `scale`. + auto dq_scale = graph.GetConstant(maybe_dq_node->Inputs()[1]); + if (!dq_scale || dq_scale->NumElements() != 1) { + break; + } + + // need to know all the initializer consumers as we're potentially going to modify it directly + auto initializer_consumers = graph.GetValueConsumers(dq_input); + if (!initializer_consumers->comprehensive) { + break; + } + + // DQ output is only used by the node we're modifying. + auto dq_consumers = graph.GetValueConsumers(dq_output_name); + if (!dq_consumers->comprehensive || dq_consumers->nodes.size() != 1) { + break; + } + + dq_node = std::move(maybe_dq_node); + } while (false); + } + + return dq_node; +} + // Returns whether perm is a valid permutation (contains each value from 0 to perm.size() - 1 exactly once) static bool IsValidPerm(const std::vector& perm) { size_t rank = perm.size(); @@ -345,6 +396,12 @@ static std::vector SortedAxesForTransposedInput(const std::vectorShape(); + graph.GetValueInfo(dq.Outputs()[0])->SetShape(&new_shape); +} + /////// /////// /////// /////// @@ -357,51 +414,126 @@ static std::string_view HelpHandleUnsqueeze(HandlerArgs& args, const std::vector // broadcasting. static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, const std::vector& axes) { std::string_view input = node.Inputs()[i]; - // Remove this node as a consumer - node.SetInput(i, ""); std::unique_ptr constant = ctx.graph.GetLocalConstant(input); - auto consumers = ctx.graph.GetValueConsumers(input); + + // allow a constant initializer coming via a DQ node with a single consumer + std::unique_ptr dq_node; + std::string_view constant_dq_input; + + if (!constant) { + // look past a DQ node for a constant initializer. essentially we pretend the DQ node doesn't exist + // to enable directly making changes to the initializer. any nodes added for other consumers of the initializer + // in 'Case 1' are prior to the DQ so we don't break up any QDQ node units. + dq_node = GetDQWithConstInitializerInput(ctx.graph, input); + if (dq_node) { + // underlying string for the input name is in the Node so it's safe to store in string_view constant_dq_input + constant_dq_input = dq_node->Inputs()[0]; + constant = ctx.graph.GetLocalConstant(constant_dq_input); + // remove the DQ node as a consumer of the initializer while we modify things + dq_node->SetInput(0, ""); + } + } + + // Clear the input, which also removes this node's input as a consumer of the value. + // NOTE: the node may have multiple inputs consuming the value. + node.SetInput(i, ""); + auto value_to_modify = dq_node ? constant_dq_input : input; + auto consumers = ctx.graph.GetValueConsumers(value_to_modify); // Case 1: input is a constant with a known list of consumer nodes if (constant != nullptr && consumers->comprehensive) { - // We will reshape the initializer. If there are existing consumers, still reshape it but add Squeeze nodes + // We will reshape the initializer. If there are existing consumers, reshape it and add Squeeze nodes // to counteract its effect. If they later Unsqueeze the same input, the Squeeze nodes will simply be deleted // (see Case 2). if (consumers->nodes.size() > 0) { - auto squeeze_ptr = MakeSqueezeOrUnsqueeze(ctx.opset, ctx.graph, "Squeeze", input, axes); + // record the consumer node input as being special cased for use in Case 2 if a DQ node, and IsConstant + for (auto& consumer : consumers->nodes) { + auto& consumer_node_inputs = ctx.nodes_using_updated_shared_initializer[consumer->Id()]; + + // find input id/s for consumer + auto consumer_inputs = consumer->Inputs(); + for (size_t input_idx = 0; input_idx < consumer_inputs.size(); ++input_idx) { + if (consumer_inputs[input_idx] == value_to_modify) { + consumer_node_inputs.push_back(input_idx); + } + } + } + + auto squeeze_ptr = MakeSqueezeOrUnsqueeze(ctx.opset, ctx.graph, "Squeeze", value_to_modify, axes); api::NodeRef& squeeze = *squeeze_ptr; std::string_view sq_out = squeeze.Outputs()[0]; - ctx.graph.CopyValueInfo(input, sq_out); - ReplaceValueReferences(consumers->nodes, input, sq_out); + ctx.graph.CopyValueInfo(value_to_modify, sq_out); + ReplaceValueReferences(consumers->nodes, value_to_modify, sq_out); } + auto new_shape = UnsqueezeShape(constant->Shape(), axes); - ctx.graph.ReshapeInitializer(input, new_shape); - node.SetInput(i, input); + ctx.graph.ReshapeInitializer(value_to_modify, new_shape); + + if (dq_node) { + UpdateDQNodeInputAndShape(ctx.graph, *dq_node, constant_dq_input); + } + + node.SetInput(i, input); // restore the original connection return; } // Case 2: input is a Squeeze node with matching axes std::unique_ptr inp_node = ctx.graph.GetNodeProducingOutput(input); + + // check if this is a special-cased DQ node where we put the Squeeze on input 0 of the DQ in 'Case 1' above + if (inp_node && inp_node->OpType() == "DequantizeLinear" && + std::find_if(ctx.nodes_using_updated_shared_initializer.begin(), + ctx.nodes_using_updated_shared_initializer.end(), + [&inp_node](const auto& entry) { + const auto id = entry.first; + const auto& input_idxs = entry.second; + // check Id matches and the entry was for input 0 of the DQ node + return id == inp_node->Id() && + std::find(input_idxs.begin(), input_idxs.end(), size_t(0)) != input_idxs.end(); + }) != ctx.nodes_using_updated_shared_initializer.end()) { + // set things up so we can look past the DQ node to the Squeeze that was inserted in front of the reshaped + // constant initializer that was shared with this node. + dq_node = std::move(inp_node); + auto dq_input = dq_node->Inputs()[0]; + inp_node = ctx.graph.GetNodeProducingOutput(dq_input); + consumers = ctx.graph.GetValueConsumers(dq_input); + } + if (inp_node != nullptr && inp_node->IsOp("Squeeze")) { const std::vector& inp_node_inputs = inp_node->Inputs(); std::optional> squeeze_axes = std::nullopt; squeeze_axes = ReadFromAttrOrInput(ctx, *inp_node, "axes", /*inp_index*/ 1, /*opset*/ 13); if (squeeze_axes != std::nullopt && *squeeze_axes == axes) { + if (dq_node) { + UpdateDQNodeInputAndShape(ctx.graph, *dq_node, inp_node_inputs[0]); + node.SetInput(i, dq_node->Outputs()[0]); + } else { + node.SetInput(i, inp_node_inputs[0]); + } + // Remove the Squeeze node if possible - if (consumers->comprehensive && consumers->nodes.size() == 0) { + // if there's a DQ node the `consumers` list still includes it so allow for that. + // in that case UpdateDQNodeInputAndShape already updated the input of the DQ node so it's safe to remove it. + if (consumers->comprehensive && consumers->nodes.size() == size_t(dq_node ? 1 : 0)) { ctx.graph.RemoveNode(*inp_node); + if (ctx.opset >= 13 && !ctx.graph.HasValueConsumers(inp_node_inputs[1])) { ctx.graph.RemoveInitializer(inp_node_inputs[1]); } } - node.SetInput(i, inp_node_inputs[0]); + return; } // Axes don't match. Fall through to Case 3. } + // any DQ node special casing doesn't apply anymore, so go back to the original inp_node + if (dq_node) { + inp_node = std::move(dq_node); + } + // Case 3: Add an Unsqueeze node. auto unsqueeze_ptr = MakeSqueezeOrUnsqueeze(ctx.opset, ctx.graph, "Unsqueeze", input, axes); api::NodeRef& unsqueeze = *unsqueeze_ptr; @@ -453,83 +585,197 @@ static void Permute1DConstant(api::GraphRef& graph, api::NodeRef& node, api::Ten // Replaces ith input to node with transposed value. Might create a new Transpose node, find an existing one, // or transpose an initializer. -void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, - const std::vector& perm, const std::vector& perm_inv) { +static void TransposeInputImpl(api::GraphRef& graph, + NodeIdToInputIdxsMap* nodes_using_updated_shared_initializer, + api::NodeRef& node, size_t i, const std::vector& perm, + const std::vector& perm_inv) { std::string_view input = node.Inputs()[i]; - // Remove this node as a consumer - node.SetInput(i, ""); + // Only local constants are editable std::unique_ptr constant = graph.GetLocalConstant(input); - auto consumers = graph.GetValueConsumers(input); + + // allow a constant initializer coming via a DQ node with a single consumer + std::unique_ptr dq_node; + std::string_view constant_dq_input; + + if (!constant) { + // look past a DQ node for a constant initializer. essentially we pretend the DQ node doesn't exist + // to enable directly making changes to the initializer. any nodes added for other consumers of the initializer + // in 'Case 1' are prior to the DQ so we don't break up any QDQ node units. + dq_node = GetDQWithConstInitializerInput(graph, input); + if (dq_node) { + // underlying string for the input name is in the Node so it's safe to store in string_view constant_dq_input + constant_dq_input = dq_node->Inputs()[0]; + constant = graph.GetLocalConstant(constant_dq_input); + // remove the DQ node as a consumer of the initializer while we modify things + dq_node->SetInput(0, ""); + } + } + + // Clear the input, which also removes this node's input as a consumer of the value. + // NOTE: the node may have multiple inputs consuming the value. + node.SetInput(i, ""); + + auto constant_to_modify = dq_node ? constant_dq_input : input; + auto consumers = graph.GetValueConsumers(constant_to_modify); // Case 1: input is a constant with a known list of consumer nodes if (constant != nullptr && consumers->comprehensive) { - // Input is scalar, return early. - if (constant->Shape().size() == 1 && constant->Shape()[0] == 0) { + // we modify the initializer in-place and need to reconnect things up when we're done. this helper will + // do that when it goes out of scope. if we have manually reconnected, input or constant_dq_input is + // set to an empty string. + auto reconnect_nodes = gsl::finally([i, &node, &dq_node, &input, &constant_dq_input] { + if (!input.empty()) { + node.SetInput(i, input); + } + + if (!constant_dq_input.empty()) { + dq_node->SetInput(0, constant_dq_input); + } + }); + + // If there is only one element return early as the transpose won't change the data + if (constant->NumElements() == 1) { return; } + // This is a special case where the constant is 1D with length == perm. - // TODO: TransposeInitializer should be updated to handle this case. + // e.g. it provides a set of values that are relative to the input axes like the `sizes` input for Resize // Permute1DConstant permutes the constant and adds a new initializer. The old initializer is removed only if // there are no other consumers. if (constant->Shape().size() == 1 && constant->Shape()[0] == gsl::narrow_cast(perm.size())) { - Permute1DConstant(graph, node, *constant, i, input, perm); + auto& node_to_update = dq_node ? *dq_node : node; + Permute1DConstant(graph, node_to_update, *constant, i, constant_to_modify, perm); + + // unset updated input so reconnect_nodes doesn't change it back + if (dq_node) { + constant_dq_input = ""; + } else { + input = ""; + } + return; } + if (consumers->nodes.size() > 0) { // Transpose the initializer. If there are existing consumers, add Transpose nodes to them using perm_inv // to counteract the effect. These Transposes will hopefully be optimized out later. - auto transpose_inv_ptr = MakeTranspose(graph, input, perm_inv); + + // record the consumer node's input as being special cased for use in Case 2 if a DQ node, and IsConstant + if (nodes_using_updated_shared_initializer) { + for (auto& consumer : consumers->nodes) { + auto& consumer_node_inputs = (*nodes_using_updated_shared_initializer)[consumer->Id()]; + + // find input id/s for consumer + auto consumer_inputs = consumer->Inputs(); + for (size_t input_idx = 0; input_idx < consumer_inputs.size(); ++input_idx) { + if (consumer_inputs[input_idx] == constant_to_modify) { + consumer_node_inputs.push_back(input_idx); + } + } + } + } + + auto transpose_inv_ptr = MakeTranspose(graph, constant_to_modify, perm_inv); api::NodeRef& transpose_inv = *transpose_inv_ptr; std::string_view transpose_out = transpose_inv.Outputs()[0]; - graph.CopyValueInfo(input, transpose_out); - ReplaceValueReferences(consumers->nodes, input, transpose_out); + graph.CopyValueInfo(constant_to_modify, transpose_out); + ReplaceValueReferences(consumers->nodes, constant_to_modify, transpose_out); } - graph.TransposeInitializer(input, perm); - node.SetInput(i, input); + + graph.TransposeInitializer(constant_to_modify, perm); + + if (dq_node) { + UpdateDQNodeInputAndShape(graph, *dq_node, constant_to_modify); + constant_dq_input = ""; // DQ input was already updated so we don't need reconnect_nodes to handle it + } + return; } // Case 2: input is a Transpose node std::unique_ptr inp_node = graph.GetNodeProducingOutput(input); + + // check if this is a special-cased DQ node where we put the Transpose on input 0 of the DQ in 'Case 1' above + if (inp_node && inp_node->OpType() == "DequantizeLinear" && + nodes_using_updated_shared_initializer && + std::find_if(nodes_using_updated_shared_initializer->begin(), nodes_using_updated_shared_initializer->end(), + [&inp_node](const auto entry) { + const auto id = entry.first; + const auto& input_idxs = entry.second; + // id matches and the entry is for input 0 of the DQ node + return id == inp_node->Id() && + std::find(input_idxs.begin(), input_idxs.end(), size_t(0)) != input_idxs.end(); + }) != nodes_using_updated_shared_initializer->end()) { + // set things up so we can look past the DQ node to the Transpose that was inserted in front of the reshaped + // constant initializer that was shared with this node. + dq_node = std::move(inp_node); + auto dq_input = dq_node->Inputs()[0]; + inp_node = graph.GetNodeProducingOutput(dq_input); + consumers = graph.GetValueConsumers(dq_input); + } + if (inp_node != nullptr && inp_node->IsOp("Transpose")) { std::optional> perm2 = GetPermAttrIfValid(*inp_node); if (perm2 != std::nullopt && perm2->size() == perm.size()) { // If they cancel, use pre_transpose_value and remove Transpose if possible. if (*perm2 == perm_inv) { std::string_view pre_transpose_value = inp_node->Inputs()[0]; - if (consumers->comprehensive && consumers->nodes.size() == 0) { + + if (dq_node) { + UpdateDQNodeInputAndShape(graph, *dq_node, pre_transpose_value); + node.SetInput(i, dq_node->Outputs()[0]); + } else { + node.SetInput(i, pre_transpose_value); + } + + // Remove the Transpose node if possible + // if there's a DQ node the `consumers` list still includes it so allow for that. + // in that case UpdateDQNodeInputAndShape already updated the input of the DQ node so it's safe to remove it. + if (consumers->comprehensive && consumers->nodes.size() == size_t(dq_node ? 1 : 0)) { graph.RemoveNode(*inp_node); } - node.SetInput(i, pre_transpose_value); - return; - } else if (*perm2 == perm) { - // we are trying to add a duplicate transpose. - // do nothing and return + return; } - // Otherwise, compose the perm and Transpose pre_transpose_value. Cost is the same and we may be able to remove - // the other Transpose. - const std::vector& perm_combined = ComposePerm(*perm2, perm); - auto transpose_ptr = MakeTranspose(graph, inp_node->Inputs()[0], perm_combined); - api::NodeRef& transpose = *transpose_ptr; - std::string_view transpose_out = transpose.Outputs()[0]; - graph.CopyValueInfo(input, transpose_out); - graph.GetValueInfo(transpose_out)->PermuteDims(perm); - if (consumers->comprehensive && consumers->nodes.size() == 0) { - graph.RemoveNode(*inp_node); + // NOTE: We expect the Transpose to cancel out when handling a special-cased DQ node that was originally + // connected to a shared constant initializer, so we don't expect to get here if dq_node is not nullptr. + // If there was a dq_node where the Transpose didn't cancel out we fall through to the next case + // so we retain the potential to cancel out for any other usages of the shared initializer. + assert(!dq_node); // assert in debug build to investigate. fall through to next case in release build to be safe. + + if (!dq_node) { + // Otherwise, compose the perm and Transpose pre_transpose_value. Cost is the same and we may be able to remove + // the other Transpose. + const std::vector& perm_combined = ComposePerm(*perm2, perm); + auto transpose_ptr = MakeTranspose(graph, inp_node->Inputs()[0], perm_combined); + api::NodeRef& transpose = *transpose_ptr; + std::string_view transpose_out = transpose.Outputs()[0]; + graph.CopyValueInfo(input, transpose_out); + graph.GetValueInfo(transpose_out)->PermuteDims(perm); + + if (consumers->comprehensive && consumers->nodes.size() == 0) { + graph.RemoveNode(*inp_node); + } + + node.SetInput(i, transpose_out); + + return; } - node.SetInput(i, transpose_out); - return; } } - // Case 3: A Transpose op might already exist - for (size_t j = 0; j < consumers->nodes.size(); ++j) { - api::NodeRef& consumer = *consumers->nodes[j]; - if (consumer.IsOp("Transpose") && GetPermAttrIfValid(consumer) == perm) { - node.SetInput(i, consumer.Outputs()[0]); + // any DQ node special casing doesn't apply anymore, so go back to the original inp_node + if (dq_node) { + inp_node = std::move(dq_node); + consumers = graph.GetValueConsumers(input); + } + + // Case 3: A Transpose op with the same perms might already exist + for (auto& consumer : consumers->nodes) { + if (consumer->IsOp("Transpose") && GetPermAttrIfValid(*consumer) == perm) { + node.SetInput(i, consumer->Outputs()[0]); return; } } @@ -540,11 +786,25 @@ void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, std::string_view transpose_out = transpose.Outputs()[0]; graph.CopyValueInfo(input, transpose_out); graph.GetValueInfo(transpose_out)->PermuteDims(perm); + node.SetInput(i, transpose_out); } +void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, + const std::vector& perm, + const std::vector& perm_inv) { + // this TransposeInput is used by the layout transformer to wrap a node in Transpose ops. there's no OptimizerCtx + // in that scenario and we're not tracking special-cased DQ nodes as we only do that when pushing Transpose nodes. + TransposeInputImpl(graph, /* nodes_using_updated_shared_initializer */ nullptr, node, i, perm, perm_inv); +} + +static void TransposeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, const std::vector& perm, + const std::vector& perm_inv) { + TransposeInputImpl(ctx.graph, &ctx.nodes_using_updated_shared_initializer, node, i, perm, perm_inv); +} + // Unsqueezes inputs of node to have uniform rank. Returns false if input ranks are unknown or exceed the target rank. -static bool NormalizeInputRanks(OptimizerCtx ctx, api::NodeRef& node, size_t target_rank, +static bool NormalizeInputRanks(OptimizerCtx& ctx, api::NodeRef& node, size_t target_rank, const std::vector& input_indices) { auto inputs = node.Inputs(); @@ -579,7 +839,7 @@ void TransposeInputs(OptimizerCtx& ctx, api::NodeRef& node, const std::vector& input_indices) { auto perm_inv = InvertPerm(perm); for (size_t j : input_indices) { - TransposeInput(ctx.graph, node, j, perm, perm_inv); + TransposeInput(ctx, node, j, perm, perm_inv); } } @@ -670,23 +930,74 @@ static bool CanLikelyRemoveTranspose(const api::GraphRef& graph, api::NodeRef& t return true; } +// return true if +// - the value is a constant initializer +// - the value is the output of a DQ node who's input is a constant initializer +// - UnsqueezeInput/TranposeInput can look past the DQ to update the constant initializer directly +// - DQ node is currently ignored if it uses per-channel quantization +// - supporting per-channel quantization requires modifying the scales and zero point data, which can be done +// if/when there's a use-case to justify the development cost. +// - the input was originally connected to a shared constant initializer that was updated in place by UnsqueezeInput +// or TransposeInput, and usage by this node had Squeeze/Transpose nodes inserted to counteract the effect of the +// in-place update. if we push the same transpose through this node it should cancel out that Squeeze/Transpose +// +// in all these cases we expect pushing the transpose through to not require a runtime Transpose node +static bool IsConstant(const api::GraphRef& graph, const api::NodeRef& node, + size_t input_id, + std::string_view value_name, + const NodeIdToInputIdxsMap& nodes_using_updated_shared_initializer) { + std::unique_ptr producer_node = graph.GetNodeProducingOutput(value_name); + + if (!producer_node) { + // initializer. may or may not be constant depending on whether it has a matching graph input + std::unique_ptr constant = graph.GetConstant(value_name); + return constant != nullptr; + } + + auto node_id_to_check = node.Id(); + + // handle potentially looking past a DQ node + if (producer_node->OpType() == "DequantizeLinear") { + std::unique_ptr dq_node = GetDQWithConstInitializerInput(graph, value_name); + if (dq_node != nullptr) { + // DQ node pointing to an initializer that has not been updated in-place yet + return true; + } + + // could also be a DQ that was connected to a shared initializer that was updated in-place. + // update the info on the node/input index to check and fall through + node_id_to_check = producer_node->Id(); + input_id = 0; // can only be input 0 of a DQ node + } + + auto entry = nodes_using_updated_shared_initializer.find(node_id_to_check); + if (entry != nodes_using_updated_shared_initializer.end()) { + if (std::find(entry->second.begin(), entry->second.end(), input_id) != entry->second.end()) { + return true; + } + } + + return false; +} + // Estimates the cost of transposing an input. Currently uses rank heuristic. Negative if transpose is removed. // Feel free to improve as needed. -static int EstimateTransposeValueCost(const api::GraphRef& graph, std::string_view input, +static int EstimateTransposeValueCost(const api::GraphRef& graph, const api::NodeRef& node, + size_t input_id, std::string_view input, const std::vector& perm_inv, - const HandlerMap& extended_handlers) { + const HandlerMap& extended_handlers, + const NodeIdToInputIdxsMap& nodes_using_updated_shared_initializer) { // Case 1: Transposing constants probably costs nothing. - std::unique_ptr constant = graph.GetConstant(input); - if (constant != nullptr) { + if (IsConstant(graph, node, input_id, input, nodes_using_updated_shared_initializer)) { return 0; } // Case 2: Transposing a transpose either cancels it or composes the permutations. - std::unique_ptr node = graph.GetNodeProducingOutput(input); - if (node != nullptr && node->IsOp("Transpose")) { - std::optional> perm2 = GetPermAttrIfValid(*node); + std::unique_ptr producer_node = graph.GetNodeProducingOutput(input); + if (producer_node != nullptr && producer_node->IsOp("Transpose")) { + std::optional> perm2 = GetPermAttrIfValid(*producer_node); if (perm2 != std::nullopt) { - if (*perm2 == perm_inv && CanLikelyRemoveTranspose(graph, *node, extended_handlers)) { + if (*perm2 == perm_inv && CanLikelyRemoveTranspose(graph, *producer_node, extended_handlers)) { return -EstimateValueRank(graph, input); } else { return 0; @@ -702,11 +1013,13 @@ static int EstimateTransposeValueCost(const api::GraphRef& graph, std::string_vi static int EstimateTransposeInputsCost(const api::GraphRef& graph, const api::NodeRef& node, const std::vector& perm_inv, const std::vector& input_indices, - const HandlerMap& extended_handlers) { + const HandlerMap& extended_handlers, + const NodeIdToInputIdxsMap& nodes_using_updated_shared_initializer) { auto inputs = node.Inputs(); int cost = 0; for (size_t j : input_indices) { - cost += EstimateTransposeValueCost(graph, inputs[j], perm_inv, extended_handlers); + cost += EstimateTransposeValueCost(graph, node, j, inputs[j], perm_inv, extended_handlers, + nodes_using_updated_shared_initializer); } return cost; } @@ -734,8 +1047,10 @@ static bool HandleSimpleNodeBase(HandlerArgs& args, bool broadcast_inputs) { if (broadcast_inputs && !NormalizeInputRanks(args.ctx, args.node, rank, args.transposible_inputs)) { return false; } + TransposeInputs(args.ctx, args.node, args.perm_inv, args.transposible_inputs); TransposeOutputs(args.ctx, args.node, args.perm); + return true; } @@ -927,18 +1242,7 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con node.SetInput(i, gather_output); } -static bool HandleResize([[maybe_unused]] HandlerArgs& args) { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) || defined(USE_WEBNN) - // The CUDA Resize kernel requires that the input is NCHW, so we can't push a Transpose through a Resize - // in ORT builds with CUDA enabled. - // The ROCm EP is generated from the CUDA EP kernel so the same applies to builds with ROCm enabled. - // The QNN EP requires the input to be NHWC, so the Resize handler is also not enabled for QNN builds. - // - // TODO: Remove this special case once the CUDA Resize kernel is implemented "generically" (i.e.) aligning with the - // generic nature of the ONNX spec. - // See https://github.com/microsoft/onnxruntime/pull/10824 for a similar fix applied to the CPU Resize kernel. - return false; -#else +bool HandleResize([[maybe_unused]] HandlerArgs& args) { auto inputs = args.node.Inputs(); int64_t rank_int = gsl::narrow_cast(args.perm.size()); @@ -964,10 +1268,10 @@ static bool HandleResize([[maybe_unused]] HandlerArgs& args) { TransposeOutputs(args.ctx, args.node, args.perm); return true; -#endif } -constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize}; +// Not currently registered by default. +// constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize}; static bool HandlePad(HandlerArgs& args) { size_t rank = args.perm.size(); @@ -1719,8 +2023,11 @@ static const std::unordered_map handler_ma {"Split", split_handler}, {"Shape", shape_handler}, {"Pad", pad_handler}, - {"Resize", resize_handler}, - {"ReduceSum", reduce_op_handler}, + + // Execution providers tend to only implement Resize for specific layouts. Due to that, it's safer to not + // push a Transpose through a Resize unless the EP specifically checks that it can handle the change via an + // extended handler. + // {"Resize", resize_handler}, {"ReduceLogSum", reduce_op_handler}, {"ReduceLogSumExp", reduce_op_handler}, @@ -1728,6 +2035,7 @@ static const std::unordered_map handler_ma {"ReduceMean", reduce_op_handler}, {"ReduceMin", reduce_op_handler}, {"ReduceProd", reduce_op_handler}, + {"ReduceSum", reduce_op_handler}, {"ReduceSumSquare", reduce_op_handler}, {"ReduceL1", reduce_op_handler}, {"ReduceL2", reduce_op_handler}, @@ -1787,12 +2095,14 @@ static int CalculateCost(const api::GraphRef& graph, const api::NodeRef& node, const std::unordered_set& outputs_leading_to_transpose, const HandlerInfo& info, const std::vector& input_indices, - const HandlerMap& extended_handlers) { + const HandlerMap& extended_handlers, + const NodeIdToInputIdxsMap& nodes_using_updated_shared_initializer) { // We require the input cost (number of transposes before the op) and the total cost to strictly decrease. // Strict decrease of the input cost ensures the optimization is stable, since the total cost decrease is just an // estimate (the transpose after the op may or may not cancel with a subsequent transpose). We don't want // repeated runs of the optimizer to have a transpose toggle between two inputs of a binary op. - int cost = EstimateTransposeInputsCost(graph, node, perm, input_indices, extended_handlers); + int cost = EstimateTransposeInputsCost(graph, node, perm, input_indices, extended_handlers, + nodes_using_updated_shared_initializer); if (cost < 0 && info.transposes_outputs) { // If the output will be transposed and won't ultimately cancel, factor in that cost. @@ -1822,13 +2132,14 @@ static bool ShouldPushTranspose(const api::GraphRef& graph, const api::NodeRef& const std::unordered_set& outputs_leading_to_transpose, const HandlerInfo& info, const std::vector transposable_input_indices, - const HandlerMap& extended_handlers) { + const HandlerMap& extended_handlers, + const NodeIdToInputIdxsMap& nodes_using_updated_shared_initializer) { if (node.IsOp("Transpose")) { return true; } int cost = CalculateCost(graph, node, perm, outputs_leading_to_transpose, info, transposable_input_indices, - extended_handlers); + extended_handlers, nodes_using_updated_shared_initializer); return cost < 0; } @@ -1855,7 +2166,7 @@ bool ProcessTranspose(OptimizerCtx& ctx, api::NodeRef& transpose, api::NodeRef& if (cost == CostCheckResult::kFallThrough) { cost = ShouldPushTranspose(ctx.graph, node, perm, outputs_leading_to_transpose, *info, input_indices, - ctx.extended_handlers) + ctx.extended_handlers, ctx.nodes_using_updated_shared_initializer) ? CostCheckResult::kPushTranspose : CostCheckResult::kStop; } @@ -1889,7 +2200,7 @@ std::optional MakeOptimizerContext(api::GraphRef& graph, return std::nullopt; } - OptimizerCtx ctx{*opset, graph, provider_type, cost_check_fn, extended_handlers}; + OptimizerCtx ctx{*opset, graph, provider_type, cost_check_fn, extended_handlers, {}}; return ctx; } @@ -2067,6 +2378,8 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { continue; } + // NOTE: this bleeds ORT specific logic into the base optimizer, however we justify that for now because we expect + // the types that the ORT DQ provides to be added to the ONNX spec, at which point this special case can go away. if (IsMSDomain(dq_domain) && !TransposeQuantizeDequantizeAxis(ctx.graph, perm_inv, *dq_node)) { continue; } diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h index 1a54e7834a4ae..cc1552704c187 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h @@ -3,6 +3,7 @@ #pragma once +#include #include // implementation details of the transpose optimizer API defined in optimizer_api.h. @@ -38,6 +39,8 @@ struct HandlerInfo { bool transposes_outputs = true; }; +using NodeIdToInputIdxsMap = std::unordered_map>; + struct OptimizerCtx { int64_t opset; api::GraphRef& graph; @@ -48,6 +51,32 @@ struct OptimizerCtx { // Handlers for ops that are not in the ONNX opset, or for ONNX ops where special handling is required. // If a handler is not found in this map, the default handlers will be used. const HandlerMap& extended_handlers; + + // When we update a shared constant initializer as part of pushing a transpose through a node we update the + // initializer in-place and insert Squeeze (in UnsqueezeInput if the initializer is broadcast) or + // Transpose (in TransposeInput) nodes between the updated initializer and the other usages. + // This map contains the set of nodes that had a Squeeze or Transpose added between them and the initializer. + // The entry contains the node id (key) and original input index/es (value) that were connected to the initializer + // prior to the insertion of the Squeeze/Transpose. + // + // Assuming we also transpose the other usages of the initializer in the same way (which would be expected) the + // Squeeze and Transpose nodes would be cancelled out, and the other usages will end up using the original + // initializer that was updated in-place. + // + // We use this information in two ways. + // + // 1. In the IsConstant calculation that determines the cost of pushing a transpose through a node. + // - as we expect the transpose to be making the same modification to all shared usages of the initializer we + // expect the Squeeze/Transpose nodes to be cancelled out, resulting in no runtime cost to push the transpose + // through that input. + // + // 2. To enable and track a special case in a QDQ format model where there is the added complexity of a DQ node + // between the initializer and each usage. + // - we look past a DQ node in UnsqueezeInput and TransposeInput to determine if there is a constant initializer + // that can be updated in-place as the DQ node is not sensitive to any rank or layout changes + // - NOTE we currently ignore DQ nodes with per-channel quantization as they are sensitive to changes + // - we also look past DQ nodes when processing the other usages in order to cancel out the Squeeze/Transpose + NodeIdToInputIdxsMap nodes_using_updated_shared_initializer; }; /// @@ -69,6 +98,7 @@ bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional default_ // base handlers that are used by extended handlers. add from transpose_optimizer.cc as needed. bool HandleReduceOps(HandlerArgs& args); +bool HandleResize([[maybe_unused]] HandlerArgs& args); void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std::vector& perm, diff --git a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h index 40a03f24f7648..fb338be1c7f5a 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h +++ b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h @@ -242,6 +242,12 @@ class NodeRef { /// since version or default value -1 virtual int SinceVersion() const = 0; + /// + /// Get the unique id of the node. + /// + /// Id + virtual int64_t Id() const = 0; + virtual ~NodeRef(){}; }; @@ -442,7 +448,7 @@ class GraphRef { } // namespace api constexpr int64_t kMinSupportedOpset = 7; -constexpr int64_t kMaxSupportedOpset = 19; +constexpr int64_t kMaxSupportedOpset = 20; // enum of results that a CostCheckFn can return. enum class CostCheckResult { diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index b30c94d7b3e40..2fcb88cb0b9ba 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -95,7 +95,8 @@ class ApiNode final : public api::NodeRef { void ClearAttribute(std::string_view name) override; void SetInput(size_t i, std::string_view name) override; std::string_view GetExecutionProviderType() const override; - virtual int SinceVersion() const override; + int SinceVersion() const override; + int64_t Id() const override; private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ApiNode); @@ -417,6 +418,10 @@ int ApiNode::SinceVersion() const { return node_.SinceVersion(); } +int64_t ApiNode::Id() const { + return node_.Index(); +} + // std::optional ApiGraph::Opset(std::string_view domain) const { diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc index ead82a6b56741..f4f3505128737 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc @@ -5,12 +5,35 @@ #include #include "core/graph/constants.h" +#include "core/framework/utils.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" using namespace onnx_transpose_optimization; namespace onnxruntime { +static bool EPAwareHandleResize(HandlerArgs& args) { + // Whilst Resize is not technically layout sensitive, execution providers typically implement handling for only one + // layout. Due to that, only push a Transpose through a Resize once it is assigned and we know it's being handled + // by an EP that supports multiple layouts. Currently that's the CPU and XNNPACK EPs. + const auto ep_type = args.node.GetExecutionProviderType(); + if (ep_type == kCpuExecutionProvider || ep_type == kXnnpackExecutionProvider) { + // allow NCHW <-> NHWC for now. not clear any other sort of transpose has a valid usage in a real model + int64_t rank_int = gsl::narrow_cast(args.perm.size()); + if (rank_int == 4) { + static const std::vector nchw_to_nhwc_perm{0, 2, 3, 1}; + static const std::vector nhwc_to_nchw_perm{0, 3, 1, 2}; + if (args.perm == nchw_to_nhwc_perm || args.perm == nhwc_to_nchw_perm) { + return HandleResize(args); + } + } + } + + return false; +} + +constexpr HandlerInfo ep_aware_resize_handler = {&FirstInput, &EPAwareHandleResize}; + static bool HandleQLinearConcat(HandlerArgs& args) { return HandleSimpleNodeWithAxis(args); } @@ -62,7 +85,7 @@ static bool HandleMaxPool(HandlerArgs& args) { ORT_UNUSED_PARAMETER(args); return false; #else - if (args.node.GetExecutionProviderType() != "CPUExecutionProvider") { + if (args.node.GetExecutionProviderType() != kCpuExecutionProvider) { return false; } @@ -103,6 +126,7 @@ static bool HandleContribQuantizeDequantizeLinear(HandlerArgs& args) { } constexpr HandlerInfo max_pool_op_handler = {&FirstInput, &HandleMaxPool}; + constexpr HandlerInfo node_1_inp_handler = {&FirstInput, &HandleSimpleNode}; constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps}; constexpr HandlerInfo contrib_quantize_dequantize_linear_handler = {&FirstInput, @@ -113,6 +137,7 @@ const HandlerMap& OrtExtendedHandlers() { static const HandlerMap extended_handler_map = []() { HandlerMap map = { {"MaxPool", max_pool_op_handler}, + {"Resize", ep_aware_resize_handler}, {"com.microsoft.QuantizeLinear", contrib_quantize_dequantize_linear_handler}, {"com.microsoft.DequantizeLinear", contrib_quantize_dequantize_linear_handler}, {"com.microsoft.QLinearAdd", q_linear_binary_op_handler}, diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h index 0a5dbd6d13d06..8245d8c3b4eae 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h @@ -10,7 +10,6 @@ namespace onnxruntime { /// /// Get the extended handlers for ORT specific transpose optimization. /// These include handlers for contrib ops, and where we have an NHWC version of a layout sensitive op. -/// Extends the handlers returned by OrtHandlers. /// /// HandlerMap const onnx_transpose_optimization::HandlerMap& OrtExtendedHandlers(); diff --git a/onnxruntime/core/optimizer/transpose_optimizer.cc b/onnxruntime/core/optimizer/transpose_optimizer.cc index 33e3f5eeaf0fa..092df9cc7dcfb 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer.cc @@ -18,10 +18,18 @@ namespace onnxruntime { Status TransposeOptimizer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { - auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ nullptr); - - OptimizeResult result = onnx_transpose_optimization::Optimize(*api_graph, "", /* default cost check*/ nullptr, - OrtExtendedHandlers()); + OptimizeResult result; + + if (ep_.empty()) { + // basic usage - no EP specific optimizations + auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ nullptr); + result = onnx_transpose_optimization::Optimize(*api_graph, "", /* default cost check*/ nullptr, + OrtExtendedHandlers()); + } else { + // EP specific optimizations enabled. Currently only used for CPU EP. + auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ ep_.c_str()); + result = onnx_transpose_optimization::Optimize(*api_graph, ep_, OrtEPCostCheck, OrtExtendedHandlers()); + } if (result.error_msg) { // currently onnx_layout_transformation::Optimize only fails if we hit an unsupported opset. diff --git a/onnxruntime/core/optimizer/transpose_optimizer.h b/onnxruntime/core/optimizer/transpose_optimizer.h index 1ae6d611d2f0e..97d7ab4d0e220 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer.h +++ b/onnxruntime/core/optimizer/transpose_optimizer.h @@ -15,10 +15,14 @@ Push transposes through ops and eliminate them. class TransposeOptimizer : public GraphTransformer { private: AllocatorPtr cpu_allocator_; + const std::string ep_; public: - explicit TransposeOptimizer(AllocatorPtr cpu_allocator) noexcept - : GraphTransformer("TransposeOptimizer"), cpu_allocator_(std::move(cpu_allocator)) {} + explicit TransposeOptimizer(AllocatorPtr cpu_allocator, + const std::string& ep = {}) noexcept + : GraphTransformer(ep.empty() ? "TransposeOptimizer" : "TransposeOptimizer_" + ep), + cpu_allocator_(std::move(cpu_allocator)), + ep_{ep} {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; diff --git a/onnxruntime/core/platform/windows/debug_alloc.cc b/onnxruntime/core/platform/windows/debug_alloc.cc index b08d189f79866..2b612a4303442 100644 --- a/onnxruntime/core/platform/windows/debug_alloc.cc +++ b/onnxruntime/core/platform/windows/debug_alloc.cc @@ -12,7 +12,7 @@ // #ifndef NDEBUG #ifdef ONNXRUNTIME_ENABLE_MEMLEAK_CHECK -constexpr int c_callstack_limit = 16; // Maximum depth of callstack in leak trace +constexpr int c_callstack_limit = 32; // Maximum depth of callstack in leak trace #define VALIDATE_HEAP_EVERY_ALLOC 0 // Call HeapValidate on every new/delete #pragma warning(disable : 4073) // initializers put in library initialization area (this is intentional) @@ -223,6 +223,10 @@ Memory_LeakCheck::~Memory_LeakCheck() { // empty_group_names = new std::map; }); if (string.find("RtlRunOnceExecuteOnce") == std::string::npos && string.find("re2::RE2::Init") == std::string::npos && + string.find("dynamic initializer for 'FLAGS_") == std::string::npos && + string.find("AbslFlagDefaultGenForgtest_") == std::string::npos && + string.find("::SetProgramUsageMessage") == std::string::npos && + string.find("testing::internal::ParseGoogleTestFlagsOnly") == std::string::npos && string.find("testing::internal::Mutex::ThreadSafeLazyInit") == std::string::npos && string.find("testing::internal::ThreadLocalRegistryImpl::GetThreadLocalsMapLocked") == std::string::npos && string.find("testing::internal::ThreadLocalRegistryImpl::GetValueOnCurrentThread") == std::string::npos && diff --git a/onnxruntime/core/platform/windows/stacktrace.cc b/onnxruntime/core/platform/windows/stacktrace.cc index cac6f4f29043b..d7d423e4a483e 100644 --- a/onnxruntime/core/platform/windows/stacktrace.cc +++ b/onnxruntime/core/platform/windows/stacktrace.cc @@ -10,7 +10,6 @@ #include #endif #endif -#include #include "core/common/logging/logging.h" #include "core/common/gsl.h" diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index b142db86a7902..b4132d3b770ec 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -51,7 +51,7 @@ class BaseOpBuilder : public IOpBuilder { virtual bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const; virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; } - virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 19; } + virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 20; } private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 18010960e11c8..3d03abf5b7ebc 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -273,7 +273,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma // Opset 9 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Compress); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantOfShape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 19, ConstantOfShape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, MeanVarianceNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, Greater); @@ -958,6 +958,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Scan); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Shape); +// Opset 20 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape); + // !!PLEASE READ BELOW!! Following that, add new entries above this comment /* *** IMPORTANT! *** @@ -1332,7 +1335,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { // Opset 9 BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // Opset 20 + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc b/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc index 920db5ed34dd1..a93da12ccf595 100644 --- a/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc @@ -11,11 +11,16 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0, ConstantOfShapeDefaultOutputTypes); +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, 20, Output, 0, + ConstantOfShapeDefaultOutputTypesOpset20); + // pytorch converter uses ConstantOfShape with int64 to create Pad input // https://github.com/pytorch/pytorch/blob/044b519a80459f6787f6723c1c091a18b153d184/torch/onnx/symbolic_opset11.py#L449 ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0, int64_t); + } // namespace op_kernel_type_control namespace { @@ -24,6 +29,10 @@ using EnabledOutputTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0); +using EnabledOutputTypesOpset20 = + ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, 20, Output, 0); + class ConstantOfShape final : public ConstantOfShapeBase, public OpKernel { public: explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), OpKernel(info) {} @@ -66,13 +75,22 @@ Status ConstantOfShape::Compute(OpKernelContext* ctx) const { } // namespace -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( ConstantOfShape, 9, + 19, KernelDefBuilder() .TypeConstraint("T1", DataTypeImpl::GetTensorType()) .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()), ConstantOfShape); +ONNX_CPU_OPERATOR_KERNEL( + ConstantOfShape, + 20, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", + BuildKernelDefConstraintsFromTypeList()), + ConstantOfShape); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h index d96ff06e3d6d8..9aa73c714daea 100644 --- a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h @@ -23,6 +23,18 @@ using ConstantOfShapeDefaultOutputTypes = uint8_t, uint16_t, uint32_t, uint64_t, bool>; +using ConstantOfShapeDefaultOutputTypesOpset20 = + TypeList< + BFloat16, + MLFloat16, + float, double, +#if !defined(DISABLE_FLOAT8_TYPES) + Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ, +#endif + int8_t, int16_t, int32_t, int64_t, + uint8_t, uint16_t, uint32_t, uint64_t, + bool>; + template class ConstantOfShapeBase { protected: diff --git a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc index e9fc8d857b831..21a256eee6f14 100644 --- a/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc +++ b/onnxruntime/core/providers/cpu/quantization/qlinearconv.cc @@ -77,7 +77,8 @@ class QLinearConv : public OpKernel { W_zero_point_value = W_zero_point_data[0]; for (int64_t i = 1; i < W_zero_point_size; i++) { ORT_ENFORCE(W_zero_point_data[i] == W_zero_point_value, - "QLinearConv : zero point of per-channel filter must be same"); + "QLinearConv : zero point of per-channel filter must be same. " + "This happens by design if the quantization is symmetric."); } } diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index c13c9d42dd392..99522cdf0759a 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -239,7 +239,7 @@ class UpsampleBase { if (coordinate_transform_mode_name == "half_pixel_symmetric") { return HALF_PIXEL_SYMMETRIC; } - ORT_THROW("coordinate_transform_mode:[" + coordinate_transform_mode_name + "] is not supportted!"); + ORT_THROW("coordinate_transform_mode:[" + coordinate_transform_mode_name + "] is not supported!"); } GetOriginalCoordinateFunc GetOriginalCoordinateFromResizedCoordinate( diff --git a/onnxruntime/core/providers/js/allocator.cc b/onnxruntime/core/providers/js/allocator.cc index c1d0aa9abbf6b..574c507222a5c 100644 --- a/onnxruntime/core/providers/js/allocator.cc +++ b/onnxruntime/core/providers/js/allocator.cc @@ -10,6 +10,10 @@ namespace onnxruntime { namespace js { void* JsCustomAllocator::Alloc(size_t size) { + if (size == 0) { + return nullptr; + } + void* p = EM_ASM_PTR({ return Module.jsepAlloc($0); }, size); stats_.num_allocs++; stats_.bytes_in_use += size; @@ -17,8 +21,10 @@ void* JsCustomAllocator::Alloc(size_t size) { } void JsCustomAllocator::Free(void* p) { - size_t size = (size_t)(void*)EM_ASM_PTR({ return Module.jsepFree($0); }, p); - stats_.bytes_in_use -= size; + if (p != nullptr) { + size_t size = (size_t)(void*)EM_ASM_PTR({ return Module.jsepFree($0); }, p); + stats_.bytes_in_use -= size; + } } void JsCustomAllocator::GetStats(AllocatorStats* stats) { diff --git a/onnxruntime/core/providers/js/data_transfer.cc b/onnxruntime/core/providers/js/data_transfer.cc index c62362d90867f..ebea041b80128 100644 --- a/onnxruntime/core/providers/js/data_transfer.cc +++ b/onnxruntime/core/providers/js/data_transfer.cc @@ -20,23 +20,25 @@ bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_dev common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { size_t bytes = src.SizeInBytes(); - const void* src_data = src.DataRaw(); - void* dst_data = dst.MutableDataRaw(); - - auto& src_device = src.Location().device; - auto& dst_device = dst.Location().device; - - if (dst_device.Type() == OrtDevice::GPU) { - if (src_device.Type() == OrtDevice::GPU) { - // copy from GPU to GPU - EM_ASM({ Module.jsepCopy($0, $1, $2, true); }, src_data, dst_data, bytes); - } else { - // copy from CPU to GPU - EM_ASM({ Module.jsepCopy($0, $1, $2); }, src_data, dst_data, bytes); + if (bytes > 0) { + const void* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + auto& src_device = src.Location().device; + auto& dst_device = dst.Location().device; + + if (dst_device.Type() == OrtDevice::GPU) { + if (src_device.Type() == OrtDevice::GPU) { + // copy from GPU to GPU + EM_ASM({ Module.jsepCopy($0, $1, $2, true); }, src_data, dst_data, bytes); + } else { + // copy from CPU to GPU + EM_ASM({ Module.jsepCopy($0, $1, $2); }, src_data, dst_data, bytes); + } + } else /* if (src_device.Type() == OrtDevice::GPU) */ { + // copy from GPU to CPU + jsepDownload(src_data, dst_data, bytes); } - } else /* if (src_device.Type() == OrtDevice::GPU) */ { - // copy from GPU to CPU - jsepDownload(src_data, dst_data, bytes); } return Status::OK(); diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index a1e6cc7d90acc..6ced8d4d4a4ad 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -129,56 +129,56 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 14, Rel class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 15, LeakyRelu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, LeakyRelu); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMax); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceMax); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceMax); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceMean); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceMean); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceMean); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, float, ReduceMin); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceMin); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceMin); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceProd); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceProd); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceProd); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, ReduceSum); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL1); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL1); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceL1); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceL2); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceL2); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceL2); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSum); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceLogSum); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceLogSum); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceSumSquare); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceSumSquare); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceSumSquare); - -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ReduceLogSumExp); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, float, ReduceLogSumExp); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, float, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, ReduceMax); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMax); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMax); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceMean); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMean); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMean); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, 12, ReduceMin); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceMin); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceMin); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceProd); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceProd); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceProd); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, ReduceSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceL1); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceL1); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceL1); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceL2); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceL2); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceL2); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSum); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSum); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceLogSum); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceSumSquare); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceSumSquare); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceSumSquare); + +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, ReduceLogSumExp); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, 17, ReduceLogSumExp); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 18, ReduceLogSumExp); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, ThresholdedRelu); @@ -234,11 +234,11 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Tra class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, Conv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, float, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, float, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, AveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalAveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, float, GlobalMaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Conv); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Conv); @@ -251,16 +251,16 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Gem class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, float, AveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, float, AveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, AveragePool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, float, GlobalAveragePool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, float, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, float, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, float, MaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, float, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, MaxPool); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, float, GlobalMaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 9, AveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, AveragePool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalAveragePool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 7, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 8, 9, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 10, 10, MaxPool); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 11, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, MaxPool); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, GlobalMaxPool); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ArgMax); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, ArgMax); @@ -438,71 +438,71 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -515,16 +515,16 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index 177c0a9e691ed..fdd5c7dee5bfc 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -196,7 +196,7 @@ class JsKernel : public OpKernel { } int status_code = EM_ASM_INT( - { return Module.jsepRunKernel($0, $1, Module.jsepSessionState); }, + { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, this, reinterpret_cast(p_serialized_kernel_context)); LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data=" diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index dbe4c188380eb..18ef73268005d 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -108,6 +108,23 @@ class ConvTranspose : public JsKernel { } } + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* /* prepacked_weights */) override { + is_packed = false; + + if (input_idx == 1) { + // Only handle the common case of conv2D + if (tensor.Shape().NumDimensions() != 4 || tensor.SizeInBytes() == 0) { + return Status::OK(); + } + + w_is_const_ = true; + } + + return Status::OK(); + } + protected: ConvTransposeAttributes conv_transpose_attrs_; bool w_is_const_; diff --git a/onnxruntime/core/providers/js/operators/pool.cc b/onnxruntime/core/providers/js/operators/pool.cc index 03e6caef7e5b8..7fdb4e5d114ea 100644 --- a/onnxruntime/core/providers/js/operators/pool.cc +++ b/onnxruntime/core/providers/js/operators/pool.cc @@ -8,69 +8,65 @@ namespace onnxruntime { namespace js { -#define POOLING_KERNEL(op_name, domain, is_channels_last, data_type, pool_type, since_version) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - op_name, \ - domain, \ - since_version, \ - data_type, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Pool); +#define POOLING_KERNEL(op_name, domain, is_channels_last, pool_type, since_version) \ + ONNX_OPERATOR_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), \ + Pool); -#define POOLING_KERNEL_VERSIONED(op_name, domain, is_channels_last, data_type, pool_type, since_version, end_version) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - op_name, \ - domain, \ - since_version, \ - end_version, \ - data_type, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Pool); +#define POOLING_KERNEL_VERSIONED(op_name, domain, is_channels_last, pool_type, since_version, end_version) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + end_version, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", JsepSupportedFloatTypes()), \ + Pool); -#define POOLING_KERNEL_WITH_INDICES(op_name, domain, is_channels_last, data_type, pool_type, since_version) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - op_name, \ - domain, \ - since_version, \ - data_type, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ - Pool); +#define POOLING_KERNEL_WITH_INDICES(op_name, domain, is_channels_last, pool_type, since_version) \ + ONNX_OPERATOR_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", JsepSupportedFloatTypes()) \ + .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ + Pool); -#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, domain, is_channels_last, data_type, pool_type, since_version, end_version) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - op_name, \ - domain, \ - since_version, \ - end_version, \ - data_type, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ - Pool); +#define POOLING_KERNEL_VERSIONED_WITH_INDICES(op_name, domain, is_channels_last, pool_type, since_version, end_version) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + op_name, \ + domain, \ + since_version, \ + end_version, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", JsepSupportedFloatTypes()) \ + .TypeConstraint("I", DataTypeImpl::GetTensorType()), \ + Pool); -POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, float, AveragePool, 7, 9) -POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, float, AveragePool, 10, 10) -POOLING_KERNEL(AveragePool, kOnnxDomain, false, float, AveragePool, 11) -POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, float, AveragePool, 11) -POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, float, AveragePool, 1) -POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, float, AveragePool, 1) +POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 7, 9) +POOLING_KERNEL_VERSIONED(AveragePool, kOnnxDomain, false, AveragePool, 10, 10) +POOLING_KERNEL(AveragePool, kOnnxDomain, false, AveragePool, 11) +POOLING_KERNEL(AveragePool, kMSInternalNHWCDomain, true, AveragePool, 11) +POOLING_KERNEL(GlobalAveragePool, kOnnxDomain, false, AveragePool, 1) +POOLING_KERNEL(GlobalAveragePool, kMSInternalNHWCDomain, true, AveragePool, 1) -POOLING_KERNEL_VERSIONED(MaxPool, kOnnxDomain, false, float, MaxPool<1>, 1, 7) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 8, 9) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 10, 10) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 11, 11) -POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, float, MaxPool<8>, 11, 11) -POOLING_KERNEL_WITH_INDICES(MaxPool, kOnnxDomain, false, float, MaxPool<8>, 12) -POOLING_KERNEL_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, float, MaxPool<8>, 12) -POOLING_KERNEL(GlobalMaxPool, kOnnxDomain, false, float, MaxPool<1>, 1) -POOLING_KERNEL(GlobalMaxPool, kMSInternalNHWCDomain, true, float, MaxPool<1>, 1) +POOLING_KERNEL_VERSIONED(MaxPool, kOnnxDomain, false, MaxPool<1>, 1, 7) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 8, 9) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 10, 10) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 11, 11) +POOLING_KERNEL_VERSIONED_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 11, 11) +POOLING_KERNEL_WITH_INDICES(MaxPool, kOnnxDomain, false, MaxPool<8>, 12) +POOLING_KERNEL_WITH_INDICES(MaxPool, kMSInternalNHWCDomain, true, MaxPool<8>, 12) +POOLING_KERNEL(GlobalMaxPool, kOnnxDomain, false, MaxPool<1>, 1) +POOLING_KERNEL(GlobalMaxPool, kMSInternalNHWCDomain, true, MaxPool<1>, 1) } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/pool.h b/onnxruntime/core/providers/js/operators/pool.h index 5dbe5d0b8881d..5723123c0c3b8 100644 --- a/onnxruntime/core/providers/js/operators/pool.h +++ b/onnxruntime/core/providers/js/operators/pool.h @@ -41,7 +41,7 @@ namespace js { #define GLOBAL_POOL_ATTRIBUTES_JS_OBJ_MAPPING ({"format" : $1 ? "NHWC" : "NCHW"}) #define GLOBAL_POOL_ATTRIBUTES_PARAM_LIST static_cast(is_channels_last) -template +template class Pool : public JsKernel, public PoolBase { public: Pool(const OpKernelInfo& info) : JsKernel(info), PoolBase(info) { @@ -65,10 +65,10 @@ class Pool : public JsKernel, public PoolBase { } }; -template -class Pool, is_channels_last> final : public Pool, is_channels_last> { +template +class Pool, is_channels_last> final : public Pool, is_channels_last> { public: - Pool(const OpKernelInfo& info) : Pool, is_channels_last>(info) {} + Pool(const OpKernelInfo& info) : Pool, is_channels_last>(info) {} }; } // namespace js diff --git a/onnxruntime/core/providers/js/operators/reduce.cc b/onnxruntime/core/providers/js/operators/reduce.cc index 21854fccc37ca..2679cfed86124 100644 --- a/onnxruntime/core/providers/js/operators/reduce.cc +++ b/onnxruntime/core/providers/js/operators/reduce.cc @@ -7,32 +7,30 @@ namespace onnxruntime { namespace js { #define REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceOp, sinceVersion, endVersion) \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ ReduceOp, \ kOnnxDomain, \ sinceVersion, endVersion, \ - float, \ kJsExecutionProvider, \ (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ReduceOp); + .TypeConstraint("T", JsepSupportedFloatTypes()), \ + ReduceOp); // macro REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL does not set .InputMemoryType(OrtMemTypeCPU, 1), so in future if // a new opset version update applies to Reduce* operators, we may need to add another macro like // REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL_WITH_AXIS_IN_INPUT to set input memory type. // i.e. we cannot use REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL to version 18 when the opset version is increased. -#define REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceOp, sinceVersion) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ReduceOp, \ - kOnnxDomain, \ - sinceVersion, \ - float, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(OrtMemTypeCPU, 1), \ - ReduceOp); +#define REGISTER_REDUCE_ELEMENTWISE_KERNEL(ReduceOp, sinceVersion) \ + ONNX_OPERATOR_KERNEL_EX( \ + ReduceOp, \ + kOnnxDomain, \ + sinceVersion, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", JsepSupportedFloatTypes()) \ + .InputMemoryType(OrtMemTypeCPU, 1), \ + ReduceOp); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 1, 10); REGISTER_REDUCE_ELEMENTWISE_VERSIONED_KERNEL(ReduceMean, 11, 12); diff --git a/onnxruntime/core/providers/js/operators/reduce.h b/onnxruntime/core/providers/js/operators/reduce.h index 19a6d298c7696..a5a4aa834c2ca 100644 --- a/onnxruntime/core/providers/js/operators/reduce.h +++ b/onnxruntime/core/providers/js/operators/reduce.h @@ -9,7 +9,7 @@ namespace onnxruntime { namespace js { #define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \ - template \ + template \ class ReduceKernel : public JsKernel, public ReduceKernelBase { \ public: \ using ReduceKernelBase::axes_; \ diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc index 92a7feea7fc54..df4c718949269 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/clip_op_builder.cc @@ -1,6 +1,9 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include +#include + #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" #include "core/providers/qnn/builder/qnn_model_wrapper.h" @@ -9,8 +12,6 @@ #include "base_op_builder.h" -#include - namespace onnxruntime { namespace qnn { class ClipOpBuilder : public BaseOpBuilder { @@ -33,8 +34,6 @@ class ClipOpBuilder : public BaseOpBuilder { private: Status ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const; - mutable float min_value_ = std::numeric_limits::lowest(); - mutable float max_value_ = std::numeric_limits::max(); }; Status ClipOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const { @@ -61,61 +60,8 @@ Status ClipOpBuilder::ProcessInputs(QnnModelWrapper& qnn_model_wrapper, if (do_op_validation) { ORT_RETURN_IF_ERROR(ExplictOpCheck(qnn_model_wrapper, node_unit)); } - Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; - - auto inputs = node_unit.Inputs(); - for (size_t input_i = 0; input_i < inputs.size(); ++input_i) { - Qnn_QuantizeParams_t quantize_param = QNN_QUANTIZE_PARAMS_INIT; - bool is_quantized_tensor = inputs[input_i].quant_param.has_value(); - utils::InitializeQuantizeParam(quantize_param, is_quantized_tensor); - - auto& input_name = inputs[input_i].node_arg.Name(); - if (input_name.empty()) { - // Ignore unspecified/unused optional input - continue; - } - if (qnn_model_wrapper.IsQnnTensorWrapperExist(input_name)) { - LOGS(logger, VERBOSE) << "Tensor already added or the input is not named, skip it: " << input_name; - input_names.push_back(input_name); - continue; - } - - const auto* type_proto = inputs[input_i].node_arg.TypeAsProto(); - ORT_RETURN_IF_ERROR(utils::GetQnnDataType(is_quantized_tensor, type_proto, qnn_data_type)); - - std::vector input_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(inputs[input_i].node_arg, input_shape), "Cannot get shape"); - - ORT_RETURN_IF_NOT(qnn_model_wrapper.ProcessQuantizationParameter(inputs[input_i].quant_param, - quantize_param.scaleOffsetEncoding.scale, - quantize_param.scaleOffsetEncoding.offset), - "Cannot get quantization parameter"); - - float* ini_data = nullptr; - std::vector unpacked_tensor; - bool is_initializer_input = qnn_model_wrapper.IsInitializerInput(input_name); - if (is_initializer_input) { - const auto& input_tensor = qnn_model_wrapper.GetInitializerTensors().at(input_name); - ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*input_tensor, unpacked_tensor)); - ini_data = reinterpret_cast(unpacked_tensor.data()); - if (input_i == 1) { - min_value_ = *ini_data; - continue; - } else if (input_i == 2) { - max_value_ = *ini_data; - continue; - } - } - ORT_ENFORCE(input_i == 0, "QNN ReluMinMax operator expects only one input. Min and max are expected to be parameters, ie. initializer inputs in ONNX model"); - - Qnn_TensorType_t tensor_type = GetInputTensorType(qnn_model_wrapper, input_name); - QnnTensorWrapper input_tensorwrapper(input_name, tensor_type, qnn_data_type, quantize_param, - std::move(input_shape), std::move(unpacked_tensor)); - ORT_RETURN_IF_NOT(qnn_model_wrapper.AddTensorWrapper(std::move(input_tensorwrapper)), "Failed to add tensor."); - input_names.push_back(input_name); - } - return Status::OK(); + return ProcessInput(qnn_model_wrapper, node_unit.Inputs()[0], logger, input_names); } Status ClipOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper, @@ -123,20 +69,59 @@ Status ClipOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wra std::vector&& input_names, const logging::Logger& logger, bool do_op_validation) const { + const auto& inputs = node_unit.Inputs(); + const size_t num_inputs = inputs.size(); + + const Qnn_DataType_t qnn_data_type = QNN_DATATYPE_FLOAT_32; std::vector param_tensor_names; - Qnn_Scalar_t min_qnn_scalar = QNN_SCALAR_INIT; - min_qnn_scalar.dataType = QNN_DATATYPE_FLOAT_32; - min_qnn_scalar.floatValue = min_value_; - QnnParamWrapper min_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MIN_VALUE, min_qnn_scalar); - param_tensor_names.push_back(min_value_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(min_value_param)); - - Qnn_Scalar_t max_qnn_scalar = QNN_SCALAR_INIT; - max_qnn_scalar.dataType = QNN_DATATYPE_FLOAT_32; - max_qnn_scalar.floatValue = max_value_; - QnnParamWrapper max_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MAX_VALUE, max_qnn_scalar); - param_tensor_names.push_back(max_value_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(max_value_param)); + + auto get_f32_from_bytes = [](const std::vector& bytes, float default_val) -> float { + return bytes.empty() ? default_val : *reinterpret_cast(bytes.data()); + }; + + // Set the 'min' parameter. + { + std::vector min_val_bytes; + + if (num_inputs > 1 && !inputs[1].node_arg.Name().empty()) { + OnnxInputInfo min_input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[1], min_input_info)); + ORT_RETURN_IF_NOT(min_input_info.qnn_data_type == qnn_data_type, + "QNN EP: The 'min' input of the Clip operator must be of type float32."); + assert(min_input_info.is_initializer); // Checked by ExplicitOpCheck(). + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*min_input_info.initializer_tensor, min_val_bytes)); + } + + Qnn_Scalar_t min_qnn_scalar = QNN_SCALAR_INIT; + min_qnn_scalar.dataType = qnn_data_type; + min_qnn_scalar.floatValue = get_f32_from_bytes(min_val_bytes, std::numeric_limits::lowest()); + QnnParamWrapper min_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MIN_VALUE, + min_qnn_scalar); + param_tensor_names.push_back(min_value_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(min_value_param)); + } + + // Set the 'max' parameter. + { + std::vector max_val_bytes; + + if (num_inputs > 2 && !inputs[2].node_arg.Name().empty()) { + OnnxInputInfo max_input_info = {}; + ORT_RETURN_IF_ERROR(qnn_model_wrapper.GetOnnxInputInfo(inputs[2], max_input_info)); + ORT_RETURN_IF_NOT(max_input_info.qnn_data_type == qnn_data_type, + "QNN EP: The 'max' input of the Clip operator must of type float32."); + assert(max_input_info.is_initializer); // Checked by ExplicitOpCheck(). + ORT_RETURN_IF_ERROR(qnn_model_wrapper.UnpackInitializerData(*max_input_info.initializer_tensor, max_val_bytes)); + } + + Qnn_Scalar_t max_qnn_scalar = QNN_SCALAR_INIT; + max_qnn_scalar.dataType = qnn_data_type; + max_qnn_scalar.floatValue = get_f32_from_bytes(max_val_bytes, std::numeric_limits::max()); + QnnParamWrapper max_value_param(node_unit.Index(), node_unit.Name(), QNN_OP_RELU_MIN_MAX_PARAM_MAX_VALUE, + max_qnn_scalar); + param_tensor_names.push_back(max_value_param.GetParamTensorName()); + qnn_model_wrapper.AddParamWrapper(std::move(max_value_param)); + } ORT_RETURN_IF_ERROR(ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc index 511f2a5149f2e..4039c4fbf8d70 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/resize_op_builder.cc @@ -2,7 +2,8 @@ // Licensed under the MIT License. #include -#include +#include +#include #include "core/providers/common.h" #include "core/providers/shared/utils/utils.h" @@ -42,76 +43,6 @@ class ResizeOpBuilder : public BaseOpBuilder { bool do_op_validation) const override ORT_MUST_USE_RESULT; private: - /** - * Returns the QNN integer value that corresponds to the given ONNX mode (string). - * - * /param onnx_modes Array of ONNX modes supported by QNN. The index of each mode corresponds to the QNN value. - * /param onnx_mode The ONNX mode for which to get the corresponding QNN value. - * /param onnx_model_label Mode label to print out in case of error (e.g., "nearest_mode"). - * /param qnn_mode Output parameter that is set to the appropriate QNN value from the given ONNX mode. - * - * /returns A status indicating failure or success. - */ - template - Status GetQnnModeFromString(const std::array& onnx_modes, std::string_view onnx_mode, - const char* onnx_mode_label, QnnValType& qnn_mode) const ORT_MUST_USE_RESULT; - - /** - * Called by IsOpSupported to validate the op for non-quantized models. - * - * /param qnn_model_wrapper The QNN model wrapper instance. - * /param node_unit The node unit containing metadata for the ONNX Resize operator. - * - * /returns A status indicating failure or success. - */ - Status ValidateOp(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const ORT_MUST_USE_RESULT; - - /** - * Called by IsOpSupported to validate the op for quantized models. - * - * /param qnn_model_wrapper The QNN model wrapper instance. - * /param node_unit The node unit containing metadata for the ONNX Resize operator and its Q/DQ nodes. - * - * /returns A status indicating failure or success. - */ - Status ValidateQDQOp(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const ORT_MUST_USE_RESULT; - - /** - * Called by ProcessAttributesAndOutputs to process the op's attributes and outputs - * for non-quantized models. - * - * /param qnn_model_wrapper The QNN model wrapper instance. - * /param node_unit The node unit containing metadata for the ONNX Resize operator. - * /param input_names The operator's input names. - * /param logger A logger. - * /param do_op_validation Set to true if the op should be validated using QNN's validation API. - * - * /returns A status indicating failure or success. - */ - Status ProcessOpAttrsAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const ORT_MUST_USE_RESULT; - - /** - * Called by ProcessAttributesAndOutputs to process the op's attributes and outputs - * for quantized models. - * - * /param qnn_model_wrapper The QNN model wrapper instance. - * /param node_unit The node unit containing metadata for the ONNX Resize operator and its Q/DQ nodes. - * /param input_names The operator's input names. - * /param logger A logger. - * /param do_op_validation Set to true if the op should be validated using QNN's validation API. - * - * /returns A status indicating failure or success. - */ - Status ProcessQDQOpAttrsAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const ORT_MUST_USE_RESULT; - // Info for each ONNX attribute of interest (attribute name + default value) static const OnnxAttrInfo onnx_mode_attr; static const OnnxAttrInfo onnx_coord_transf_mode_attr; @@ -119,21 +50,29 @@ class ResizeOpBuilder : public BaseOpBuilder { static const OnnxAttrInfo onnx_antialias_attr; static const OnnxAttrInfo onnx_exclude_outside_attr; - // Arrays of supported QNN modes for QNN's Resize op. The index of each mode is used as the corresponding - // QNN parameter value. Ex: The "nearest" mode is represented as the value 0 in QNN. Note, that - // not all modes are supported by every QNN backend. + // Tables that map an ONNX attribute value (string) to the corresponding integer (enum) QNN parameter value. + // Ex: The "half_pixel" coordinate_transformation_mode is represented as the value 0 in QNN. + // Only the modes supported by QNN Resize are mapped by these tables. + static const std::unordered_map supported_modes; + static const std::unordered_map supported_coord_transf_modes; + static const std::unordered_map supported_nearest_modes; +}; - // QNN values: NEAREST = 0, LINEAR = 1 - static constexpr std::array supported_modes = {"nearest", "linear"}; +const std::unordered_map ResizeOpBuilder::supported_modes = { + {"nearest", QNN_OP_RESIZE_INTERPOLATION_MODE_NEAREST}, + {"linear", QNN_OP_RESIZE_INTERPOLATION_MODE_LINEAR}}; - // QNN values: HALF_PIXEL = 0, PYTORCH_HALF_PIXEL = 1, ALIGN_CORNERS = 2, ASYMMETRIC = 3 - static constexpr std::array supported_coord_transf_modes = {"half_pixel", "pytorch_half_pixel", - "align_corners", "asymmetric"}; +const std::unordered_map ResizeOpBuilder::supported_coord_transf_modes = { + {"half_pixel", QNN_OP_RESIZE_TRANSFORMATION_MODE_HALF_PIXEL}, + {"pytorch_half_pixel", QNN_OP_RESIZE_TRANSFORMATION_MODE_PYTORCH_HALF_PIXEL}, + {"align_corners", QNN_OP_RESIZE_TRANSFORMATION_MODE_ALIGN_CORNERS}, + {"asymmetric", QNN_OP_RESIZE_TRANSFORMATION_MODE_ASYMMETRIC}}; - // QNN values: ROUND_PREFER_FLOOR = 0, ROUND_PREFER_CEIL = 1, FLOOR = 2, CEIL = 3 - static constexpr std::array supported_nearest_modes = {"round_prefer_floor", "round_prefer_ceil", - "floor", "ceil"}; -}; +const std::unordered_map ResizeOpBuilder::supported_nearest_modes = { + {"round_prefer_floor", QNN_OP_RESIZE_NEAREST_MODE_ROUND_PREFER_FLOOR}, + {"round_prefer_ceil", QNN_OP_RESIZE_NEAREST_MODE_ROUND_PREFER_CEIL}, + {"floor", QNN_OP_RESIZE_NEAREST_MODE_FLOOR}, + {"ceil", QNN_OP_RESIZE_NEAREST_MODE_CEIL}}; const OnnxAttrInfo ResizeOpBuilder::onnx_mode_attr = {"mode", "nearest"}; const OnnxAttrInfo ResizeOpBuilder::onnx_coord_transf_mode_attr = {"coordinate_transformation_mode", @@ -143,19 +82,26 @@ const OnnxAttrInfo ResizeOpBuilder::onnx_nearest_mode_attr = {"near const OnnxAttrInfo ResizeOpBuilder::onnx_antialias_attr = {"antialias", 0}; const OnnxAttrInfo ResizeOpBuilder::onnx_exclude_outside_attr = {"exclude_outside", 0}; -template -Status ResizeOpBuilder::GetQnnModeFromString(const std::array& onnx_modes, - std::string_view onnx_mode, const char* onnx_mode_label, - QnnValType& qnn_mode) const { - for (size_t i = 0; i < onnx_modes.size(); ++i) { - if (onnx_modes[i] == onnx_mode) { - qnn_mode = SafeInt(i); - return Status::OK(); - } +// Returns the QNN parameter integer value that corresponds to the given ONNX attribute mode string value. +static Status GetQnnModeValFromOnnxString(const std::unordered_map& supported_qnn_modes, + const std::string& onnx_attr_value, + const char* onnx_attr_name, + uint32_t& qnn_mode_value) { + auto it = supported_qnn_modes.find(onnx_attr_value); + if (it != supported_qnn_modes.end()) { + qnn_mode_value = it->second; + return Status::OK(); } - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Resize operator does not support ", onnx_mode_label, - " ", std::string(onnx_mode)); + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN EP: Resize operator does not support ", onnx_attr_name, + " ", std::string(onnx_attr_value)); +} + +// Returns true if the given ONNX attribute mode value is generally supported on QNN. Note that +// different QNN backends may support a smaller subset of modes. +static bool IsOnnxAttrModeSupported(const std::unordered_map& supported_qnn_modes, + const std::string& onnx_attr_value) { + return supported_qnn_modes.find(onnx_attr_value) != supported_qnn_modes.end(); } // Resize ops are sensitive with data layout, no special validation so far @@ -169,118 +115,95 @@ Status ResizeOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper, return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true); } + const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + NodeAttrHelper node_helper(node_unit); + // QNN doesn't support anti-aliasing (added in opset 18) if (node_unit.SinceVersion() >= 18) { - NodeAttrHelper node_helper(node_unit); const bool antialias = GetOnnxAttr(node_helper, onnx_antialias_attr) != 0; ORT_RETURN_IF(antialias, "QNN EP: Resize doesn't support anti-aliasing."); } - // The QNN Resize op does not currently work with the QNN cpu backend, but works with the HTP backend. Therefore, we - // currently use QNN's Resize op for quantized models and either ResizeBilinear or ResizeNearestNeighbor for - // non-quantized models. This requires separate validation for quantized models. - // TODO: Use only Resize once QNN's Resize op works in the QNN cpu backend. - bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); - return is_npu_backend ? ValidateQDQOp(qnn_model_wrapper, node_unit) : ValidateOp(qnn_model_wrapper, node_unit); -} - -Status ResizeOpBuilder::ValidateOp(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const { - NodeAttrHelper node_helper(node_unit); - const std::string resize_mode = GetOnnxAttr(node_helper, onnx_mode_attr); - ORT_RETURN_IF((resize_mode != "nearest") && (resize_mode != "linear"), - "QNN EP: Resize doesn't support mode '", resize_mode.c_str(), "'.", - "Only 'nearest' and 'linear' are supported."); - - const std::string coordinate_mode = GetOnnxAttr(node_helper, onnx_coord_transf_mode_attr); - ORT_RETURN_IF((coordinate_mode != "half_pixel") && (coordinate_mode != "align_corners"), - "QNN EP: coordinate transformation mode '", coordinate_mode.c_str(), "' not supported for Resize op.", - "Only 'align_corners' and 'half_pixel' are supported."); - - // Check for a valid "nearest_mode" if the mode is "nearest". - if (resize_mode == "nearest") { - // NOTE: QNN's ResizeNearestNeighbor operator does not have a way to specify rounding (i.e., "nearest_mode"). - // The output of the QNN ResizeNearestNeighbor operator is not always equivalent to ONNX's Resize - // operator with any single specific "nearest_mode". - // - // For some input/output shapes, QNN's ResizeNearestNeighbor is equivalent to ONNX's Resize with "round_prefer_floor". - // For other shapes, QNN's ResizeNearestNeighbor is equivalent to ONNX Resize with "round_prefer_ceil". - // - // From unit tests, I've found a relationship between input/output shapes and the equivalent ONNX "nearest_mode". - // If the new and old spatial dimensions are evenly divisible, the "nearest_mode" is "round_prefer_floor". - // Otherwise, the "nearest_mode" is "round_prefer_ceil". - // - // This relationship is probably incomplete/wrong. - // - // TODO: Ask Qualcomm what the correct "nearest_mode" should be, - // OR use QNN's own Resize operator once it works on QnnCpu. - const std::string& nearest_mode = GetOnnxAttr(node_helper, onnx_nearest_mode_attr); - ORT_RETURN_IF_NOT("floor" == nearest_mode, "QNN Resize only supports nearest_mode: floor!"); // This is wrong! - } - - auto& input_0 = node_unit.Inputs()[0]; - std::vector input_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input_0.node_arg, input_shape), - "QNN EP: Cannot get input shape for Resize op"); - - const auto& output_0 = node_unit.Outputs()[0]; - std::vector output_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output_0.node_arg, output_shape), - "QNN EP: Cannot get output shape for Resize op"); - - ORT_RETURN_IF(input_shape.size() != 4 || output_shape.size() != 4, "QNN Resize only supports 4D!"); - - ONNX_NAMESPACE::DataType input_data_type = input_0.node_arg.Type(); - ORT_RETURN_IF(input_data_type != ONNX_NAMESPACE::Utils::DataTypeUtils::ToType("float"), - "QNN EP: Data type ", input_data_type->c_str(), - " is not supported for Resize operator in CPU backend."); - - return Status::OK(); -} - -Status ResizeOpBuilder::ValidateQDQOp(QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const { - NodeAttrHelper node_helper(node_unit); - - using namespace onnxruntime::qnn::utils; // Check mode const std::string interp_mode = GetOnnxAttr(node_helper, onnx_mode_attr); - ORT_RETURN_IF_NOT(ArrayHasString(supported_modes, interp_mode), "QNN EP: Resize does not support mode ", + ORT_RETURN_IF_NOT(IsOnnxAttrModeSupported(supported_modes, interp_mode), "QNN EP: Resize does not support mode ", interp_mode.c_str()); // Check coordinate transformation mode const std::string transformation_mode = GetOnnxAttr(node_helper, onnx_coord_transf_mode_attr); - ORT_RETURN_IF_NOT(ArrayHasString(supported_coord_transf_modes, transformation_mode), + ORT_RETURN_IF_NOT(IsOnnxAttrModeSupported(supported_coord_transf_modes, transformation_mode), "QNN EP: Resize does not support coordinate_transformation_mode ", transformation_mode.c_str()); - // Check nearest mode + const auto& input_0 = node_unit.Inputs()[0]; + std::vector input_shape; + ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input_0.node_arg, input_shape), + "QNN EP: Cannot get shape for Resize input"); + const size_t input_rank = input_shape.size(); + + // Validate Resize w/ "nearest" mode. + // Translation matrix of ONNX Resize w/ "nearest" mode on HTP backend. + // Table entries correspond to the QNN operator used for the given configuration + // (Resize = QNN Resize op, RNN = QNN ResizeNearestNeighbor op, X = Unsupported). + // + // nearest_mode: + // coordinate_transformation_mode: | round_prefer_floor round_prefer_ceil floor ceil + // ----------------------------------------------------------------------------------------- + // half_pixel | Resize X RNN X + // pytorch_half_pixel | Resize X X X + // align_corners | Resize X RNN X + // asymmetric | Resize X RNN X + if (interp_mode == "nearest") { const std::string nearest_mode = GetOnnxAttr(node_helper, onnx_nearest_mode_attr); - ORT_RETURN_IF_NOT(ArrayHasString(supported_nearest_modes, nearest_mode), + ORT_RETURN_IF_NOT(IsOnnxAttrModeSupported(supported_nearest_modes, nearest_mode), "QNN EP: Resize does not support nearest_mode ", nearest_mode.c_str()); - // TODO: Support 'asymmetric' transformation mode with nearest_mode != 'floor'. - // - // QNN's ONNX converter tool translates 'nearest' + 'asymmetric' (regardless of rounding mode) - // to QNN's ResizeNearestNeighbor with {align_corners: 0, half_pixel: 0}. - // This is only accurate if the rounding mode is "floor". Need to investigate how to handle - // other rounding modes with Qualcomm. Ideally, we would use QNN's Resize operator, but it doesn't support - // the "asymmetric" coordinate transformation mode on HTP. - ORT_RETURN_IF(transformation_mode == "asymmetric" && nearest_mode != "floor", - "QNN EP: Resize with coordinate_transformation_mode 'asymmetric' and nearest_mode '", nearest_mode, - "' is not currently supported on the HTP backend."); + if (is_npu_backend) { + // QNN only supports the following nearest_mode values on HTP: + // - "round_prefer_floor" via QNN's Resize operator + // - "floor" via QNN's ResizeNearestNeighbor operator + // + // QNN validation does not throw an error if unsupported nearest_mode values are used, so we have to + // catch them here. Otherwise, accuracy is significantly degraded. + ORT_RETURN_IF_NOT(nearest_mode == "round_prefer_floor" || nearest_mode == "floor", + "QNN EP: Resize on the NPU does not support nearest_mode ", nearest_mode.c_str()); + + const bool use_resize_nn_op = nearest_mode == "floor"; + + // If HTP uses ResizeNearestNeighbor ("floor"), then the "pytorch_half_pixel" coordinate_transformation_mode + // is not supported. + ORT_RETURN_IF(use_resize_nn_op && transformation_mode == "pytorch_half_pixel", + "QNN EP: Resize on the NPU does not support the combination of nearest_mode == 'floor' ", + " and coordinate_transformation_mode == 'pytorch_half_pixel'."); + + // QNN's ResizeNearestNeighbor requires rank 4 inputs. + ORT_RETURN_IF(use_resize_nn_op && input_rank != 4, + "QNN EP: Resize on the NPU with nearest_mode == 'floor' requires an input with rank 4."); + } } - // Check that input shape has at least a rank of 3. - const auto& input_0 = node_unit.Inputs()[0]; - std::vector input_shape; - ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(input_0.node_arg, input_shape), - "QNN EP: Cannot get shape for Resize input"); - ORT_RETURN_IF(input_shape.size() < 3, "QNN EP: Resize input must have a rank >= 3."); + // Check that the input shape has at least a rank of 3 (and a max of 5 on HTP). + ORT_RETURN_IF(input_rank < 3 || (is_npu_backend && input_rank > 5), + "QNN EP: Resize input must have a rank >= 3. The maximum rank is 5 on the NPU."); const auto& output_0 = node_unit.Outputs()[0]; std::vector output_shape; ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(output_0.node_arg, output_shape), "QNN EP: Cannot get shape for Resize output"); - ORT_RETURN_IF(output_shape.size() < 3, "QNN EP: Resize output must have a rank >= 3."); + + // Check that only the spatial dimensions (width, height) are resized. The batch_size (N) and channels (C) should + // be untouched. This code runs before layout transformation, so we know that the current layout is "channel first" + // (e.g., N, C, S1, S2, ..., SN), and that the minimum rank is 3. + assert(node_unit.Domain() != kMSInternalNHWCDomain); + ORT_RETURN_IF_NOT(input_shape[0] == output_shape[0] && input_shape[1] == output_shape[1], + "QNN EP: Resize may only change the spatial dimensions."); + + if (!is_npu_backend) { + ONNX_NAMESPACE::DataType input_data_type = input_0.node_arg.Type(); + ORT_RETURN_IF(input_data_type != ONNX_NAMESPACE::Utils::DataTypeUtils::ToType("float"), + "QNN EP: Data type ", input_data_type->c_str(), + " is not supported for Resize operator in CPU backend."); + } return Status::OK(); } @@ -305,92 +228,34 @@ Status ResizeOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w std::vector&& input_names, const logging::Logger& logger, bool do_op_validation) const { - // The QNN Resize op does not currently work with the QNN cpu backend, but works with the HTP backend. Therefore, we - // currently use QNN's Resize op for quantized models and either ResizeBilinear or ResizeNearestNeighbor for - // non-quantized models. This requires separate handling for quantized models. - // TODO: Use only Resize once QNN's Resize op works in the QNN cpu backend. - bool is_quantized_node = NodeUnit::Type::QDQGroup == node_unit.UnitType(); - return is_quantized_node ? ProcessQDQOpAttrsAndOutputs(qnn_model_wrapper, node_unit, std::move(input_names), logger, do_op_validation) : ProcessOpAttrsAndOutputs(qnn_model_wrapper, node_unit, std::move(input_names), logger, do_op_validation); -} - -Status ResizeOpBuilder::ProcessOpAttrsAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const { - ORT_UNUSED_PARAMETER(logger); - NodeAttrHelper node_helper(node_unit); - const std::string resize_mode = GetOnnxAttr(node_helper, onnx_mode_attr); - std::string qnn_node_type = "ResizeNearestNeighbor"; - if ("linear" == resize_mode) { - qnn_node_type = "ResizeBilinear"; - } - - const std::string coordinate_mode = GetOnnxAttr(node_helper, onnx_coord_transf_mode_attr); - - Qnn_Scalar_t qnn_align_corners = QNN_SCALAR_INIT; - qnn_align_corners.dataType = QNN_DATATYPE_BOOL_8; - qnn_align_corners.bool8Value = static_cast(0); - - Qnn_Scalar_t qnn_half_pixel = QNN_SCALAR_INIT; - qnn_half_pixel.dataType = QNN_DATATYPE_BOOL_8; - qnn_half_pixel.bool8Value = static_cast(0); - - if ("align_corners" == coordinate_mode) { - qnn_align_corners.bool8Value = static_cast(1); - } else if ("half_pixel" == coordinate_mode) { - qnn_half_pixel.bool8Value = static_cast(1); - } - QnnParamWrapper qnn_align_corners_param(node_unit.Index(), node_unit.Name(), - QNN_OP_RESIZE_BILINEAR_PARAM_ALIGN_CORNERS, qnn_align_corners); - QnnParamWrapper qnn_half_pixel_param(node_unit.Index(), node_unit.Name(), - QNN_OP_RESIZE_BILINEAR_PARAM_HALF_PIXEL_CENTERS, qnn_half_pixel); - - std::vector param_tensor_names; - param_tensor_names.push_back(qnn_align_corners_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(qnn_align_corners_param)); - param_tensor_names.push_back(qnn_half_pixel_param.GetParamTensorName()); - qnn_model_wrapper.AddParamWrapper(std::move(qnn_half_pixel_param)); - - return ProcessOutputs(qnn_model_wrapper, node_unit, std::move(input_names), std::move(param_tensor_names), - logger, do_op_validation, qnn_node_type); -} - -Status ResizeOpBuilder::ProcessQDQOpAttrsAndOutputs(QnnModelWrapper& qnn_model_wrapper, - const NodeUnit& node_unit, - std::vector&& input_names, - const logging::Logger& logger, - bool do_op_validation) const { std::vector param_tensor_names; NodeAttrHelper node_helper(node_unit); const std::string interp_mode = GetOnnxAttr(node_helper, onnx_mode_attr); const std::string transformation_mode = GetOnnxAttr(node_helper, onnx_coord_transf_mode_attr); + const std::string nearest_mode = GetOnnxAttr(node_helper, onnx_nearest_mode_attr); + const bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); std::string qnn_op_type = "Resize"; - // Handle Resize with {mode: "nearest", coordinate_transformation_mode: "asymmetric"} uniquely. - // QNN's ONNX converter tool translates this configuration (regardless of rounding mode) - // to QNN's ResizeNearestNeighbor with {align_corners: 0, half_pixel: 0}. - // - // NOTE: This is only accurate if the rounding mode is "floor". Need to investigate how to handle - // other rounding modes with Qualcomm. Ideally, we would use QNN's Resize operator, but it doesn't support - // the "asymmetric" coordinate transformation mode on HTP. - if (interp_mode == "nearest" && transformation_mode == "asymmetric") { + // Translate Resize with {mode: "nearest", nearest_mode: "floor", coordinate_transformation_mode: XXX} to + // QNN's ResizeNearestNeighbor operator on the HTP backend. This combination of parameters is not supported on HTP + // via QNN's Resize operator. Note that QNN's ResizeNearestNeighbor operator always uses "floor" rounding. + if (is_npu_backend && interp_mode == "nearest" && nearest_mode == "floor") { qnn_op_type = "ResizeNearestNeighbor"; - // Set parameter 'align_corners' to 0 + // Parameter 'align_corners' Qnn_Scalar_t qnn_align_corners = QNN_SCALAR_INIT; qnn_align_corners.dataType = QNN_DATATYPE_BOOL_8; - qnn_align_corners.bool8Value = static_cast(0); + qnn_align_corners.bool8Value = static_cast(transformation_mode == "align_corners"); QnnParamWrapper qnn_align_corners_param(node_unit.Index(), node_unit.Name(), QNN_OP_RESIZE_BILINEAR_PARAM_ALIGN_CORNERS, qnn_align_corners); param_tensor_names.push_back(qnn_align_corners_param.GetParamTensorName()); qnn_model_wrapper.AddParamWrapper(std::move(qnn_align_corners_param)); - // Set parameter 'half_pixel_centers' to 0 + // Parameter 'half_pixel_centers' Qnn_Scalar_t qnn_half_pixel = QNN_SCALAR_INIT; qnn_half_pixel.dataType = QNN_DATATYPE_BOOL_8; - qnn_half_pixel.bool8Value = static_cast(0); + qnn_half_pixel.bool8Value = static_cast(transformation_mode == "half_pixel"); QnnParamWrapper qnn_half_pixel_param(node_unit.Index(), node_unit.Name(), QNN_OP_RESIZE_BILINEAR_PARAM_HALF_PIXEL_CENTERS, qnn_half_pixel); param_tensor_names.push_back(qnn_half_pixel_param.GetParamTensorName()); @@ -399,11 +264,12 @@ Status ResizeOpBuilder::ProcessQDQOpAttrsAndOutputs(QnnModelWrapper& qnn_model_w // Parameter 'transformation_mode' Qnn_Scalar_t qnn_transformation_mode = QNN_SCALAR_INIT; qnn_transformation_mode.dataType = QNN_DATATYPE_UINT_32; - ORT_RETURN_IF_ERROR(GetQnnModeFromString(supported_coord_transf_modes, transformation_mode, - "coordinate_transformation_mode", qnn_transformation_mode.uint32Value)); + ORT_RETURN_IF_ERROR(GetQnnModeValFromOnnxString(supported_coord_transf_modes, transformation_mode, + "coordinate_transformation_mode", + qnn_transformation_mode.uint32Value)); - QnnParamWrapper qnn_transformation_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_RESIZE_PARAM_TRANSFORMATION_MODE, - qnn_transformation_mode); + QnnParamWrapper qnn_transformation_mode_param(node_unit.Index(), node_unit.Name(), + QNN_OP_RESIZE_PARAM_TRANSFORMATION_MODE, qnn_transformation_mode); param_tensor_names.push_back(qnn_transformation_mode_param.GetParamTensorName()); qnn_model_wrapper.AddParamWrapper(std::move(qnn_transformation_mode_param)); @@ -420,7 +286,7 @@ Status ResizeOpBuilder::ProcessQDQOpAttrsAndOutputs(QnnModelWrapper& qnn_model_w // Parameter 'interpolation_mode' Qnn_Scalar_t qnn_interp_mode = QNN_SCALAR_INIT; qnn_interp_mode.dataType = QNN_DATATYPE_UINT_32; - ORT_RETURN_IF_ERROR(GetQnnModeFromString(supported_modes, interp_mode, "mode", qnn_interp_mode.uint32Value)); + ORT_RETURN_IF_ERROR(GetQnnModeValFromOnnxString(supported_modes, interp_mode, "mode", qnn_interp_mode.uint32Value)); QnnParamWrapper qnn_interp_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_RESIZE_PARAM_INTERPOLATION_MODE, qnn_interp_mode); @@ -429,11 +295,10 @@ Status ResizeOpBuilder::ProcessQDQOpAttrsAndOutputs(QnnModelWrapper& qnn_model_w // Parameter 'nearest_mode'. Processed only when 'interpolation_mode' is NEAREST(0). if (qnn_interp_mode.uint32Value == 0) { - const std::string nearest_mode = GetOnnxAttr(node_helper, onnx_nearest_mode_attr); Qnn_Scalar_t qnn_nearest_mode = QNN_SCALAR_INIT; qnn_nearest_mode.dataType = QNN_DATATYPE_UINT_32; - ORT_RETURN_IF_ERROR(GetQnnModeFromString(supported_nearest_modes, nearest_mode, "nearest_mode", - qnn_nearest_mode.uint32Value)); + ORT_RETURN_IF_ERROR(GetQnnModeValFromOnnxString(supported_nearest_modes, nearest_mode, "nearest_mode", + qnn_nearest_mode.uint32Value)); QnnParamWrapper qnn_nearest_mode_param(node_unit.Index(), node_unit.Name(), QNN_OP_RESIZE_PARAM_NEAREST_MODE, qnn_nearest_mode); diff --git a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc index 6ca36736f2f7f..047972294f78c 100644 --- a/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc +++ b/onnxruntime/core/providers/qnn/builder/opbuilder/topk.cc @@ -63,9 +63,20 @@ Status TopKOpBuilder::ExplictOpCheck(QnnModelWrapper& qnn_model_wrapper, const N auto rank = input_shape.size(); auto axis = node_helper.Get("axis", -1); - if (-1 == axis && axis != static_cast(rank - 1)) { - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "QNN TopK axis is always the last dimension"); + ORT_RETURN_IF_NOT(axis == -1 || axis == static_cast(rank - 1), + "QNN TopK's axis is always the last dimension"); + + // ONNX TopK outputs int64 indices, but the equivalent QNN op outputs uint32 indices. + // The QNN HTP backend does not generally support the int64 type, but QNN EP can just use the uint32 type + // for TopK ops within the graph. However, if the TopK op **generates** a graph output, + // then we cannot support it on the HTP backend. + bool is_npu_backend = IsNpuBackend(qnn_model_wrapper.GetQnnBackendType()); + if (is_npu_backend) { + const std::string& output_name = node_unit.Outputs()[0].node_arg.Name(); + ORT_RETURN_IF(qnn_model_wrapper.IsGraphOutput(output_name), + "QNN EP does not support TopK ops that generate a graph output."); } + return Status::OK(); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 96893f63b4540..55204abc80187 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1143,46 +1143,35 @@ bool TensorrtExecutionProvider::IsGraphCaptureEnabled() const { return cuda_graph_enable_; } -bool TensorrtExecutionProvider::IsGraphCaptured() const { - return GetPerThreadContext().IsGraphCaptured(); -} - -Status TensorrtExecutionProvider::ReplayGraph() { - return GetPerThreadContext().ReplayGraph(); -} - -void TensorrtExecutionProvider::PerThreadContext::SetGraphStream(cudaStream_t stream) { - cuda_graph_.SetStream(stream); -} - -bool TensorrtExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { +bool TensorrtExecutionProvider::IsGraphCaptureAllowed() const { return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; } -void TensorrtExecutionProvider::PerThreadContext::CaptureBegin() { +void TensorrtExecutionProvider::CaptureBegin() { cuda_graph_.Reset(); cuda_graph_.CaptureBegin(); } -void TensorrtExecutionProvider::PerThreadContext::CaptureEnd() { +void TensorrtExecutionProvider::CaptureEnd() { cuda_graph_.CaptureEnd(); is_graph_captured_ = true; } -bool TensorrtExecutionProvider::PerThreadContext::IsGraphCaptured() const { +bool TensorrtExecutionProvider::IsGraphCaptured() const { return is_graph_captured_; } -Status TensorrtExecutionProvider::PerThreadContext::ReplayGraph() { +Status TensorrtExecutionProvider::ReplayGraph() { ORT_ENFORCE(IsGraphCaptured()); // Please note that CUDAGraph::Replay() is not thread safe. - // The cuda graph object is maintained by a per thread basis, + // ORT TRT calls ReplayGraph() in compute_func() where synchromization is enforced due to lock_guard(), // therefore calling CUDAGraph::Replay() here is guaranteed to be thread safe. return cuda_graph_.Replay(); } -void TensorrtExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { - // The cuda graph object is maintained by a per thread basis, +void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { + // Please note that this function is not thread safe. + // ORT TRT calls this function in compute_func() where synchronization is enforced due to lock_guard(), // therefore following increment is guaranteed to be thread safe. ++regular_run_count_before_graph_capture_; } @@ -1213,18 +1202,6 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { if (sync_stream && external_stream_) { CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); } - - // The reason of !IsGraphCaptureEnabled(): - // If cuda graph is enabled, the per thread context will not be released - // because the per thread cuda graph needs to be maintained and replayed for - // the next run. - // The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end(): - // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), - // PerThreadContext won't be created and there is nothing to release. - if (!IsGraphCaptureEnabled() && - PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { - ReleasePerThreadContext(); - } return Status::OK(); } @@ -2384,6 +2361,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name, - &parsers_[context->node_name], &engines_[context->node_name], &builders_[context->node_name], + &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, @@ -2445,6 +2415,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorinput_shape_ranges; auto trt_builder = trt_state->builder->get(); auto trt_engine = trt_state->engine->get(); + auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr; int num_inputs = static_cast(input_indexes.size()); @@ -2502,7 +2473,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine->reset(); *(trt_state->engine) = std::unique_ptr( trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (*(trt_state->engine) == nullptr) { + if (!(*(trt_state->engine))) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); } LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path; @@ -2527,7 +2498,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine->reset(); *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size, nullptr)); - if (*(trt_state->engine) == nullptr) { + if (!(*(trt_state->engine))) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path); } @@ -2556,10 +2527,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorcontext->reset(); trt_state->engine->reset(); auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); trt_config->setMaxWorkspaceSize(*(trt_state->max_workspace_size_ptr)); @@ -2660,7 +2628,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectortrt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl; } } - if (*(trt_state->engine) == nullptr) { + if (!(*(trt_state->engine))) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine."); } trt_engine = trt_state->engine->get(); @@ -2706,32 +2674,20 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector new_context; + if (context_update) { if (trt_state->context_memory_sharing_enable) { - new_context.reset(trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContextWithoutDeviceMemory()); } else { - new_context.reset(trt_state->engine->get()->createExecutionContext()); + *(trt_state->context) = std::unique_ptr( + trt_state->engine->get()->createExecutionContext()); } - auto context_status = GetPerThreadContext().UpdateTensorRTContext(fused_node_name, std::move(new_context)); - if (!context_status) { + if (!(*(trt_state->context))) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context."); } - GetPerThreadContext().UpdateProfileShapes(fused_node_name, shape_ranges); + trt_context = trt_state->context->get(); } - // Get the reference to the IExecutionContext object that is maintained on a per thread basis. - nvinfer1::IExecutionContext& trt_context = GetPerThreadContext().GetTensorRTContext(fused_node_name); - // Get input and output binding names int total_bindings = trt_engine->getNbBindings(); std::vector buffers(total_bindings); @@ -2767,12 +2723,12 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorisShapeBinding(binding_index)) { - trt_context.setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]); + trt_context->setInputShapeBinding(binding_index, &tensor_shape_values[input_name][0]); } else { for (int j = 0, end = nb_dims; j < end; ++j) { dimensions.d[j] = static_cast(tensor_shapes[j]); } - const bool status = trt_context.setBindingDimensions(binding_index, dimensions); + const bool status = trt_context->setBindingDimensions(binding_index, dimensions); if (!status) { ORT_THROW_IF_ERROR(ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP cannot set the dynamic dimensions of a binding")); @@ -2911,7 +2867,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsecond; } - nvinfer1::Dims dimensions = trt_context.getBindingDimensions(static_cast(binding_index)); + nvinfer1::Dims dimensions = trt_context->getBindingDimensions(static_cast(binding_index)); int nb_dims = dimensions.nbDims; std::vector output_shapes(nb_dims); for (int j = 0, end = nb_dims; j < end; ++j) { @@ -3045,20 +3001,20 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector *max_context_mem_size_ptr) { *max_context_mem_size_ptr = mem_size; } - trt_context.setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); + trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get()); } // Start CUDA graph capture. // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; - GetPerThreadContext().SetGraphStream(stream); - GetPerThreadContext().CaptureBegin(); + cuda_graph_.SetStream(stream); + CaptureBegin(); } // Run TRT inference - if (!trt_context.enqueueV2(&buffers[0], stream, nullptr)) { + if (!trt_context->enqueueV2(&buffers[0], stream, nullptr)) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed."); } @@ -3089,14 +3045,14 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector* parser = nullptr; std::unique_ptr* engine = nullptr; + std::unique_ptr* context = nullptr; std::unique_ptr* builder = nullptr; std::unique_ptr* network = nullptr; std::vector> input_info; @@ -246,6 +247,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { // For those non thread safe operations, TRT EP uses (1) lock_guard or (2) PerThreadContext to make sure synchronization. std::unordered_map> parsers_; std::unordered_map> engines_; + std::unordered_map> contexts_; std::unordered_map> builders_; std::unordered_map> networks_; std::unordered_map>> input_info_; @@ -256,6 +258,21 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_map input_shape_ranges_; // The profile shape ranges that the engine is built with std::unordered_map> profiles_; + // for external stream, we need to create its cudnn/cublass handle before cuda EP enable cuda graph capture + cudnnHandle_t external_cudnn_handle_ = nullptr; + cublasHandle_t external_cublas_handle_ = nullptr; + + CUDAGraph cuda_graph_; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + // There is chance (currently only happens in CUDA EP) that the second regular run allocates GPU memory for causes like: + // (1) memory pattern is enabled. (2) arena allocation for stream. + // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs + // to allocate enough memory in Arena before graph capturing. + const int min_num_runs_before_cuda_graph_capture_ = 1; // required min regular runs before graph capture for the necessary memory allocations. + + // [Note] We don't use PerThreadContext for now since it has issue with multithreading + // // TRT or CUDA objects that must be maintained on a per thread basis will be put under this PerThreadContext data structure. // For example, TensorRT execution context and CUDA graph are the ones to be put here. class PerThreadContext final { diff --git a/onnxruntime/core/providers/vitisai/README.md b/onnxruntime/core/providers/vitisai/README.md index 15e0c804489c5..6ddb58b8d96ae 100644 --- a/onnxruntime/core/providers/vitisai/README.md +++ b/onnxruntime/core/providers/vitisai/README.md @@ -1,4 +1,4 @@ -VitsAI Execution Prividers +VitisAI Execution Provider ============================ diff --git a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc index 544e18350635d..ee8dfc6d03d12 100644 --- a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc +++ b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc @@ -34,9 +34,12 @@ static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT64); } else if (data_type->s() == "int1") { updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); + } else if (data_type->s() == "bfloat16") { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BFLOAT16); + } else if (data_type->s() == "float16") { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT16); } else { - std::cerr << "not supported data_type " << data_type->s(); - abort(); + vai_assert(false, ", not supported data_type: " + data_type->s()); } if (shape != nullptr) { for (auto i = 0; i < shape->ints_size(); ++i) { diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 31453e005272e..774df067fe347 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -53,9 +53,12 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons } for (const auto& dim : shape_proto->dim()) { - // For now we workaround dynamic shape support by assuming 1. + // WebNN doesn't support dynamic shape - use sessionOptions.freeDimensionOverrides to fix the shape. if (!dim.has_dim_value()) { - LOGS(logger, VERBOSE) << "Dynamic shape is not supported for now, assume to be 1, for input:" << input_name; + LOGS(logger, VERBOSE) << "Dynamic shape is not supported, " + << "use sessionOptions.FreeDimensionOverrides to set a fixed shape for input: " + << input_name; + return false; } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index 301927d9c658f..01e4a3c60281f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -46,7 +46,7 @@ class BaseOpBuilder : public IOpBuilder { // We still set the mininal supported opset to 1 as we couldn't // get the model opset version at this stage. virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; } - virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 19; } + virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 20; } private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index e7e3cee21c956..b207b804416aa 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -37,13 +37,18 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto input_size = input_shape.size(); // WebNN Softmax only support 2d input shape, reshape input to 2d. if (input_size != 2) { - int32_t new_shape_0 = SafeInt(input_shape.data()[0]); - for (size_t i = 1; i < input_size - 1; i++) { - new_shape_0 *= input_shape.data()[i]; - } - emscripten::val new_shape = emscripten::val::array(); - new_shape.call("push", new_shape_0); - new_shape.call("push", static_cast(input_shape.back())); + NodeAttrHelper helper(node); + int32_t axis = helper.Get("axis", 1); + if (node.SinceVersion() >= 13) + // Opset 13 has default value -1. + axis = helper.Get("axis", -1); + // Coerce the input into a 2-dimensional tensor with dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. + axis = static_cast(HandleNegativeAxis(axis, input_size)); + int32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, + 1, std::multiplies())); + int32_t second_dim = static_cast(std::reduce(input_shape.begin() + axis, input_shape.end(), + 1, std::multiplies())); + emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); input = model_builder.GetBuilder().call("reshape", input, new_shape); } output = model_builder.GetBuilder().call("softmax", input); @@ -76,9 +81,10 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return false; } NodeAttrHelper helper(node); - const int32_t axis = helper.Get("axis", 1); - // WebNN softmax only support input axis 1 - if (axis != 1 && axis != -1) { + const int64_t axis = helper.Get("axis", 1); + // WebNN softmax only support reshape for the last axis or version before 13. + // TODO: support opset 13 by composing into: Exp(input) / ReduceSum(Exp(input), axis=axis, keepdims=1). + if (axis != -1 && axis != input_shape.size() - 1 && node.SinceVersion() >= 13) { LOGS(logger, VERBOSE) << "SoftMax only support axis 1 or -1, input axis: " << axis; return false; } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 14ca4f1a1e674..2eae8cebbbd66 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -218,12 +218,9 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i } else { dims.reserve(shape.size()); for (const auto& dim : shape) { - if (!dim.has_dim_value()) { - // FIXME: support dyanmic shape. - dims.push_back(1); - } else { - dims.push_back(SafeInt(dim.dim_value())); - } + // dim_param free dimensions should have already been excluded by IsInputSupported(). + assert(dim.has_dim_value()); + dims.push_back(SafeInt(dim.dim_value())); } } } diff --git a/onnxruntime/core/providers/xnnpack/nn/resize.cc b/onnxruntime/core/providers/xnnpack/nn/resize.cc index 672b2597279db..76c6b6acbfe32 100644 --- a/onnxruntime/core/providers/xnnpack/nn/resize.cc +++ b/onnxruntime/core/providers/xnnpack/nn/resize.cc @@ -331,7 +331,13 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 13, 17, kXnnpackExecution DataTypeImpl::GetTensorType()}), Resize); -ONNX_OPERATOR_KERNEL_EX(Resize, kOnnxDomain, 18, kXnnpackExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 18, 18, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + Resize); + +ONNX_OPERATOR_KERNEL_EX(Resize, kOnnxDomain, 19, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index ba577ac38d48c..494c718cde081 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -46,7 +46,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWC class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 10, 10, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, 17, Resize); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 19, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); @@ -84,7 +85,9 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo< ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, Softmax)>, BuildKernelCreateInfo< - ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, Resize)>, + ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 19, Resize)>, + BuildKernelCreateInfo< + ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, 18, Resize)>, BuildKernelCreateInfo< ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, 17, Resize)>, BuildKernelCreateInfo< diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 14c284d7bbdec..d8256646f9707 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -27,6 +27,10 @@ static constexpr uint32_t min_ort_version_with_optional_io_support = 8; static constexpr uint32_t min_ort_version_with_variadic_io_support = 14; #endif +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +static constexpr uint32_t min_ort_version_with_compute_v2_support = 17; +#endif + #if !defined(DISABLE_FLOAT8_TYPES) #define SUPPORTED_TENSOR_TYPES DataTypeImpl::AllTensorTypesIRv9() #else @@ -424,7 +428,8 @@ struct CustomOpKernel : OpKernel { ORT_THROW("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op)); } - if (op_.version > 15 && op_.KernelCompute == 0) { + if (op_.version >= min_ort_version_with_compute_v2_support && + op_.CreateKernelV2) { op_kernel_ = nullptr; Ort::ThrowOnError( op_.CreateKernelV2( @@ -443,13 +448,14 @@ struct CustomOpKernel : OpKernel { } Status Compute(OpKernelContext* ctx) const override { - if (op_.version > 15 && op_.KernelCompute == 0) { + if (op_.version >= min_ort_version_with_compute_v2_support && + op_.KernelComputeV2) { auto status_ptr = op_.KernelComputeV2(op_kernel_, reinterpret_cast(ctx)); return ToStatus(status_ptr); + } else { + op_.KernelCompute(op_kernel_, reinterpret_cast(ctx)); + return Status::OK(); } - - op_.KernelCompute(op_kernel_, reinterpret_cast(ctx)); - return Status::OK(); } private: diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 5a2a6efb6df4b..21c8fbe0cd2c9 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1005,18 +1005,22 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool layout_transformation::TransformLayoutForEP(graph_to_transform, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn)); - if (modified) { - ORT_RETURN_IF_ERROR_SESSIONID_( - graph_transformer_mgr_.ApplyTransformers(graph_to_transform, TransformerLevel::Level1, *session_logger_)); - - // debug the graph after the L1 transformers have run against any layout transformation changes. - // this is prior to GraphPartitioner::GetCapabilityForEP calling IExecutionProvider::GetCapability the second - // time to validate the EP that requested the layout transformation can take all nodes using the new layout. - // if that fails, this allows debugging the graph used in that GetCapability call. - if (debug_graph_fn) { - debug_graph_fn(graph_to_transform); - } - } + // Previously we ran the L1 transformers to handle constant folding of any initializers that were transposed in + // a QDQ format model. The transpose optimizer can now look past DQ nodes to directly update initializers which + // takes care of most models without needing this. + // + // if (modified) { + // ORT_RETURN_IF_ERROR_SESSIONID_( + // graph_transformer_mgr_.ApplyTransformers(graph_to_transform, TransformerLevel::Level1, *session_logger_)); + // + // debug the graph after the L1 transformers have run against any layout transformation changes. + // this is prior to GraphPartitioner::GetCapabilityForEP calling IExecutionProvider::GetCapability the second + // time to validate the EP that requested the layout transformation can take all nodes using the new layout. + // if that fails, this allows debugging the graph used in that GetCapability call. + // if (debug_graph_fn) { + // debug_graph_fn(graph_to_transform); + //} + //} return Status::OK(); }; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 4c0adcdd374aa..60b6296f7f539 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2711,9 +2711,8 @@ static constexpr OrtApi ort_api_1_to_17 = { &OrtApis::GetTensorRTProviderOptionsByName, &OrtApis::UpdateCUDAProviderOptionsWithValue, &OrtApis::GetCUDAProviderOptionsByName, - // End of Version 16 - DO NOT MODIFY ABOVE (see above text for more information) - &OrtApis::KernelContext_GetResource, + // End of Version 16 - DO NOT MODIFY ABOVE (see above text for more information) }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -2742,7 +2741,7 @@ static_assert(offsetof(OrtApi, ReleaseKernelInfo) / sizeof(void*) == 218, "Size static_assert(offsetof(OrtApi, ReleaseCANNProviderOptions) / sizeof(void*) == 224, "Size of version 13 API cannot change"); static_assert(offsetof(OrtApi, GetSessionConfigEntry) / sizeof(void*) == 238, "Size of version 14 API cannot change"); static_assert(offsetof(OrtApi, GetBuildInfoString) / sizeof(void*) == 254, "Size of version 15 API cannot change"); -static_assert(offsetof(OrtApi, GetCUDAProviderOptionsByName) / sizeof(void*) == 264, "Size of version 16 API cannot change"); +static_assert(offsetof(OrtApi, KernelContext_GetResource) / sizeof(void*) == 265, "Size of version 16 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.17.0", diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 4124822adef1f..bcc6f15129231 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -438,7 +438,7 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi # Tensorrt can fall back to CUDA if it's explicitly assigned. All others fall back to CPU. if "TensorrtExecutionProvider" in available_providers: - if any( + if providers and any( provider == "CUDAExecutionProvider" or (isinstance(provider, tuple) and provider[0] == "CUDAExecutionProvider") for provider in providers @@ -448,7 +448,7 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi self._fallback_providers = ["CPUExecutionProvider"] # MIGraphX can fall back to ROCM if it's explicitly assigned. All others fall back to CPU. elif "MIGraphXExecutionProvider" in available_providers: - if any( + if providers and any( provider == "ROCMExecutionProvider" or (isinstance(provider, tuple) and provider[0] == "ROCMExecutionProvider") for provider in providers @@ -463,14 +463,6 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi providers, provider_options = check_and_normalize_provider_args( providers, provider_options, available_providers ) - if not providers and len(available_providers) > 1: - self.disable_fallback() - raise ValueError( - f"This ORT build has {available_providers} enabled. " - "Since ORT 1.9, you are required to explicitly set " - "the providers parameter when instantiating InferenceSession. For example, " - f"onnxruntime.InferenceSession(..., providers={available_providers}, ...)" - ) session_options = self._sess_options if self._sess_options else C.get_default_session_options() if self._model_path: diff --git a/onnxruntime/python/onnxruntime_pybind_module.cc b/onnxruntime/python/onnxruntime_pybind_module.cc index f320707697c9e..6824a5d0bf98f 100644 --- a/onnxruntime/python/onnxruntime_pybind_module.cc +++ b/onnxruntime/python/onnxruntime_pybind_module.cc @@ -10,6 +10,12 @@ namespace onnxruntime { namespace python { namespace py = pybind11; +#if defined(USE_MPI) && defined(ORT_USE_NCCL) +static constexpr bool HAS_COLLECTIVE_OPS = true; +#else +static constexpr bool HAS_COLLECTIVE_OPS = false; +#endif + void CreateInferencePybindStateModule(py::module& m); PYBIND11_MODULE(onnxruntime_pybind11_state, m) { @@ -23,6 +29,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { m.def("get_version_string", []() -> std::string { return ORT_VERSION; }); m.def("get_build_info", []() -> std::string { return ORT_BUILD_INFO; }); + m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; }); } } // namespace python } // namespace onnxruntime diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index bdf00f21100bf..26e74a6dfbac9 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -22,7 +22,7 @@ class TensorData: - _allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges"]) + _allowed = frozenset(["avg", "std", "lowest", "highest", "hist", "hist_edges", "bins"]) def __init__(self, **kwargs): for k, v in kwargs.items(): @@ -55,7 +55,7 @@ def __init__(self, calibration_method, data: Dict[str, Union[TensorData, Tuple]] self.data[k] = TensorData(lowest=v[0], highest=v[1]) continue if len(v) == 4: - self.data[k] = TensorData(lowest=v[0], highest=v[1], histogram=v[2], bins=v[3]) + self.data[k] = TensorData(lowest=v[0], highest=v[1], hist=v[2], bins=v[3]) continue raise TypeError(f"Unexpected tuple for {k:r}, it has {len(v)} elements: {v}.") if not isinstance(v, TensorData): diff --git a/onnxruntime/python/tools/quantization/operators/conv.py b/onnxruntime/python/tools/quantization/operators/conv.py index d23459b478e6a..23f9eaf4b0e0b 100644 --- a/onnxruntime/python/tools/quantization/operators/conv.py +++ b/onnxruntime/python/tools/quantization/operators/conv.py @@ -157,7 +157,7 @@ def quantize(self): nodes, ) = self.quantizer.quantize_activation(node, [0]) quant_weight_tuple = self.quantizer.quantize_weight_per_channel( - node.input[1], onnx_proto.TensorProto.INT8, 0 + node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType? ) quantized_input_names.append(quant_weight_tuple[0]) zero_point_names.append(quant_weight_tuple[1]) diff --git a/onnxruntime/python/tools/quantization/operators/lstm.py b/onnxruntime/python/tools/quantization/operators/lstm.py index 7e91f9b76ca36..90a52cb528b32 100644 --- a/onnxruntime/python/tools/quantization/operators/lstm.py +++ b/onnxruntime/python/tools/quantization/operators/lstm.py @@ -47,10 +47,10 @@ def quantize(self): R.dims[0] = R_num_dir * R_4_hidden_size quant_input_weight_tuple = self.quantizer.quantize_weight_per_channel( - node.input[1], onnx_proto.TensorProto.INT8, 0 + node.input[1], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType? ) quant_recurrent_weight_tuple = self.quantizer.quantize_weight_per_channel( - node.input[2], onnx_proto.TensorProto.INT8, 0 + node.input[2], onnx_proto.TensorProto.INT8, 0 # self.quantizer.weight_qType? ) W_quant_weight = model.get_initializer(quant_input_weight_tuple[0]) # noqa: N806 diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index e595b580b20df..5c97dd20cf507 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -283,7 +283,13 @@ def _add_qdq_pair_for_initializer(self, weight_proto, tensor_type, axis=None): raise ValueError("Per-Channel support with QDQ format requires onnx opset version 13 or above.") q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel( weight_name, - self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType, + # Quantization type is forced to be TensorProto.INT8. + # when the expected value would be (see below) + # self.weight_qType if tensor_type is QDQQuantTensorType.WEIGHT else self.activation_qType. + # QLinearConv expects to have a unique value for all channels. + # This code does not enforce that but it is necessarily the case when the + # quantization is symmetric (as for INT8). + onnx_proto.TensorProto.INT8, axis, keep_float_weight=self.add_qdq_pair_to_weight, ) diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 74e54c3f1fa37..739e399042bf3 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -505,7 +505,7 @@ def apply_plot(hist, hist_edges): plt.show() -def write_calibration_table(calibration_cache): +def write_calibration_table(calibration_cache, dir="."): """ Helper function to write calibration table to files. """ @@ -519,7 +519,7 @@ def write_calibration_table(calibration_cache): logging.info(f"calibration cache: {calibration_cache}") - with open("calibration.json", "w") as file: + with open(os.path.join(dir, "calibration.json"), "w") as file: file.write(json.dumps(calibration_cache)) # use `json.loads` to do the reverse # Serialize data using FlatBuffers @@ -551,7 +551,7 @@ def write_calibration_table(calibration_cache): builder.Finish(cal_table) buf = builder.Output() - with open("calibration.flatbuffers", "wb") as file: + with open(os.path.join(dir, "calibration.flatbuffers"), "wb") as file: file.write(buf) # Deserialize data (for validation) @@ -564,7 +564,7 @@ def write_calibration_table(calibration_cache): logging.info(key_value.Value()) # write plain text - with open("calibration.cache", "w") as file: + with open(os.path.join(dir, "calibration.cache"), "w") as file: for key in sorted(calibration_cache.keys()): value = calibration_cache[key] s = key + " " + str(max(abs(value[0]), abs(value[1]))) diff --git a/onnxruntime/python/tools/quantization/shape_inference.py b/onnxruntime/python/tools/quantization/shape_inference.py index eff3dc0bcdc35..b7d4726610387 100644 --- a/onnxruntime/python/tools/quantization/shape_inference.py +++ b/onnxruntime/python/tools/quantization/shape_inference.py @@ -99,7 +99,10 @@ def quant_pre_process( sess_option = onnxruntime.SessionOptions() sess_option.optimized_model_filepath = opt_model_path sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC - _ = onnxruntime.InferenceSession(input_model_path, sess_option, providers=["CPUExecutionProvider"]) + sess = onnxruntime.InferenceSession(input_model_path, sess_option, providers=["CPUExecutionProvider"]) + # Close the session to avoid the cleanup error on Windows for temp folders + # https://github.com/microsoft/onnxruntime/issues/17627 + del sess except Exception: logger.error( "ONNX Runtime Model Optimization Failed! Consider rerun with option `--skip_optimization'." diff --git a/onnxruntime/python/tools/tensorrt/perf/benchmark.py b/onnxruntime/python/tools/tensorrt/perf/benchmark.py index d440cafb23236..acd836c915910 100644 --- a/onnxruntime/python/tools/tensorrt/perf/benchmark.py +++ b/onnxruntime/python/tools/tensorrt/perf/benchmark.py @@ -2176,7 +2176,7 @@ def parse_arguments(): required=False, default=True, action="store_true", - help="Inlcude Float16 into benchmarking.", + help="Include Float16 into benchmarking.", ) parser.add_argument("--trtexec", required=False, default=None, help="trtexec executable path.") diff --git a/onnxruntime/python/tools/transformers/convert_to_packing_mode.py b/onnxruntime/python/tools/transformers/convert_to_packing_mode.py index 0b8dbdcdd9638..4da97f0de7bed 100644 --- a/onnxruntime/python/tools/transformers/convert_to_packing_mode.py +++ b/onnxruntime/python/tools/transformers/convert_to_packing_mode.py @@ -67,7 +67,7 @@ def _try_getting_last_layernorm(self) -> Union[NodeProto, None]: last_layernorm_node = node return last_layernorm_node - def _are_attentions_supportted(self) -> bool: + def _are_attentions_supported(self) -> bool: raise NotImplementedError() def _insert_removepadding_node(self, inputs: List[str], outputs: List[str]) -> None: @@ -105,7 +105,7 @@ def _get_input_to_remove_padding(self, first_attention_node) -> Union[str, None] def convert(self, use_symbolic_shape_infer: bool = True) -> None: logger.debug("start converting to packing model...") - if not self._are_attentions_supportted(): + if not self._are_attentions_supported(): return attention_mask = self._try_getting_attention_mask() @@ -164,7 +164,7 @@ class PackingAttention(PackingAttentionBase): def __init__(self, model: OnnxModel): super().__init__(model, Operators.ATTENTION) - def _are_attentions_supportted(self) -> bool: + def _are_attentions_supported(self) -> bool: for node in self.attention_nodes: if OnnxModel.get_node_attribute(node, "past_present_share_buffer") is not None: return False @@ -237,7 +237,7 @@ def _check_empty_output(self, node, index: int, name: str): return False return True - def _are_attentions_supportted(self) -> bool: + def _are_attentions_supported(self) -> bool: for node in self.attention_nodes: for attr in node.attribute: if attr.name not in ["num_heads", "mask_filter_value", "scale"]: diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 40f2aee875382..a7460157ba409 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -78,7 +78,15 @@ def process_mask(self, input: str) -> str: # ReduceSum-13: axes is moved from attribute to input axes_name = "ort_const_1_reduce_sum_axes" if self.model.get_initializer(axes_name) is None: - self.add_initializer(name=axes_name, data_type=TensorProto.INT64, dims=[1], vals=[1], raw=False) + self.model.add_initializer( + helper.make_tensor( + name=axes_name, + data_type=TensorProto.INT64, + dims=[1], + vals=[1], + raw=False, + ) + ) mask_index_node = helper.make_node( "ReduceSum", inputs=[input_name, axes_name], diff --git a/onnxruntime/test/contrib_ops/qordered_attention_test.cc b/onnxruntime/test/contrib_ops/qordered_attention_test.cc index 24e4bff528285..1dd0162ad722f 100644 --- a/onnxruntime/test/contrib_ops/qordered_attention_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_attention_test.cc @@ -240,12 +240,6 @@ static std::vector transpose(const T& src, size_t h, size_t w) { } TEST(QOrderedTest, Attention_WithData_ROW_ORDER) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - // Needs Turing architecture if (NeedSkipIfCudaArchLowerThan(750) || NeedSkipIfCudaArchGreaterEqualThan(800)) { return; diff --git a/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc b/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc index 55209d9422fdd..06fe42ca989d7 100644 --- a/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc @@ -34,12 +34,6 @@ static void run_qordered_longformer_attention_op_test( const int64_t head_size, const int64_t window, int64_t input_hidden_size = 0) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - // Needs Turing architecture if (NeedSkipIfCudaArchLowerThan(750) || NeedSkipIfCudaArchGreaterEqualThan(800)) { return; diff --git a/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc b/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc index e5b3d59ef86e3..e3905db6355d9 100644 --- a/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc @@ -21,12 +21,6 @@ static void RunQOrdered_MatMul_Test( OrderCublasLt weight_order, float scale_A, float scale_B, float scale_C, float scale_Y, bool add_bias = false, bool broadcast_c_batch = false, bool per_channel = false) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - // Needs Turing architecture if (NeedSkipIfCudaArchLowerThan(750) || NeedSkipIfCudaArchGreaterEqualThan(800)) { return; diff --git a/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc b/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc index 15e97751acf2d..0f3f702695b80 100644 --- a/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc @@ -73,12 +73,6 @@ static void RunQOrdered_Quantize_Test( std::vector const& shape, OrderCublasLt order_q, float scale) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - auto qvec = QuantizeTransform(shape, scale, fvec, order_q); std::vector> execution_providers; @@ -153,12 +147,6 @@ static void RunQOrdered_Dequantize_Test( std::vector const& shape, OrderCublasLt order_q, float scale) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - auto fvec = DequantizeTransform(shape, scale, qvec, order_q); std::vector> execution_providers; diff --git a/onnxruntime/test/contrib_ops/tokenizer_test.cc b/onnxruntime/test/contrib_ops/tokenizer_test.cc index 4daf9dac886d0..b8fb964b86c13 100644 --- a/onnxruntime/test/contrib_ops/tokenizer_test.cc +++ b/onnxruntime/test/contrib_ops/tokenizer_test.cc @@ -10,7 +10,7 @@ namespace test { namespace tokenizer_test { const std::string start_mark{0x2}; const std::string end_mark{0x3}; -const std::string padval(u8"0xdeadbeaf"); +const std::string padval("0xdeadbeaf"); constexpr const char* domain = onnxruntime::kMSDomain; constexpr int opset_ver = 1; @@ -220,7 +220,7 @@ TEST(ContribOpTest, TokenizerCharLevel_CyrillicCharsWithMarkersC) { InitTestAttr(test, true, {""}, 1); std::vector dims{2}; - std::vector input{u8"ÐбÑурд", u8"Кома"}; + std::vector input{"ÐбÑурд", "Кома"}; test.AddInput("T", dims, input); std::vector output_dims(dims); @@ -229,10 +229,10 @@ TEST(ContribOpTest, TokenizerCharLevel_CyrillicCharsWithMarkersC) { output_dims.push_back(int64_t(6 + 2)); std::vector output{ start_mark, - u8"Ð", u8"б", u8"Ñ", u8"у", u8"Ñ€", u8"д", + "Ð", "б", "Ñ", "у", "Ñ€", "д", end_mark, start_mark, - u8"К", u8"о", u8"м", u8"а", + "К", "о", "м", "а", end_mark, padval, padval}; @@ -254,7 +254,7 @@ TEST(ContribOpTest, TokenizerCharLevel_MixedCharsWithMarkersC) { InitTestAttr(test, true, {""}, 1); std::vector dims{2}; - std::vector input{u8"ÐбÑу中文", u8"Коñó"}; + std::vector input{"ÐбÑу中文", "Коñó"}; test.AddInput("T", dims, input); std::vector output_dims(dims); @@ -263,10 +263,10 @@ TEST(ContribOpTest, TokenizerCharLevel_MixedCharsWithMarkersC) { output_dims.push_back(int64_t(6 + 2)); std::vector output{ start_mark, - u8"Ð", u8"б", u8"Ñ", u8"у", u8"中", u8"æ–‡", + "Ð", "б", "Ñ", "у", "中", "æ–‡", end_mark, start_mark, - u8"К", u8"о", u8"ñ", u8"ó", + "К", "о", "ñ", "ó", end_mark, padval, padval}; @@ -285,7 +285,7 @@ TEST(ContribOpTest, TokenizerCharLevel_EmptyOutputC) { InitTestAttr(test, true, {""}, 1); std::vector dims{2}; - std::vector input{u8"", u8""}; + std::vector input{"", ""}; test.AddInput("T", dims, input); std::vector output_dims(dims); @@ -306,7 +306,7 @@ TEST(ContribOpTest, TokenizerCharLevel_EmptyOutputNC) { InitTestAttr(test, true, {""}, 1); std::vector dims{2, 2}; - std::vector input{u8"", u8"", u8"", u8""}; + std::vector input{"", "", "", ""}; test.AddInput("T", dims, input); std::vector output_dims(dims); @@ -325,13 +325,13 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersC) { // [C] dimensions // Output [C][D] { - std::string sepexp = u8"(у|ñ)"; + std::string sepexp = "(у|ñ)"; OpTester test("Tokenizer", opset_ver, domain); InitTestAttr(test, true, {sepexp}, 1); std::vector dims{2}; - std::vector input{u8"ÐбÑу中文", u8"Коñó"}; + std::vector input{"ÐбÑу中文", "Коñó"}; test.AddInput("T", dims, input); std::vector output_dims(dims); @@ -339,10 +339,10 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersC) { output_dims.push_back(int64_t(2 + 2)); std::vector output{ start_mark, - u8"ÐбÑ", u8"中文", + "ÐбÑ", "中文", end_mark, start_mark, - u8"Ко", u8"ó", + "Ко", "ó", end_mark}; test.AddOutput("Y", output_dims, output); @@ -355,13 +355,13 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersCompleteMatchEmpt // Test entire separators match so we get nothing // in the output { - std::string sepexp = u8"(ÐбÑу中文)|(Коñó)"; + std::string sepexp = "(ÐбÑу中文)|(Коñó)"; OpTester test("Tokenizer", opset_ver, domain); InitTestAttr(test, true, {sepexp}, 1); std::vector dims{2}; - std::vector input{u8"ÐбÑу中文", u8"Коñó"}; + std::vector input{"ÐбÑу中文", "Коñó"}; test.AddInput("T", dims, input); std::vector output_dims(dims); @@ -378,13 +378,13 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersCompleteMatchEmpt TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersStartMatchC) { // Match the start { - std::string sepexp = u8"(Ð)|(К)"; + std::string sepexp = "(Ð)|(К)"; OpTester test("Tokenizer", opset_ver, domain); InitTestAttr(test, true, {sepexp}, 1); std::vector dims{2}; - std::vector input{u8"ÐбÑу中文", u8"Коñó"}; + std::vector input{"ÐбÑу中文", "Коñó"}; test.AddInput("T", dims, input); std::vector output_dims(dims); @@ -392,10 +392,10 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersStartMatchC) { output_dims.push_back(int64_t(3)); std::vector output{ start_mark, - u8"бÑу中文", + "бÑу中文", end_mark, start_mark, - u8"оñó", + "оñó", end_mark}; test.AddOutput("Y", output_dims, output); @@ -407,13 +407,13 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersStartMatchC) { TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEndMatchC) { // Match the end { - std::string sepexp = u8"(æ–‡)|(ó)"; + std::string sepexp = "(æ–‡)|(ó)"; OpTester test("Tokenizer", opset_ver, domain); InitTestAttr(test, true, {sepexp}, 1); std::vector dims{2}; - std::vector input{u8"ÐбÑу中文", u8"Коñó"}; + std::vector input{"ÐбÑу中文", "Коñó"}; test.AddInput("T", dims, input); std::vector output_dims(dims); @@ -421,10 +421,10 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEndMatchC) { output_dims.push_back(int64_t(3)); std::vector output{ start_mark, - u8"ÐбÑу中", + "ÐбÑу中", end_mark, start_mark, - u8"Коñ", + "Коñ", end_mark}; test.AddOutput("Y", output_dims, output); @@ -436,13 +436,13 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEndMatchC) { TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEndMatchAtLeast4CharsC) { // Match the end, require at least 4 chars { - std::string sepexp = u8"(æ–‡)|(ó)"; + std::string sepexp = "(æ–‡)|(ó)"; OpTester test("Tokenizer", opset_ver, domain); InitTestAttr(test, true, {sepexp}, 4); std::vector dims{2}; - std::vector input{u8"ÐбÑу中文", u8"Коñó"}; + std::vector input{"ÐбÑу中文", "Коñó"}; test.AddInput("T", dims, input); std::vector output_dims(dims); @@ -451,7 +451,7 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEndMatchAtLeast4C output_dims.push_back(int64_t(3)); std::vector output{ start_mark, - u8"ÐбÑу中", + "ÐбÑу中", end_mark, start_mark, end_mark, @@ -466,13 +466,13 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEndMatchAtLeast4C TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEmptyInputEmptyOutputC) { // Empty input for [C] should produce [C][0] { - std::string sepexp = u8"(æ–‡)|(ó)"; + std::string sepexp = "(æ–‡)|(ó)"; OpTester test("Tokenizer", opset_ver, domain); InitTestAttr(test, true, {sepexp}, 4); std::vector dims{2}; - std::vector input{u8"", u8""}; + std::vector input{"", ""}; test.AddInput("T", dims, input); std::vector output_dims(dims); @@ -488,13 +488,13 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEmptyInputEmptyOu TEST(ContribOpTest, TokenizerWithSeparators_MixCharsWithMarkersEmptyInputEmptyOutputNC) { // Empty input for [N][C] should produce [N][C][0] { - std::string sepexp = u8"(æ–‡)|(ó)"; + std::string sepexp = "(æ–‡)|(ó)"; OpTester test("Tokenizer", opset_ver, domain); InitTestAttr(test, true, {sepexp}, 4); std::vector dims{2, 2}; - std::vector input{u8"", u8"æ–‡", u8"ó", u8""}; + std::vector input{"", "æ–‡", "ó", ""}; test.AddInput("T", dims, input); std::vector output_dims(dims); @@ -514,20 +514,20 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapSh { // In this case the first pattern must match first // and there would be no match for the second - std::vector separators = {u8"Ñу", u8"ÐбÑу"}; + std::vector separators = {"Ñу", "ÐбÑу"}; OpTester test("Tokenizer", opset_ver, domain); InitTestAttr(test, false, separators, 1); std::vector dims{1}; - std::vector input{u8"ÐбÑу中文"}; + std::vector input{"ÐбÑу中文"}; test.AddInput("T", dims, input); std::vector output_dims(dims); // must split in 2 with no two middle characters output_dims.push_back(int64_t(2)); std::vector output{ - u8"Ðб", u8"中文"}; + "Ðб", "中文"}; test.AddOutput("Y", output_dims, output); @@ -543,21 +543,21 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapLo // In this case the first pattern must match first // and there would be no match for the second std::vector separators = { - u8"ÐбÑу", - u8"Ñу"}; + "ÐбÑу", + "Ñу"}; OpTester test("Tokenizer", opset_ver, domain); InitTestAttr(test, false, separators, 1); std::vector dims{1}; - std::vector input{u8"ÐбÑу中文"}; + std::vector input{"ÐбÑу中文"}; test.AddInput("T", dims, input); std::vector output_dims(dims); // Must drop the beginning of the word that // also contains the second separator output_dims.push_back(int64_t(1)); - std::vector output{u8"中文"}; + std::vector output{"中文"}; test.AddOutput("Y", output_dims, output); @@ -573,21 +573,21 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapLo // In this case the first pattern must match first // and there would be no match for the second std::vector separators = { - u8"ÐбÑу", - u8"Ñу"}; + "ÐбÑу", + "Ñу"}; OpTester test("Tokenizer", opset_ver, domain); InitTestAttr(test, false, separators, 1); std::vector dims{1}; - std::vector input{u8"ÐбÑуÑуÑу中文"}; + std::vector input{"ÐбÑуÑуÑу中文"}; test.AddInput("T", dims, input); std::vector output_dims(dims); // Must drop the beginning of the word that // also contains the second separator output_dims.push_back(int64_t(1)); - std::vector output{u8"中文"}; + std::vector output{"中文"}; test.AddOutput("Y", output_dims, output); @@ -605,21 +605,21 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharsNoMarkersSeparatorsOverlapin // so the earlier match for the first wins. // and there would be no match for the second std::vector separators = { - u8"уÑу", - u8"ÐбÑу"}; + "уÑу", + "ÐбÑу"}; OpTester test("Tokenizer", opset_ver, domain); InitTestAttr(test, false, separators, 1); std::vector dims{1}; - std::vector input{u8"ÐбÑуÑуÑу中文"}; + std::vector input{"ÐбÑуÑуÑу中文"}; test.AddInput("T", dims, input); std::vector output_dims(dims); // Must drop the beginning of the word that // also contains the second separator output_dims.push_back(int64_t(2)); - std::vector output{u8"ÐбÑ", u8"Ñу中文"}; + std::vector output{"ÐбÑ", "Ñу中文"}; test.AddOutput("Y", output_dims, output); @@ -633,14 +633,14 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharCommonPrefixC) { // [C] dimensions // Output [C][D] std::vector separators = { - u8";", - u8";;;"}; + ";", + ";;;"}; OpTester test("Tokenizer", opset_ver, domain); InitTestAttr(test, true, separators, 1); std::vector dims{4}; - std::vector input{u8"a;b", u8"a;;;b", u8"b;c;;;d;e", u8"a;;b;;;c"}; + std::vector input{"a;b", "a;;;b", "b;c;;;d;e", "a;;b;;;c"}; test.AddInput("T", dims, input); std::vector output_dims(dims); @@ -648,27 +648,27 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharCommonPrefixC) { output_dims.push_back(int64_t(6)); std::vector output{ start_mark, - u8"a", - u8"b", + "a", + "b", end_mark, padval, padval, start_mark, - u8"a", - u8"b", + "a", + "b", end_mark, padval, padval, start_mark, - u8"b", - u8"c", - u8"d", - u8"e", + "b", + "c", + "d", + "e", end_mark, start_mark, - u8"a", - u8"b", - u8"c", + "a", + "b", + "c", end_mark, padval, }; @@ -679,27 +679,27 @@ TEST(ContribOpTest, TokenizerWithSeparators_MixCharCommonPrefixC) { TEST(ContribOpTest, TokenizerExpression_RegEx) { OpTester test("Tokenizer", opset_ver, domain); - const std::string tokenexp(u8"a."); + const std::string tokenexp("a."); InitTestAttr(test, true, {}, 1, tokenexp); std::vector dims{4}; - std::vector input{u8"a;b", u8"a;;;b", u8"b;c;;;d;e", u8"a;;b;;;c"}; + std::vector input{"a;b", "a;;;b", "b;c;;;d;e", "a;;b;;;c"}; test.AddInput("T", dims, input); std::vector output_dims(dims); output_dims.push_back(int64_t(3)); std::vector output{ start_mark, - u8"a;", + "a;", end_mark, start_mark, - u8"a;", + "a;", end_mark, start_mark, end_mark, padval, start_mark, - u8"a;", + "a;", end_mark, }; @@ -709,11 +709,11 @@ TEST(ContribOpTest, TokenizerExpression_RegEx) { TEST(ContribOpTest, TokenizerExpression_RegRep) { OpTester test("Tokenizer", opset_ver, domain); - const std::string tokenexp(u8"c;+"); + const std::string tokenexp("c;+"); InitTestAttr(test, true, {}, 1, tokenexp); std::vector dims{4}; - std::vector input{u8"a;b", u8"a;;;b", u8"b;c;;;d;e", u8"a;;b;;;c"}; + std::vector input{"a;b", "a;;;b", "b;c;;;d;e", "a;;b;;;c"}; test.AddInput("T", dims, input); std::vector output_dims(dims); @@ -726,7 +726,7 @@ TEST(ContribOpTest, TokenizerExpression_RegRep) { end_mark, padval, start_mark, - u8"c;;;", + "c;;;", end_mark, start_mark, end_mark, @@ -738,31 +738,31 @@ TEST(ContribOpTest, TokenizerExpression_RegRep) { TEST(ContribOpTest, TokenizerExpression_Grouping) { OpTester test("Tokenizer", opset_ver, domain); - const std::string tokenexp(u8"(a;)|(b;)"); + const std::string tokenexp("(a;)|(b;)"); InitTestAttr(test, true, {}, 1, tokenexp); std::vector dims{4}; - std::vector input{u8"a;b", u8"a;;;b", u8"b;c;;;d;e", u8"a;;b;;;c"}; + std::vector input{"a;b", "a;;;b", "b;c;;;d;e", "a;;b;;;c"}; test.AddInput("T", dims, input); std::vector output_dims(dims); output_dims.push_back(int64_t(4)); std::vector output{ start_mark, - u8"a;", + "a;", end_mark, padval, start_mark, - u8"a;", + "a;", end_mark, padval, start_mark, - u8"b;", + "b;", end_mark, padval, start_mark, - u8"a;", - u8"b;", + "a;", + "b;", end_mark}; test.AddOutput("Y", output_dims, output); @@ -771,22 +771,22 @@ TEST(ContribOpTest, TokenizerExpression_Grouping) { TEST(ContribOpTest, TokenizerExpression_RegDot) { OpTester test("Tokenizer", opset_ver, domain); - const std::string tokenexp(u8"."); + const std::string tokenexp("."); InitTestAttr(test, true, {}, 1, tokenexp); std::vector dims{1}; - std::vector input{u8"a;;;b"}; + std::vector input{"a;;;b"}; test.AddInput("T", dims, input); std::vector output_dims(dims); output_dims.push_back(int64_t(7)); std::vector output{ start_mark, - u8"a", - u8";", - u8";", - u8";", - u8"b", + "a", + ";", + ";", + ";", + "b", end_mark}; test.AddOutput("Y", output_dims, output); @@ -795,19 +795,19 @@ TEST(ContribOpTest, TokenizerExpression_RegDot) { TEST(ContribOpTest, TokenizerExpression_RegChar) { OpTester test("Tokenizer", opset_ver, domain); - const std::string tokenexp(u8"\\w"); + const std::string tokenexp("\\w"); InitTestAttr(test, true, {}, 1, tokenexp); std::vector dims{1}; - std::vector input{u8"a;;;b"}; + std::vector input{"a;;;b"}; test.AddInput("T", dims, input); std::vector output_dims(dims); output_dims.push_back(int64_t(4)); std::vector output{ start_mark, - u8"a", - u8"b", + "a", + "b", end_mark}; test.AddOutput("Y", output_dims, output); diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index e126979532644..6e745776ab6b0 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -6,6 +6,7 @@ #include "onnx/defs/parser.h" #include "core/common/span_utils.h" +#include "core/framework/float8.h" #include "core/graph/model.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/inference_session.h" @@ -69,7 +70,9 @@ static void Check(const char* source, float threshold = 0.001f; for (size_t i = 0; i < size; ++i) { - ASSERT_NEAR(data[i], output_values[i], threshold) << "at position i:" << i; + if (!std::isnan(data[i]) && !std::isnan(output_values[i])) { + ASSERT_NEAR(data[i], output_values[i], threshold) << "at position i:" << i; + } } } @@ -389,25 +392,13 @@ TEST(FunctionTest, AttrSaturateNan) { > agraph (float[N] x) => (float[N] y) { - y0 = local.myfun (x) - y1 = local.myfun (x) - y = Add (y0, y1) - } - - < - opset_import: [ "" : 19 ], - domain: "local" - > - myfun (x) => (y) { - x2 = Constant () - x2_ = Cast(x2) - x3 = CastLike(x2, x2_) - x3_ = Cast(x3) - y = Add (x, x3_) + x_E4M3FNUZ = Cast(x) + x_E4M3FNUZ_2 = CastLike(x, x_E4M3FNUZ) # NaN when OOR + y = Cast(x_E4M3FNUZ_2) } )"; - Check(code, "x", {1.0, 2.0, 1e6}, "y", {243.0, 245.0, 2000241}); // std::numeric_limits::quiet_NaN()}); + Check(code, "x", {1.0, 2.0, 1e6}, "y", {1.0, 2.0, std::numeric_limits::quiet_NaN()}); } #endif diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 077c6ff58e2da..2298e4afa6de0 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -2056,7 +2056,7 @@ TEST(InferenceSessionTests, TestStrictShapeInference) { ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsConfigStrictShapeTypeInference, "1")); tester.Run(session_options, OpTester::ExpectResult::kExpectFailure, - "Mismatch between number of source and target dimensions. Source=1 Target=2", + "Mismatch between number of inferred and declared dimensions. inferred=1 declared=2", excluded_provider_types); } diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 82e5efd92a8f1..e1ce1d4abf81d 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -850,6 +850,130 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test3) { ASSERT_EQ(session_state_2.GetUsedSharedPrePackedWeightCounter(), static_cast(1)); } +// Pre-packing enabled + shared initializers + +// pre-packed weights container + subgraphs = +// caching enabled in pre-packed weights used in subgraphs +TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test4) { + SessionOptions sess_options; + sess_options.enable_mem_pattern = true; + sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL; + sess_options.use_deterministic_compute = false; + sess_options.enable_mem_reuse = true; + // Enable pre-packing + sess_options.config_options.configurations[kOrtSessionOptionsConfigDisablePrepacking] = "0"; + + // Enable shared initializer + OrtMemoryInfo mem_info(CPU, OrtDeviceAllocator); + std::vector float_data(1, 1); + auto value = std::make_unique(); + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape(std::vector{1}), + reinterpret_cast(float_data.data()), mem_info, *value); + + ASSERT_STATUS_OK(sess_options.AddInitializer("if_shared", value.get())); + + // Enable pre-packed weights container + PrepackedWeightsContainer prepacked_weights_container; + + // First session/model + Model model_1("graph_main", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + + CreateGraphWithSubgraph(model_1.MainGraph()); + PlaceAllNodesToCPUEP(model_1.MainGraph()); + SessionState session_state_1(model_1.MainGraph(), + execution_providers, + tp.get(), + nullptr, /*inter_op_thread_pool*/ + dtm, + DefaultLoggingManager().DefaultLogger(), + profiler, + sess_options, + &prepacked_weights_container); + + ASSERT_STATUS_OK(session_state_1.FinalizeSessionState(std::basic_string(), + kernel_registry_manager)); + + // At the main graph level, there should be no pre-packing calls as there are + // no initializers (shared or otherwise) consumed by any nodes in the main graph + ASSERT_EQ(session_state_1.GetNumberOfPrepacksCounter(), static_cast(0)); + + auto if_index_1 = 1; + if (session_state_1.GetKernel(0)->Node().OpType() == "If") { + if_index_1 = 0; + } + + const auto& subgraph_session_states = session_state_1.GetSubgraphSessionStateMap(); + const auto& if_node_session_states = subgraph_session_states.at(if_index_1); + const auto& session_state_1_then_branch_session_state = *if_node_session_states.at("then_branch"); + const auto& session_state_1_else_branch_session_state = *if_node_session_states.at("else_branch"); + + auto if_node_branches_prepack_counter_1 = + session_state_1_then_branch_session_state.GetNumberOfPrepacksCounter() + + session_state_1_else_branch_session_state.GetNumberOfPrepacksCounter(); + + // We should be seeing 2 pre-pack calls in the "If" node (one in each subgraph) + ASSERT_EQ(if_node_branches_prepack_counter_1, static_cast(2)); + + auto if_node_branches_shared_prepack_counter_1 = + session_state_1_then_branch_session_state.GetUsedSharedPrePackedWeightCounter() + + session_state_1_else_branch_session_state.GetUsedSharedPrePackedWeightCounter(); + + // We should only be seeing 1 shared pre-pack weights usage in the "If" node + // Either the "then branch" or "else branch" will be using the shared version + // depending on which branch writes to the shared container + ASSERT_EQ(if_node_branches_shared_prepack_counter_1, static_cast(1)); + + // Second session/model + Model model_2("graph_main", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + + CreateGraphWithSubgraph(model_2.MainGraph()); + PlaceAllNodesToCPUEP(model_2.MainGraph()); + SessionState session_state_2(model_2.MainGraph(), + execution_providers, + tp.get(), + nullptr, /*inter_op_thread_pool*/ + dtm, + DefaultLoggingManager().DefaultLogger(), + profiler, + sess_options, + &prepacked_weights_container); + + ASSERT_STATUS_OK(session_state_2.FinalizeSessionState(std::basic_string(), + kernel_registry_manager)); + + // At the main graph level, there should be no pre-packing calls as there are + // no initializers (shared or otherwise) consumed by any nodes in the main graph + ASSERT_EQ(session_state_2.GetNumberOfPrepacksCounter(), static_cast(0)); + + auto if_index_2 = 1; + if (session_state_2.GetKernel(0)->Node().OpType() == "If") { + if_index_2 = 0; + } + + const auto& subgraph_session_states_2 = session_state_2.GetSubgraphSessionStateMap(); + const auto& if_node_session_states_2 = subgraph_session_states_2.at(if_index_2); + const auto& session_state_2_then_branch_session_state = *if_node_session_states_2.at("then_branch"); + const auto& session_state_2_else_branch_session_state = *if_node_session_states_2.at("else_branch"); + + auto if_node_branches_prepack_counter_2 = + session_state_2_then_branch_session_state.GetNumberOfPrepacksCounter() + + session_state_2_else_branch_session_state.GetNumberOfPrepacksCounter(); + + // We should be seeing 2 pre-pack calls in the "If" node (one in each subgraph) + ASSERT_EQ(if_node_branches_prepack_counter_2, static_cast(2)); + + auto if_node_branches_shared_prepack_counter_2 = + session_state_2_then_branch_session_state.GetUsedSharedPrePackedWeightCounter() + + session_state_2_else_branch_session_state.GetUsedSharedPrePackedWeightCounter(); + + // We should be seeing 2 shared pre-pack weights calls in the "If" node + // Both branches will be using the shared version coming from the first model. + ASSERT_EQ(if_node_branches_shared_prepack_counter_2, static_cast(2)); +} + INSTANTIATE_TEST_SUITE_P(SessionStateTests, SessionStatePrepackingTest, testing::Values(PrepackingTestParam{false, false}, diff --git a/onnxruntime/test/framework/tensor_test.cc b/onnxruntime/test/framework/tensor_test.cc index f24064a403c5d..38e3f184ebc18 100644 --- a/onnxruntime/test/framework/tensor_test.cc +++ b/onnxruntime/test/framework/tensor_test.cc @@ -214,6 +214,7 @@ TEST(TensorTest, Strided) { TensorShape shape({2, 3, 4}); auto alloc = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; void* data = alloc->Alloc(shape.Size() * sizeof(float)); + Tensor t(DataTypeImpl::GetType(), shape, data, alloc->Info()); EXPECT_TRUE(t.IsContiguous()); const TensorShapeVector strides{12, 4, 1}; @@ -227,6 +228,7 @@ TEST(TensorTest, Strided) { ASSERT_EQ(t.Shape(), new_shape); ASSERT_THAT(t.Strides(), testing::ContainerEq(gsl::make_span(new_strides))); ASSERT_EQ(t.SizeInBytes(), sizeof(float) * 24); + Tensor t2(DataTypeImpl::GetType(), new_shape, data, alloc->Info(), 0L, gsl::make_span(new_strides)); EXPECT_FALSE(t2.IsContiguous()); ASSERT_EQ(t2.Shape(), new_shape); @@ -237,7 +239,9 @@ TEST(TensorTest, Strided) { ASSERT_EQ(t2.Shape(), shape); ASSERT_THAT(t2.Strides(), testing::ContainerEq(gsl::make_span(strides))); ASSERT_EQ(t2.SizeInBytes(), sizeof(float) * 24); + alloc->Free(data); + data = alloc->Alloc(sizeof(int64_t)); const TensorShapeVector single_element_strides{0, 0, 0}; Tensor t3(DataTypeImpl::GetType(), shape, data, alloc->Info(), 0L, gsl::make_span(single_element_strides)); @@ -246,8 +250,10 @@ TEST(TensorTest, Strided) { ASSERT_THAT(t3.Strides(), testing::ContainerEq(gsl::make_span(single_element_strides))); ASSERT_EQ(t3.SizeInBytes(), sizeof(int64_t)); alloc->Free(data); + const TensorShapeVector zero_strides{0, 0, 0}; - Tensor t4(DataTypeImpl::GetType(), shape, alloc, zero_strides); + Tensor t4(DataTypeImpl::GetType(), shape, alloc); + t4.SetShapeAndStrides(shape, zero_strides); EXPECT_FALSE(t4.IsContiguous()); EXPECT_EQ(t4.Shape(), shape); ASSERT_THAT(t4.Strides(), testing::ContainerEq(gsl::make_span(zero_strides))); diff --git a/onnxruntime/test/framework/test_tensor_loader.cc b/onnxruntime/test/framework/test_tensor_loader.cc index f627409bb0e60..e71830be08b5e 100644 --- a/onnxruntime/test/framework/test_tensor_loader.cc +++ b/onnxruntime/test/framework/test_tensor_loader.cc @@ -1,11 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "gtest/gtest.h" + #include "core/common/common.h" #include "core/framework/callback.h" #include "core/framework/tensorprotoutils.h" -#include "gtest/gtest.h" -#include "file_util.h" +#include "test/util/include/file_util.h" +#include "test/util/include/asserts.h" #ifdef _WIN32 #include @@ -30,13 +32,14 @@ TEST(CApiTensorTest, load_simple_float_tensor_not_enough_space) { std::vector output(1); OrtValue value; OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); - auto st = utils::TensorProtoToMLValue(Env::Default(), nullptr, p, - MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), value); - // check the result - ASSERT_FALSE(st.IsOK()); + + ASSERT_STATUS_NOT_OK( + utils::TensorProtoToOrtValue(Env::Default(), nullptr, p, + MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), + value)); } -TEST(CApiTensorTest, load_simple_float_tensor) { +TEST(CApiTensorTest, load_simple_float_tensor_membuffer) { // construct a tensor proto onnx::TensorProto p; p.mutable_float_data()->Add(1.0f); @@ -51,9 +54,37 @@ TEST(CApiTensorTest, load_simple_float_tensor) { std::vector output(3); OrtValue value; OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); - auto st = utils::TensorProtoToMLValue(Env::Default(), nullptr, p, - MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), value); - ASSERT_TRUE(st.IsOK()) << st.ErrorMessage(); + ASSERT_STATUS_OK( + utils::TensorProtoToOrtValue(Env::Default(), nullptr, p, + MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), + value)); + float* real_output; + auto ort_st = g_ort->GetTensorMutableData(&value, (void**)&real_output); + ASSERT_EQ(ort_st, nullptr) << g_ort->GetErrorMessage(ort_st); + // check the result + ASSERT_EQ(real_output[0], 1.0f); + ASSERT_EQ(real_output[1], 2.2f); + ASSERT_EQ(real_output[2], 3.5f); + g_ort->ReleaseStatus(ort_st); +} + +TEST(CApiTensorTest, load_simple_float_tensor_allocator) { + // construct a tensor proto + onnx::TensorProto p; + p.mutable_float_data()->Add(1.0f); + p.mutable_float_data()->Add(2.2f); + p.mutable_float_data()->Add(3.5f); + p.mutable_dims()->Add(3); + p.set_data_type(onnx::TensorProto_DataType_FLOAT); + std::string s; + // save it to a buffer + ASSERT_TRUE(p.SerializeToString(&s)); + // deserialize it + AllocatorPtr tmp_allocator = std::make_shared(); + OrtValue value; + + ASSERT_STATUS_OK(utils::TensorProtoToOrtValue(Env::Default(), nullptr, p, tmp_allocator, value)); + float* real_output; auto ort_st = g_ort->GetTensorMutableData(&value, (void**)&real_output); ASSERT_EQ(ort_st, nullptr) << g_ort->GetErrorMessage(ort_st); @@ -106,9 +137,9 @@ static void run_external_data_test() { } OrtValue value; OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); - auto st = utils::TensorProtoToMLValue(Env::Default(), nullptr, p, - MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), value); - ASSERT_TRUE(st.IsOK()) << st.ErrorMessage(); + ASSERT_STATUS_OK(utils::TensorProtoToOrtValue( + Env::Default(), nullptr, p, MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), value)); + float* real_output; auto ort_st = g_ort->GetTensorMutableData(&value, (void**)&real_output); ASSERT_EQ(ort_st, nullptr) << g_ort->GetErrorMessage(ort_st); @@ -156,11 +187,10 @@ TEST(CApiTensorTest, load_huge_tensor_with_external_data) { std::vector output(total_ele_count); OrtValue value; OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); - auto st = utils::TensorProtoToMLValue(Env::Default(), nullptr, p, - MemBuffer(output.data(), output.size() * sizeof(int), cpu_memory_info), value); + ASSERT_STATUS_OK( + utils::TensorProtoToOrtValue(Env::Default(), nullptr, p, + MemBuffer(output.data(), output.size() * sizeof(int), cpu_memory_info), value)); - // check the result - ASSERT_TRUE(st.IsOK()) << "Error from TensorProtoToMLValue: " << st.ErrorMessage(); int* buffer; auto ort_st = g_ort->GetTensorMutableData(&value, (void**)&buffer); ASSERT_EQ(ort_st, nullptr) << "Error from OrtGetTensorMutableData: " << g_ort->GetErrorMessage(ort_st); diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc index bfc46c56975e6..0d9e557ebc813 100644 --- a/onnxruntime/test/framework/tunable_op_test.cc +++ b/onnxruntime/test/framework/tunable_op_test.cc @@ -263,6 +263,7 @@ TEST(TunableOp, OpWrapsMutableFunctor) { class VecAddMoveOnlyFunctor { public: + VecAddMoveOnlyFunctor() = default; VecAddMoveOnlyFunctor(VecAddMoveOnlyFunctor&&) = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(VecAddMoveOnlyFunctor); @@ -288,6 +289,7 @@ TEST(TunableOp, OpWrapsMoveOnlyFunctor) { class VecAddWithIsSupportedMethod { public: + VecAddWithIsSupportedMethod() = default; VecAddWithIsSupportedMethod(VecAddWithIsSupportedMethod&&) = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(VecAddWithIsSupportedMethod); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index dce1f2d40e8b9..6acf631d53cd9 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1000,6 +1000,26 @@ TEST_F(GraphTransformationTests, ConstantFoldingAShapeNodeDeepInTheGraph) { ASSERT_TRUE(op_to_count.size() == 0U); } +// Test we don't fail when constant folding hits a string initializer +TEST_F(GraphTransformationTests, ConstantFoldingStringInitializer) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "gh_issue_17392.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Identity"], 1); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + op_to_count = CountOpsInGraph(graph); + + ASSERT_EQ(op_to_count.size(), 0U) << "Identity node should have been removed"; +} + // Check transformations in the case of a subgraph with constant inputs. TEST_F(GraphTransformationTests, SubgraphWithConstantInputs) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant-subgraph.onnx"; @@ -6280,7 +6300,7 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShouldNotShareForGraphOutput) { TEST_F(GraphTransformationTests, GatherToSplitFusion) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInput({{54}}); - auto* shape_arg = builder.MakeInput({{1}}); + auto* shape_arg = builder.MakeInput({{4}}); auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); @@ -6393,7 +6413,7 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInput({{54}}); - auto* shape_arg = builder.MakeInput({{1}}); + auto* shape_arg = builder.MakeInput({{4}}); auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); auto* gather_index_1 = builder.MakeInitializer({1}, {static_cast(0)}); auto* gather_index_2 = builder.MakeInitializer({1}, {static_cast(1)}); diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index 1f671e90090ba..a55238396cea3 100755 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -429,6 +429,40 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) { } } +// It tests the scenario when scale or bias are not Graph Inputs and not initialized in Graph +// To test this added a Identity node after Scale and Bias terms to ensure LayerNormFusion works properly +TEST_F(GraphTransformationTests, LayerNormScaleBiasTest) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/layer_norm_fusion_scale_bias.onnx"; + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::make_unique(), TransformerLevel::Level2)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["ReduceMean"], 0); + ASSERT_EQ(op_to_count["Sub"], 0); + ASSERT_EQ(op_to_count["Cast"], 0); + ASSERT_EQ(op_to_count["Pow"], 0); + ASSERT_EQ(op_to_count["Add"], 0); + ASSERT_EQ(op_to_count["Sqrt"], 0); + ASSERT_EQ(op_to_count["Div"], 0); + ASSERT_EQ(op_to_count["Mul"], 0); + ASSERT_EQ(op_to_count["LayerNormalization"], 1); + + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "LayerNormalization") { + // LayerNormalization should have three inputs. + EXPECT_EQ(node.InputDefs().size(), 3u) << "LayerNormalization number of inputs does not equal to 3. Got:" << node.InputDefs().size(); + // LayerNormalization input "scale" and "bias" should have the same dimension. + const TensorShapeProto* scale_shape = node.InputDefs()[1]->Shape(); + EXPECT_EQ(scale_shape->dim_size(), 1) << "LayerNormalization scale should be 1D. Got: " << scale_shape->dim_size(); + } + } +} + // If EP is non-GPU EP or unknown, the sub-graph will be not fused because CPU impl for SimplifiedLayerNormalization // doesn't support input and scale having different data types. TEST_F(GraphTransformationTests, SimplifiedLayerNormWithCastsFusionTest) { diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index a438a61cb9b36..1bf1cbacf479e 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -2497,10 +2497,15 @@ TEST(QDQTransformerTests, Clip) { epsilon); }; + constexpr int16_t int16_min = std::numeric_limits::min(); + constexpr uint16_t uint16_min = std::numeric_limits::min(); + std::vector opsets{12, 18, 19}; for (auto opset : opsets) { test_case(.0235294122248888f, static_cast(-128), 0, opset); // [0, 6] test_case(.0235294122248888f, static_cast(-128), 0, opset, true); // [0, 6] contrib qdq + test_case(9.15541313801785e-5f, int16_min, 0, opset, true); // [0, 6] contrib 16-bit qdq + test_case(0.0009f, int16_min, 1, opset, true); // [0, 58.98] contrib 16-bit qdq test_case(.02f, static_cast(-128), 0, opset); // [0, 5.1] test_case(.02f, static_cast(-128), 0, opset, true); // [0, 5.1] contrib qdq test_case(.03f, static_cast(-128), 1, opset); // [0, 7.65] @@ -2513,6 +2518,8 @@ TEST(QDQTransformerTests, Clip) { test_case(.04f, static_cast(-97), 1, opset, true); // [-1.24, 8.96] contrib qdq test_case(.02352941176f, static_cast(0), 0, opset); // [0, 6] test_case(.02352941176f, static_cast(0), 0, opset, true); // [0, 6] contrib qdq + test_case(9.15541313801785e-5f, uint16_min, 0, opset, true); // [0, 6] contrib 16-bit qdq + test_case(0.0009f, uint16_min, 1, opset, true); // [0, 58.98] contrib 16-bit qdq test_case(.02f, static_cast(0), 0, opset); // [0, 5.1] test_case(.02f, static_cast(0), 0, opset, true); // [0, 5.1] contrib qdq test_case(.03f, static_cast(0), 1, opset); // [0, 7.65] @@ -3078,12 +3085,12 @@ TEST(QDQTransformerTests, QDQPropagation_Per_Layer_No_Propagation) { check_graph, TransformerLevel::Default, TransformerLevel::Level1, - 18); // disable TransposeOptimizer for simplicity + 18); TransformerTester(build_test_case, check_graph, TransformerLevel::Default, TransformerLevel::Level1, - 19); // disable TransposeOptimizer for simplicity + 19); }; test_case({1, 13, 13, 23}, {0, 2, 3, 1}, false /*use_contrib_qdq*/); diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index 1f4c499985ad0..4f4157bd7b1cf 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -12,6 +12,9 @@ #include "core/graph/node_attr_utils.h" #include "core/framework/op_node_proto_helper.h" #include "core/framework/utils.h" +#include "core/optimizer/transpose_optimization/onnx_transpose_optimization.h" +#include "core/optimizer/transpose_optimization/optimizer_api.h" +#include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "test/test_environment.h" @@ -19,6 +22,7 @@ #include "test/providers/internal_testing/internal_testing_execution_provider.h" #include "test/util/include/asserts.h" #include "test/util/include/inference_session_wrapper.h" +#include "test/util/include/test_utils.h" namespace onnxruntime { namespace test { @@ -316,20 +320,6 @@ TEST(TransposeOptimizerTests, TestPadNonconst) { /*opset_version*/ {11, 18}); } -// The CUDA Resize kernel assumes that the input is NCHW and -// Resize can't be supported in ORT builds with CUDA enabled. -// TODO: Enable this once the CUDA Resize kernel is implemented -// "generically" (i.e.) aligning with the generic nature of the -// ONNX spec. -// See https://github.com/microsoft/onnxruntime/pull/10824 for -// a similar fix applied to the CPU Resize kernel. -// Per tests included in #10824, the ROCM EP also generates -// incorrect results when this handler is used, so the Resize -// handler is not enabled even for those builds. -// -// The QNN EP requires the input to be NHWC, so the Resize handler is also not enabled -// for QNN builds. -#if !defined(USE_CUDA) && !defined(USE_ROCM) && !defined(USE_QNN) TEST(TransposeOptimizerTests, TestResize) { auto build_test_case_1 = [&](ModelTestBuilder& builder) { auto* input0_arg = MakeInput(builder, {{4, -1, 2, -1}}, {4, 6, 2, 10}, 0.0, 1.0); @@ -358,7 +348,9 @@ TEST(TransposeOptimizerTests, TestResize) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {10, 18}); } @@ -386,7 +378,9 @@ TEST(TransposeOptimizerTests, TestResizeOpset11) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {11, 18}); } @@ -414,7 +408,9 @@ TEST(TransposeOptimizerTests, TestResizeOpset15) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {15, 18}); } @@ -444,7 +440,9 @@ TEST(TransposeOptimizerTests, TestResizeSizeRoi) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {15, 18}); } @@ -478,7 +476,9 @@ TEST(TransposeOptimizerTests, TestResizeRoiScalesZeroRank0) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, {12, 18}); } @@ -507,7 +507,9 @@ TEST(TransposeOptimizerTests, TestResizeNonconst) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {11, 18}); } @@ -536,11 +538,12 @@ TEST(TransposeOptimizerTests, TestResizeNonconstOpset13) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {13, 18}); } -#endif TEST(TransposeOptimizerTests, TestAdd) { auto build_test_case_1 = [&](ModelTestBuilder& builder) { auto* input0_arg = builder.MakeInput({4, 6, 10}, 0.0, 1.0); @@ -4395,9 +4398,9 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue9671) { SessionOptions so; so.session_logid = "TransposeOptimizerTests.RegressionTest_GitHubIssue9671"; - InferenceSession session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.Load(model_uri)); - ASSERT_STATUS_OK(session_object.Initialize()); // optimizers run during initialization + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); // optimizers run during initialization } // regression test for a model where the transpose optimizations incorrectly removed a node providing an implicit @@ -4409,9 +4412,9 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue10305) { SessionOptions so; so.session_logid = "TransposeOptimizerTests.RegressionTest_GitHubIssue10305"; - InferenceSession session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.Load(model_uri)); - ASSERT_STATUS_OK(session_object.Initialize()); // optimizers run during initialization + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); // optimizers run during initialization } // regression test for a model with DQ node with per-axis dequantization followed by a Transpose. @@ -4432,30 +4435,31 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue12151) { { so.graph_optimization_level = TransformerLevel::Default; // off - InferenceSession session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.Load(model_uri)); - ASSERT_STATUS_OK(session_object.Initialize()); - ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches_orig)); + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig)); } { so.graph_optimization_level = TransformerLevel::Level1; // enable transpose optimizer - InferenceSession session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.Load(model_uri)); - ASSERT_STATUS_OK(session_object.Initialize()); - ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches)); + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); } ASSERT_THAT(fetches_orig[0].Get().DataAsSpan(), testing::ContainerEq(fetches[0].Get().DataAsSpan())); } +// These tests uses internal testing EP with static kernels which requires a full build, +// and the NHWC Conv which requires contrib ops +#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) + // Test a Transpose node followed by a Reshape that is logically equivalent to an Transpose can be merged. // The test graph was extracted from a model we were trying to use with the QNN EP. TEST(TransposeOptimizerTests, QnnTransposeReshape) { - // test uses internal testing EP with static kernels which requires a full build, - // and the NHWC Conv with requires contrib ops -#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) Status status; auto model_uri = ORT_TSTR("testdata/layout_transform_reshape.onnx"); @@ -4497,14 +4501,17 @@ TEST(TransposeOptimizerTests, QnnTransposeReshape) { for (const auto& node : graph.Nodes()) { EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; + + if (node.Name() == "Mul_212" || node.Name() == "Add_213") { + // check that the special case in TransposeInputs for a single element input reconnects things back up correctly + const auto& inputs = node.InputDefs(); + EXPECT_EQ(inputs.size(), size_t(2)); + EXPECT_TRUE(inputs[1]->Exists()); + } } -#endif } TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { - // test uses internal testing EP with static kernels which requires a full build, - // and the NHWC Conv with requires contrib ops -#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) Status status; auto model_uri = ORT_TSTR("testdata/layout_transform_reshape.qdq.onnx"); @@ -4541,7 +4548,128 @@ TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; } -#endif +} + +// Validate handling for EP with layout specific Resize that prefers NHWC +TEST(TransposeOptimizerTests, QnnResizeOpset11) { + Status status; + auto model_uri = ORT_TSTR("testdata/nhwc_resize_scales_opset11.onnx"); + + SessionOptions so; + // Uncomment to debug + // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); + + using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + + // set the test EP to support all ops in the model so that the layout transform applies to all nodes + const std::unordered_set empty_set; + auto internal_testing_ep = std::make_unique(empty_set, empty_set, DataLayout::NHWC); + internal_testing_ep->EnableStaticKernels().TakeAllNodes(); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(internal_testing_ep))); + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + + const auto& graph = session.GetGraph(); + // all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout + std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + for (const auto& node : graph.Nodes()) { + EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() + << "' was not assigned to the internal testing EP."; + if (node.OpType() == "Resize") { + EXPECT_EQ(node.Domain(), kMSInternalNHWCDomain) << "Resize was not converted to NHWC layout"; + } + } + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Transpose"], 2) << "Resize should have been wrapped in 2 Transpose nodes to convert to NHWC"; + + // And the post-Resize Transpose should have been pushed all the way to the end + GraphViewer viewer(graph); + EXPECT_EQ(graph.GetNode(viewer.GetNodesInTopologicalOrder().back())->OpType(), "Transpose"); +} +#endif // !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) + +static void CheckSharedInitializerHandling(bool broadcast) { + auto model_uri = broadcast ? ORT_TSTR("testdata/transpose_optimizer_shared_initializers_broadcast.onnx") + : ORT_TSTR("testdata/transpose_optimizer_shared_initializers.onnx"); + + RandomValueGenerator random{123}; + std::vector input_dims{1, 2, 2, 3}; + std::vector input_data = random.Gaussian(input_dims, 0.0f, 1.0f); + + OrtValue input; + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], input_dims, input_data, &input); + + NameMLValMap feeds{{"input0", input}}; + + std::vector output_names{"output0"}; + std::vector fetches_orig; + std::vector fetches; + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); + + // get results with no modifications to the model + { + so.graph_optimization_level = TransformerLevel::Default; // off + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig)); + } + + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + + // we call the ONNX transpose optimizer directly to simplify the model required to exercise the shared initializer + // handling. this means we don't need to disable optimizers that might alter the graph before the + // transpose optimizer runs (at a minimum ConstantFolding, CommonSubexpressionElimination and ConstantSharing). + Graph& graph = session.GetMutableGraph(); + CPUAllocator allocator; + + using namespace onnx_transpose_optimization; + auto api_graph = MakeApiGraph(graph, TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + /*new_node_ep*/ nullptr); + + // default optimization cost check + OptimizeResult result = Optimize(*api_graph); + + ASSERT_EQ(result.error_msg, std::nullopt); + ASSERT_TRUE(result.graph_modified); + ASSERT_TRUE(graph.GraphResolveNeeded()); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Transpose"], 0) << "The Transpose nodes should have been pushed through and canceled out."; + + ASSERT_STATUS_OK(graph.Resolve()); + + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); + } + + ASSERT_THAT(fetches_orig[0].Get().DataAsSpan(), + testing::ContainerEq(fetches[0].Get().DataAsSpan())); +} + +// test we re-use a modified shared initializer wherever possible. model has one initializer that is used by 3 DQ nodes +// and one initializer that is used by 2 Add nodes. both cases should be handled with the initializer being +// modified in-place for the first usage, and the Transpose added to the second usage being cancelled out when the +// original Transpose at the start of the model is pushed down. +TEST(TransposeOptimizerTests, SharedInitializerHandling) { + CheckSharedInitializerHandling(/*broadcast*/ false); +} + +// same setup as the above test, however the initializer is broadcast to bring UnsqueezeInput into play. +// the in-place modification of the initializer for the first usage results in +// -> Transpose -> Squeeze -> {DQ | Add} +// the later usages of the initializer should attempt to cancel out the Squeeze in UnsqueezeInput, +// followed by canceling out the Transpose in TransposeInput. +TEST(TransposeOptimizerTests, SharedInitializerHandlingBroadcast) { + CheckSharedInitializerHandling(/*broadcast*/ true); } } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc index 6d8e05b93510a..8008fd129c19b 100644 --- a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc @@ -578,7 +578,7 @@ TEST(Scan9, DISABLED_BadShape) { ShortSequenceOneInBatchOneLoopStateVar( options, "Node:concat Output:concat_out_1 [ShapeInferenceError] Mismatch between number of source and target dimensions. " - "Source=2 Target=1"); + "inferred=2 declared=1"); } TEST(Scan8, ShortSequenceTwoInBatchOneLoopStateVar) { diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index ef2d7e31654ba..da906ebf76f79 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -1133,11 +1133,15 @@ ::std::vector<::std::basic_string> GetParameterStrings() { #if defined(NDEBUG) || defined(RUN_MODELTEST_IN_DEBUG_MODE) #ifdef _WIN32 ORT_STRING_VIEW model_test_root_path = ORT_TSTR("..\\models"); + // thus, only the root path should be mounted. + ORT_STRING_VIEW model_zoo_path = ORT_TSTR("..\\models\\zoo"); #else ORT_STRING_VIEW model_test_root_path = ORT_TSTR("../models"); + ORT_STRING_VIEW model_zoo_path = ORT_TSTR("../models/zoo"); #endif for (auto p : kvp.second) { paths.push_back(ConcatPathComponent(model_test_root_path, p)); + paths.push_back(ConcatPathComponent(model_zoo_path, p)); } #endif @@ -1168,6 +1172,12 @@ ::std::vector<::std::basic_string> GetParameterStrings() { ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("coreml_VGG16_ImageNet"), + ORT_TSTR("VGG 16-fp32"), + ORT_TSTR("VGG 19-caffe2"), + ORT_TSTR("VGG 19-bn"), + ORT_TSTR("VGG 16-bn"), + ORT_TSTR("VGG 19"), + ORT_TSTR("VGG 16"), ORT_TSTR("faster_rcnn"), ORT_TSTR("GPT2"), ORT_TSTR("GPT2_LM_HEAD"), diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index c8343483b80a6..cb5fc8095982c 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -109,7 +109,7 @@ TEST(ConvFp16Test, Conv1D_Invalid_Input_Shape) { TestConvFp16Op(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, OpTester::ExpectResult::kExpectFailure, "Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. " - "Both source and target dimension have values but they differ. Source=0 Target=2 Dimension=2", + "Both inferred and declared dimension have values but they differ. Inferred=0 Declared=2 Dimension=2", -1); // use latest opset for shape inferencing errors } @@ -132,7 +132,7 @@ TEST(ConvFp16Test, Conv2D_Invalid_Input_Shape) { TestConvFp16Op(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, OpTester::ExpectResult::kExpectFailure, "Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. " - "Both source and target dimension have values but they differ. Source=1 Target=2 Dimension=0", + "Both inferred and declared dimension have values but they differ. Inferred=1 Declared=2 Dimension=0", -1); // use latest opset for shape inferencing errors } diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index e01fd8c78e55f..5103aed50b152 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -249,7 +249,7 @@ TEST(ConvTest, Conv1D_Invalid_Input_Shape) { TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, OpTester::ExpectResult::kExpectFailure, "Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. " - "Both source and target dimension have values but they differ. Source=0 Target=2 Dimension=2", + "Both inferred and declared dimension have values but they differ. Inferred=0 Declared=2 Dimension=2", -1); // use latest opset for shape inferencing errors } @@ -272,7 +272,7 @@ TEST(ConvTest, Conv2D_Invalid_Input_Shape) { TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, OpTester::ExpectResult::kExpectFailure, "Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. " - "Both source and target dimension have values but they differ. Source=1 Target=2 Dimension=0", + "Both inferred and declared dimension have values but they differ. Inferred=1 Declared=2 Dimension=0", -1); // use latest opset for shape inferencing errors } diff --git a/onnxruntime/test/providers/cpu/nn/tfidfvectorizer_test.cc b/onnxruntime/test/providers/cpu/nn/tfidfvectorizer_test.cc index a22253dbb74d4..379b892f39135 100644 --- a/onnxruntime/test/providers/cpu/nn/tfidfvectorizer_test.cc +++ b/onnxruntime/test/providers/cpu/nn/tfidfvectorizer_test.cc @@ -91,7 +91,7 @@ TEST(TfIdfVectorizerTest, Int32_TF_onlyBigrams_Skip0_Empty_Dim1Fail) { test.Run(OpTester::ExpectResult::kExpectFailure, "Can't merge shape info. " - "Both source and target dimension have values but they differ. Source=7 Target=0 Dimension=0"); + "Both inferred and declared dimension have values but they differ. Inferred=7 Declared=0 Dimension=0"); } TEST(TfIdfVectorizerTest, Int32_TF_onlyBigrams_Skip0_Empty_Dim1Success) { @@ -136,7 +136,7 @@ TEST(TfIdfVectorizerTest, Int32_TF_onlyBigrams_Skip0_Empty_Dim2) { test.AddOutput("Y", out_dims, output); test.Run(OpTester::ExpectResult::kExpectFailure, - "Mismatch between number of source and target dimensions. Source=2 Target=1"); + "Mismatch between number of inferred and declared dimensions. inferred=2 declared=1"); } TEST(TfIdfVectorizerTest, Int32_TF_onlyBigrams_Skip01_Empty_Dim2) { @@ -159,7 +159,7 @@ TEST(TfIdfVectorizerTest, Int32_TF_onlyBigrams_Skip01_Empty_Dim2) { test.AddOutput("Y", out_dims, output); test.Run(OpTester::ExpectResult::kExpectFailure, - "Mismatch between number of source and target dimensions. Source=2 Target=1"); + "Mismatch between number of inferred and declared dimensions. inferred=2 declared=1"); } TEST(TfIdfVectorizerTest, Int32_TF_onlyBigrams_Skip0_Empty_Dim2N) { diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 832a8a744c08b..2ead9ec91f93f 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -99,9 +99,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr // CUDA: result mismatch due to not implementing NHWC support // TensorRT: results mismatch // ROCm: results mismatch - // QNN: conflict with layout transformer, need furture investigation test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + {kCudaExecutionProvider, kTensorrtExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extrapolation_uint8) { @@ -131,7 +130,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extrapolation_int8) { @@ -159,7 +158,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_with_extr 10, 10, 10}; test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); + test.Run(); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_uint8) { @@ -188,7 +187,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_extrapolation_int8) { @@ -215,7 +214,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_tf_crop_and_resize_without_e 0, 0, 0}; test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); + test.Run(); } TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear) { @@ -261,9 +260,8 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear) { test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - // QNN: conflict with layout transformer, need furture investigation test.Run(OpTester::ExpectResult::kExpectSuccess, "", - {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + {kCudaExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) { @@ -287,7 +285,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_uint8) { test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) { @@ -309,7 +307,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_int8) { std::vector Y = {0, 0}; test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); + test.Run(); } // Since NNAPI(TFLite) only using the scale calculate using the input/output size @@ -399,7 +397,9 @@ TEST(ResizeOpTest, ResizeOpLinearDownSampleTest_4DBilinear_align_corners) { std::vector Y = {1.0f, 4.0f}; test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); - test.Run(); + + // QNN: result mismatch ("NaN" instead of 1.0f on QNN CPU backend) + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); }; run_test(false); @@ -435,7 +435,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_uin test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); }; run_test(false); @@ -465,7 +465,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_align_corners_int test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y); // TensorRT: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); }; run_test(false); @@ -532,7 +532,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe test.AddOutput("Y", {N, sizes[1], sizes[2], C}, Y); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); } TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixel_int8) { @@ -560,7 +560,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearDownSampleTest_4DBilinear_pytorch_half_pixe std::vector Y = {0, 2, -9}; test.AddOutput("Y", {N, sizes[1], sizes[2], C}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); // TensorRT: results mismatch + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: results mismatch } TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_4DBilinear_asymmetric) { @@ -641,7 +641,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_uint8) { Y, false, .0f, 1.0f); // CUDA: result mismatch due to not implementing NHWC support // ROCm: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kCudaExecutionProvider, kRocmExecutionProvider}); }; run_test(false); @@ -683,7 +683,7 @@ TEST(ResizeOpTest, NhwcResizeOpLinearUpSampleTest_4DBilinear_asymmetric_int8) { test.AddOutput("Y", {N, static_cast(H * scales[1]), static_cast(W * scales[2]), C}, Y, false, .0f, 1.0f); // TensorRT: results mismatch - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); }; run_test(false); @@ -780,7 +780,7 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_5DTrilinear_pytorch_half_pixel) { } TEST(ResizeOpTest, ResizeOpLinearScalesNoOpTest) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; @@ -1079,7 +1079,7 @@ TEST(ResizeOpTest, ResizeOpNearestUpSample_Floor_Align_Corners) { 13.0f, 13.0f, 13.0f, 14.0f, 14.0f, 15.0f, 15.0f, 16.0f}; test.AddOutput("Y", {N, C, static_cast(H * scales[2]), static_cast(W * scales[3])}, Y); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kQnnExecutionProvider}); // QNN: result diff + test.Run(); } TEST(ResizeOpTest, ResizeOpNearest_OneToOneMappingBetweenInputAndOutputDataDims) { @@ -1887,7 +1887,7 @@ void TestAntialiasing(std::map attributes, test.AddOutput("Y", output_shape, output_data); // TensorRT 8.5 supports operators up to Opset 17. Temporarily exclude TensorRT EP due to accurarcy issue. - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kQnnExecutionProvider}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); } TEST(ResizeOpTest, Antialias_Bilinear_No_ExcludeOutside) { diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index c334e0c5ddcb6..0e7ac5ed2b2f0 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -37,7 +37,7 @@ TEST(TransposeOpTest, PermRankDoesNotMatchTensorRank) { // This failure comes from shape inference, because in this case it knows the input dims. // But in the real world, the model can supply different input dims at runtime. test.Run(OpTester::ExpectResult::kExpectFailure, - "Node:node1 Output:Y [ShapeInferenceError] Mismatch between number of source and target dimensions. Source=3 Target=4"); + "Node:node1 Output:Y [ShapeInferenceError] Mismatch between number of inferred and declared dimensions. inferred=3 declared=4"); } // Some of the tests can't run on TensorrtExecutionProvider because of errors. diff --git a/onnxruntime/test/providers/kernel_compute_test_utils.cc b/onnxruntime/test/providers/kernel_compute_test_utils.cc index 0bcf3cbbfd089..977a5bd9ea7b8 100644 --- a/onnxruntime/test/providers/kernel_compute_test_utils.cc +++ b/onnxruntime/test/providers/kernel_compute_test_utils.cc @@ -5,6 +5,8 @@ #include "test/providers/kernel_compute_test_utils.h" +#include + #include "core/framework/execution_providers.h" #include "core/optimizer/optimizer_execution_frame.h" #include "test/util/include/default_providers.h" @@ -55,12 +57,20 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { } #if defined(USE_CUDA) || defined(USE_ROCM) if ((provider_ == kCudaExecutionProvider || provider_ == kRocmExecutionProvider) && !data.is_cpu_data_) { - OrtValue gpu_value; const Tensor& tensor = data.value_.Get(); - Tensor::InitOrtValue(tensor.DataType(), tensor.Shape(), - execution_providers.Get(ep_type)->CreatePreferredAllocators()[0], gpu_value, - tensor.Strides()); - ASSERT_STATUS_OK(dtm.CopyTensor(tensor, *gpu_value.GetMutable())); + + Tensor gpu_tensor(tensor.DataType(), tensor.Shape(), + execution_providers.Get(ep_type)->CreatePreferredAllocators()[0]); + + if (const auto strides = tensor.Strides(); !strides.empty()) { + gpu_tensor.SetShapeAndStrides(tensor.Shape(), strides); + } + + ASSERT_STATUS_OK(dtm.CopyTensor(tensor, gpu_tensor)); + + OrtValue gpu_value; + Tensor::InitOrtValue(std::move(gpu_tensor), gpu_value); + initializer_map[name] = gpu_value; } #endif @@ -161,8 +171,7 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { } else { const Tensor& tensor = outputs[i].Get(); Tensor::InitOrtValue(tensor.DataType(), tensor.Shape(), - execution_providers.Get(cpu_ep_type)->CreatePreferredAllocators()[0], cpu_value, - tensor.Strides()); + execution_providers.Get(cpu_ep_type)->CreatePreferredAllocators()[0], cpu_value); ASSERT_STATUS_OK(dtm.CopyTensor(tensor, *cpu_value.GetMutable())); } diff --git a/onnxruntime/test/providers/kernel_compute_test_utils.h b/onnxruntime/test/providers/kernel_compute_test_utils.h index aed5856fea5a2..a93fd24a3ad4f 100644 --- a/onnxruntime/test/providers/kernel_compute_test_utils.h +++ b/onnxruntime/test/providers/kernel_compute_test_utils.h @@ -75,7 +75,12 @@ class KernelComputeTester { OrtValue value; TensorShape shape(dims); auto allocator = AllocatorManager::Instance().GetAllocator(CPU); - Tensor::InitOrtValue(DataTypeImpl::GetType(), shape, std::move(allocator), value, strides); + Tensor::InitOrtValue(DataTypeImpl::GetType(), shape, std::move(allocator), value); + + if (!strides.empty()) { + value.GetMutable()->SetShapeAndStrides(shape, strides); + } + if (values) { Tensor* tensor = value.GetMutable(); auto* p_data = tensor->MutableData(); diff --git a/onnxruntime/test/providers/qnn/average_pool_test.cc b/onnxruntime/test/providers/qnn/average_pool_test.cc index 79ec07796c0e8..0ee52f7fec21a 100644 --- a/onnxruntime/test/providers/qnn/average_pool_test.cc +++ b/onnxruntime/test/providers/qnn/average_pool_test.cc @@ -32,7 +32,7 @@ static void RunAveragePoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, attrs), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, attrs), provider_options, opset, expected_ep_assignment); @@ -53,8 +53,8 @@ static void RunQDQAveragePoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, attrs), - BuildQDQOpTestCase(op_type, input_defs, attrs), + TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, attrs), + BuildQDQOpTestCase(op_type, input_defs, {}, attrs), provider_options, opset, expected_ep_assignment); diff --git a/onnxruntime/test/providers/qnn/clip_op_test.cc b/onnxruntime/test/providers/qnn/clip_op_test.cc new file mode 100644 index 0000000000000..15ba3b5de2fa1 --- /dev/null +++ b/onnxruntime/test/providers/qnn/clip_op_test.cc @@ -0,0 +1,188 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Clip operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunClipTestOnCPU(const TestInputDef& input_def, + const std::vector>& min_max_defs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Clip", {input_def}, min_max_defs, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Clip with a dynamic min or max input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Clip_Dynamic_MinMax_Unsupported) { + // Dynamic min input is not supported. + RunClipTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {TestInputDef({}, false /* is_initializer */, {-5.0f})}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + // Dynamic max input is not supported. + RunClipTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, false, {5.0f})}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test Clip with default min/max. +TEST_F(QnnCPUBackendTests, Clip_4D_f32_DefaultMinMax) { + RunClipTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {}, // Don't specify min/max inputs. + ExpectedEPNodeAssignment::All); +} + +// Test Clip with 5D input. +TEST_F(QnnCPUBackendTests, Clip_5D_f32) { + RunClipTestOnCPU(TestInputDef({1, 1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, true, {5.0f})}, + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Runs a QDQ Clip model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQClipTestOnHTP(const TestInputDef& input_def, + const std::vector>& min_max_defs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Clip", {input_def}, {min_max_defs}, {}); + auto qdq_model_builder = BuildQDQOpTestCase("Clip", {input_def}, {min_max_defs}, {}, + kOnnxDomain, use_contrib_qdq); + + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ Clip with default min/max. +// NOTE: The Clip operator is *optimized* away during L1 optimizations, so QNN EP does not get a graph with a Clip op. +// Instead, QNN EP will get a graph with a Q -> DQ. +// - Original sequence: Q1 -> DQ1 -> Clip -> Q2 -> DQ2 +// - ClipQuantFusion: Fuses Clip -> QuantizeLinear resulting in Q1 -> DQ1 -> Q2' -> DQ2 +// - DoubleQDQPairsRemover: Simplifies remaining Q1 -> DQ1 -> Q2' -> DQ2 sequence to Q1 -> DQ2. +TEST_F(QnnHTPBackendTests, Clip_U8_DefaultMinMax_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {}, // Don't specify min/max inputs. + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Clip with default min/max. +// NOTE: The Clip operator is *optimized* away during L1 optimizations, so QNN EP does not get a graph with a Clip op. +// Instead, QNN EP will get a graph with a Q -> DQ. +// - Original sequence: Q1 -> DQ1 -> Clip -> Q2 -> DQ2 +// - ClipQuantFusion: Fuses Clip -> QuantizeLinear resulting in Q1 -> DQ1 -> Q2' -> DQ2 +// - DoubleQDQPairsRemover: Simplifies remaining Q1 -> DQ1 -> Q2' -> DQ2 sequence to Q1 -> DQ2. +TEST_F(QnnHTPBackendTests, Clip_U16_DefaultMinMax_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {}, // Don't specify min/max inputs. + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ Clip with non-default min and max inputs. QNN EP will get a graph with a Clip operator. +TEST_F(QnnHTPBackendTests, Clip_U8_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, true, {5.0f})}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Clip with non-default min and max inputs. QNN EP will get a graph with a Clip operator. +TEST_F(QnnHTPBackendTests, Clip_U16_Rank4) { + RunQDQClipTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + {TestInputDef({}, true, {-5.0f}), + TestInputDef({}, true, {5.0f})}, + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Clip of rank 5. +TEST_F(QnnHTPBackendTests, Clip_U8_Rank5) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Clip -> Q + // QDQ node group, which gets lowered to a single QNN Clip node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({1, 1, 2, 2, 2}, {0, 1, 6, 10, 20, 100, 128, 255}); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // Min/Max initializers + NodeArg* min_input = builder.MakeScalarInitializer(5.0f); + NodeArg* max_input = builder.MakeScalarInitializer(100.0f); + + // Clip -> + NodeArg* clip_output = builder.MakeIntermediate(); + builder.AddNode("Clip", {input_dq, min_input, max_input}, {clip_output}); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(clip_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/flatten_op_test.cc b/onnxruntime/test/providers/qnn/flatten_op_test.cc new file mode 100644 index 0000000000000..637d3257ddea7 --- /dev/null +++ b/onnxruntime/test/providers/qnn/flatten_op_test.cc @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Flatten operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunFlattenTestOnCPU(const TestInputDef& input_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Flatten", {input_def}, {}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Flatten input (rank4) with axis == 0. +TEST_F(QnnCPUBackendTests, Flatten_Rank4_Axis0) { + RunFlattenTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(0))}, + ExpectedEPNodeAssignment::All); +} + +// Test that Flatten input (rank4) with axis == -1. +TEST_F(QnnCPUBackendTests, Flatten_Rank4_AxisNeg1) { + RunFlattenTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All); +} + +// Test that Flatten input (rank5) with axis == 2. +TEST_F(QnnCPUBackendTests, Flatten_Rank5_Axis2) { + RunFlattenTestOnCPU(TestInputDef({1, 2, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Runs a model with a non-QDQ Flatten operator on the QNN HTP backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunFlattenTestOnHTP(const TestInputDef& input_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Flatten", {input_def}, {}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ Flatten model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQFlattenTestOnHTP(const TestInputDef& input_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Flatten", {input_def}, {}, attrs); + auto qdq_model_builder = BuildQDQOpTestCase("Flatten", {input_def}, {}, attrs, kOnnxDomain, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ Flatten input (rank4) with axis == 0. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_Axis0) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(0))}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Flatten input (rank4) with axis == 0. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_Axis0_U16) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(0))}, + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ Flatten input (rank4) with axis == -1. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_AxisNeg1) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Flatten input (rank4) with axis == -1. +TEST_F(QnnHTPBackendTests, Flatten_Rank4_AxisNeg1_U16) { + RunQDQFlattenTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + {utils::MakeAttribute("axis", static_cast(-1))}, + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test 8-bit QDQ Flatten with an input of rank5. +TEST_F(QnnHTPBackendTests, Flatten_QDQ8bit_Rank5) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Flatten -> Q + // QDQ node group, which gets lowered to a single QNN Reshape node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({1, 2, 3, 4, 5}, 0, 255); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // Flatten -> + NodeArg* flatten_output = builder.MakeIntermediate(); + Node& flatten_node = builder.AddNode("Flatten", {input_dq}, {flatten_output}); + flatten_node.AddAttribute("axis", static_cast(2)); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(flatten_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test that int32 non-QDQ Flatten runs on HTP backend. +TEST_F(QnnHTPBackendTests, Flatten_Int32_Rank4_Axis2) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunFlattenTestOnHTP(TestInputDef({1, 3, 2, 2}, false, input_data), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All); +} + +// Test that rank 5 int32 Flatten runs on HTP backend. +TEST_F(QnnHTPBackendTests, Flatten_Int32_Rank5_Axis2) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23}; + RunFlattenTestOnHTP(TestInputDef({1, 3, 2, 2, 2}, false, input_data), + {utils::MakeAttribute("axis", static_cast(2))}, + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc index 5b05b39f34a27..37e0db906d054 100644 --- a/onnxruntime/test/providers/qnn/gather_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/gather_op_htp_test.cc @@ -5,6 +5,7 @@ #include #include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -14,47 +15,14 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Function that builds a float model with a Gather op. -template -static GetTestModelFn BuildGatherOpTestCase(const TestInputDef& input_def, - const TestInputDef& indices_def, - int64_t axis = 0) { - return [input_def, indices_def, axis](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* indices = MakeTestInput(builder, indices_def); - NodeArg* output = builder.MakeOutput(); - - Node& gather_node = builder.AddNode("Gather", {input, indices}, {output}); - gather_node.AddAttribute("axis", axis); - }; -} - -// Function that builds a QDQ model with a Gather op. -template -static GetTestQDQModelFn BuildQDQGatherOpTestCase(const TestInputDef& input_def, - const TestInputDef& indices_def, - int64_t axis = 0) { - return [input_def, indices_def, axis](ModelTestBuilder& builder, - std::vector>& output_qparams) { - NodeArg* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); - NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); - - NodeArg* indices = MakeTestInput(builder, indices_def); - - NodeArg* gather_output = builder.MakeIntermediate(); - Node& gather_node = builder.AddNode("Gather", {input_qdq, indices}, {gather_output}); - gather_node.AddAttribute("axis", axis); - - AddQDQNodePairWithOutputAsGraphOutput(builder, gather_output, output_qparams[0].scale, output_qparams[0].zero_point); - }; -} - // Test the accuracy of a QDQ Gather model on QNN EP. Checks if the QDQ model on QNN EP as accurate as the QDQ model on CPU EP // (compared to float32 model). template -static void RunQDQGatherOpTest(const TestInputDef& input_def, const TestInputDef& indices_def, - int64_t axis, int opset, ExpectedEPNodeAssignment expected_ep_assignment) { +static void RunQDQGatherOpTest(const TestInputDef& input_def, + const TestInputDef& indices_def, + const std::vector& attrs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -62,12 +30,14 @@ static void RunQDQGatherOpTest(const TestInputDef& input_def, const TestI provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildGatherOpTestCase(input_def, indices_def, axis), - BuildQDQGatherOpTestCase(input_def, indices_def, axis), + auto f32_model_builder = BuildOpTestCase("Gather", {input_def}, {indices_def}, attrs); + auto qdq_model_builder = BuildQDQOpTestCase("Gather", {input_def}, {indices_def}, attrs); + + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, provider_options, opset, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } // Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all @@ -77,7 +47,7 @@ static void RunQDQGatherOpTest(const TestInputDef& input_def, const TestI TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt64_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, true, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::All); } @@ -86,7 +56,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt64_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt64_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, false, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::None); } @@ -98,7 +68,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt64_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, true, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::All); } @@ -110,7 +80,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt32_Axis0) { RunQDQGatherOpTest(TestInputDef({3, 2}, false, {1.0f, 1.2f, 2.3f, 3.4f, 4.5f, 5.7f}), TestInputDef({2, 2}, false, {0, 1, 1, 2}), - 0, + {utils::MakeAttribute("axis", static_cast(0))}, 13, ExpectedEPNodeAssignment::All); } @@ -122,7 +92,7 @@ TEST_F(QnnHTPBackendTests, GatherOp_IndicesDynamicInt32_Axis0) { TEST_F(QnnHTPBackendTests, GatherOp_IndicesStaticInt32_Axis1) { RunQDQGatherOpTest(TestInputDef({3, 3}, false, {1.0f, 1.2f, 1.9f, 2.3f, 3.4f, 3.9f, 4.5f, 5.7f, 5.9f}), TestInputDef({1, 2}, true, {0, 2}), - 1, + {utils::MakeAttribute("axis", static_cast(1))}, 13, ExpectedEPNodeAssignment::All); } diff --git a/onnxruntime/test/providers/qnn/gemm_op_test.cc b/onnxruntime/test/providers/qnn/gemm_op_test.cc new file mode 100644 index 0000000000000..15f26717b06fd --- /dev/null +++ b/onnxruntime/test/providers/qnn/gemm_op_test.cc @@ -0,0 +1,341 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Gemm operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunGemmTestOnCPU(const std::vector>& input_defs, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Gemm", input_defs, {}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Gemm with non-default 'alpha' or 'beta' attributes is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Gemm_NonDefaultAlphaBeta_Unsupported) { + // Check that alpha != 1.0f is not supported. + RunGemmTestOnCPU({TestInputDef({1, 2}, false, -10.0f, 10.0f), + TestInputDef({2, 4}, false, -10.0f, 10.0f)}, + {utils::MakeAttribute("alpha", 1.5f)}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + + // Check that beta != 1.0f is not supported. + RunGemmTestOnCPU({TestInputDef({1, 2}, false, -10.0f, 10.0f), + TestInputDef({2, 4}, false, -10.0f, 10.0f), + TestInputDef({1, 4}, false, -1.0f, 1.0f)}, + {utils::MakeAttribute("beta", 1.2f)}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Gemm with general 2D bias (M, N) is NOT supported (unless M == 1). +// QNN's FullyConnected operator only supports `outputVector = ( inputAsVector * weightsMatrix ) + biasesVector` +TEST_F(QnnCPUBackendTests, Gemm_2D_Bias_Unsupported) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 12); + + // 2D matrix mul with bias not supported. + RunGemmTestOnCPU({TestInputDef({2, 3}, false, input_a_data), + TestInputDef({3, 4}, false, input_b_data), + TestInputDef({2, 4}, false, -1.0f, 1.0f)}, + {}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. + + // However, 2D matrix mul without a bias is supported. Input A's 0th dimension is interpreted as `batch_size`. + RunGemmTestOnCPU({TestInputDef({2, 3}, false, input_a_data), + TestInputDef({3, 4}, false, input_b_data)}, + {}, + ExpectedEPNodeAssignment::All); // Assigned to QNN EP. +} + +// Test Gemm with dynamic (i.e., not initializer) inputs (A, B, Bias). +TEST_F(QnnCPUBackendTests, Gemm_Dynamic_A_B_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test Gemm with static B and Bias inputs. +TEST_F(QnnCPUBackendTests, Gemm_Static_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test Gemm with transposed A/B and static B and Bias inputs. +TEST_F(QnnCPUBackendTests, Gemm_TransAB_Static_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +// Test Gemm with transposed A/B and dynamic (i.e., not initializer) B and Bias inputs. +TEST_F(QnnCPUBackendTests, Gemm_TransAB_Dynamic_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunGemmTestOnCPU({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that builds a model with a QDQ Gemm node. +template +inline GetTestQDQModelFn BuildQDQGemmTestCase(const std::vector>& input_defs, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_defs, attrs, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + const size_t num_inputs = input_defs.size(); + assert(num_inputs == 2 || num_inputs == 3); + + std::vector op_inputs; + op_inputs.reserve(num_inputs); + + // Process input 0 + NodeArg* input0 = MakeTestInput(builder, input_defs[0]); + QuantParams input0_qparams = GetTestInputQuantParams(input_defs[0]); + NodeArg* input0_after_qdq = AddQDQNodePair(builder, input0, input0_qparams.scale, + input0_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input0_after_qdq); + + // Process input 1 + NodeArg* input1 = MakeTestInput(builder, input_defs[1]); + QuantParams input1_qparams = GetTestInputQuantParams(input_defs[1]); + NodeArg* input1_after_qdq = AddQDQNodePair(builder, input1, input1_qparams.scale, + input1_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input1_after_qdq); + + // Process bias + if (num_inputs == 3) { + NodeArg* bias_input = MakeTestQDQBiasInput(builder, input_defs[2], input0_qparams.scale * input1_qparams.scale, + use_contrib_qdq); + op_inputs.push_back(bias_input); + } + + // Op -> op_output + auto* gemm_output = builder.MakeIntermediate(); + Node& gemm_node = builder.AddNode("Gemm", op_inputs, {gemm_output}); + + for (const auto& attr : attrs) { + gemm_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + AddQDQNodePairWithOutputAsGraphOutput(builder, gemm_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); + }; +} + +// Runs a QDQ Gemm model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQGemmTestOnHTP(const std::vector>& input_defs, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + float f32_abs_err = 1e-4f, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + auto f32_model_builder = BuildOpTestCase("Gemm", input_defs, {}, attrs); + auto qdq_model_builder = BuildQDQGemmTestCase(input_defs, attrs, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment, + f32_abs_err); +} + +// Test 8-bit QDQ Gemm with dynamic inputs A and Bias. The B input is an initializer. +TEST_F(QnnHTPBackendTests, Gemm_Dynamic_A_Static_B_Dynamic_Bias_U8) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Gemm with dynamic inputs A and Bias. The B input is an initializer. +// TODO: Inaccuracy detected for output 'output_0', element 0. +// Output quant params: scale=0.001872879103757441, zero_point=0. +// Expected val: 120.73912048339844 +// QNN QDQ val: 0 (err 120.73912048339844) +// CPU QDQ val: 120.73889923095703 (err 0.00022125244140625) +TEST_F(QnnHTPBackendTests, DISABLED_Gemm_Dynamic_A_Static_B_Dynamic_Bias_U16) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All, + 13, // opset + 1e-4f, // f32_abs_err + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Gemm (16bit act, 8bit weight) with dynamic inputs A and Bias. The B input is an initializer. +// TODO: Allow small inaccuracies based on % of expected value. +// Inaccuracy detected for output 'output_0', element 0. +// Output quant params: scale=0.001872879103757441, zero_point=0. +// Expected val: 120.73912048339844 +// QNN QDQ val: 120.48043823242188 (err 0.2586822509765625) +// CPU QDQ val: 120.48980712890625 (err 0.2493133544921875) +TEST_F(QnnHTPBackendTests, Gemm_Dynamic_A_Static_B_Dynamic_Bias_U16Act_U8Weight) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All, + 13, // opset + 0.15f, // f32_abs_err + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Gemm with dynamic A and B inputs. The Bias is static. +// TODO: Inaccuracy detected for output 'output', element 0. +// Output quant params: scale=0.48132994771003723, zero_point=0. +// Expected val: 120.73912048339844 +// QNN QDQ val: 77.012794494628906 (err 43.726325988769531) +// CPU QDQ val: 119.85115814208984 (err 0.88796234130859375) +TEST_F(QnnHTPBackendTests, DISABLED_Gemm_Dynamic_A_B_Static_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, false, input_b_data), // Dynamic => inaccuracy + TestInputDef({1, 4}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test QDQ Gemm with static B and Bias inputs. +TEST_F(QnnHTPBackendTests, Gemm_Static_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({1, 6}, false, input_a_data), + TestInputDef({6, 4}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {}, + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Gemm with transposed A/B and static B and Bias inputs. +TEST_F(QnnHTPBackendTests, Gemm_TransAB_Static_B_And_Bias_U8) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +// Test QDQ Gemm (16bit activation, 8bit weight) with transposed A/B and static B and Bias inputs. +// TODO: Allow small inaccuracies based on % of expected value. +// Inaccuracy detected for output 'output_0', element 0. +// Output quant params: scale=0.00047966410056687891, zero_point=0. +// Expected val: 29.434776306152344 +// QNN QDQ val: 29.191877365112305 (err 0.24289894104003906) +// CPU QDQ val: 29.197153091430664 (err 0.23762321472167969) +TEST_F(QnnHTPBackendTests, Gemm_TransAB_Static_B_And_Bias_U16Act_U8Weight) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, true, input_b_data), + TestInputDef({1, 4}, true, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All, + 13, // opset + 0.15f, // f32_abs_err + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Gemm with transposed A/B and dynamic (i.e., not initializer) B and Bias inputs. +TEST_F(QnnHTPBackendTests, Gemm_TransAB_Dynamic_B_And_Bias) { + std::vector input_a_data = GetFloatDataInRange(-10.0f, 10.0f, 6); + std::vector input_b_data = GetFloatDataInRange(-5.0f, 5.0f, 24); + std::vector input_c_data = GetFloatDataInRange(-1.0f, 1.0f, 4); + RunQDQGemmTestOnHTP({TestInputDef({6, 1}, false, input_a_data), + TestInputDef({4, 6}, false, input_b_data), + TestInputDef({1, 4}, false, input_c_data)}, + {utils::MakeAttribute("transA", static_cast(1)), + utils::MakeAttribute("transB", static_cast(1))}, + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc b/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc index 594973e37ef0b..f662ac14336f8 100644 --- a/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc +++ b/onnxruntime/test/providers/qnn/instance_norm_htp_test.cc @@ -16,25 +16,6 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Function that builds a float32 model with an InstanceNormalization operator. -GetTestModelFn BuildInstanceNormTestCase(const TestInputDef& input_def, - const TestInputDef& scale_def, - const TestInputDef& bias_def, - const std::vector& attrs) { - return [input_def, scale_def, bias_def, attrs](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* scale = MakeTestInput(builder, scale_def); - NodeArg* bias = MakeTestInput(builder, bias_def); - - NodeArg* output = builder.MakeOutput(); - Node& op_node = builder.AddNode("InstanceNormalization", {input, scale, bias}, {output}); - - for (const auto& attr : attrs) { - op_node.AddAttributeProto(attr); - } - }; -} - // Function that builds a QDQ model with an InstanceNormalization operator. template static GetTestQDQModelFn BuildQDQInstanceNormTestCase(const TestInputDef& input_def, @@ -93,7 +74,7 @@ static void RunInstanceNormQDQTest(const TestInputDef& input_def, #endif // Runs model with DQ-> InstanceNorm -> Q and compares the outputs of the CPU and QNN EPs. - TestQDQModelAccuracy(BuildInstanceNormTestCase(input_def, scale_def, bias_def, attrs), + TestQDQModelAccuracy(BuildOpTestCase("InstanceNormalization", {input_def, scale_def, bias_def}, {}, attrs), BuildQDQInstanceNormTestCase(input_def, scale_def, bias_def, attrs), provider_options, 18, diff --git a/onnxruntime/test/providers/qnn/layer_norm_test.cc b/onnxruntime/test/providers/qnn/layer_norm_test.cc index aa6c6a142e6d1..085454004e5a5 100644 --- a/onnxruntime/test/providers/qnn/layer_norm_test.cc +++ b/onnxruntime/test/providers/qnn/layer_norm_test.cc @@ -29,7 +29,7 @@ static void RunLayerNormCpuTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, attrs), + RunQnnModelTest(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, {}, attrs), provider_options, 17, expected_ep_assignment); @@ -114,7 +114,7 @@ static void RunLayerNormQDQTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, attrs), + TestQDQModelAccuracy(BuildOpTestCase("LayerNormalization", {input_def, scale_def}, {}, attrs), BuildQDQLayerNormTestCase(input_def, scale_def, attrs), provider_options, 17, // opset diff --git a/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc b/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc index a8237817c71df..e3077ec569923 100644 --- a/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/leakyrelu_op_htp_test.cc @@ -5,6 +5,7 @@ #include #include "core/graph/graph.h" +#include "core/graph/node_attr_utils.h" #include "test/optimizer/qdq_test_utils.h" #include "test/providers/qnn/qnn_test_utils.h" @@ -15,42 +16,10 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Creates a function that builds a model with a LeakyRelu operator. -static GetTestModelFn BuildLeakyReluOpTestCase(const TestInputDef& input_def, float alpha) { - return [input_def, alpha](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* output = builder.MakeOutput(); - Node& leakyrelu_node = builder.AddNode("LeakyRelu", {input}, {output}); - leakyrelu_node.AddAttribute("alpha", alpha); - }; -} - -// Creates a function that builds a QDQ model with a LeakyRelu operator. -template -static GetTestQDQModelFn BuildQDQLeakyReluOpTestCase(const TestInputDef& input_def, - float alpha) { - return [input_def, alpha](ModelTestBuilder& builder, - std::vector>& output_qparams) { - // input => Q => DQ => - NodeArg* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); - NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point); - - // LeakryRelu - auto* leakyrelu_output = builder.MakeIntermediate(); - Node& leakyrelu_node = builder.AddNode("LeakyRelu", {input_qdq}, {leakyrelu_output}); - leakyrelu_node.AddAttribute("alpha", alpha); - - // => Q => DQ -> final output - AddQDQNodePairWithOutputAsGraphOutput(builder, leakyrelu_output, output_qparams[0].scale, - output_qparams[0].zero_point); - }; -} - // Checks the accuracy of a QDQ LeakyRelu model by comparing to ORT CPU EP. template static void RunLeakyReluOpQDQTest(const TestInputDef& input_def, - float alpha, + const std::vector& attrs, int opset, ExpectedEPNodeAssignment expected_ep_assignment) { ProviderOptions provider_options; @@ -60,12 +29,11 @@ static void RunLeakyReluOpQDQTest(const TestInputDef& input_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildLeakyReluOpTestCase(input_def, alpha), - BuildQDQLeakyReluOpTestCase(input_def, alpha), + TestQDQModelAccuracy(BuildOpTestCase("LeakyRelu", {input_def}, {}, attrs), + BuildQDQOpTestCase("LeakyRelu", {input_def}, {}, attrs), provider_options, opset, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } // Test creates a DQ -> Gather -> Q -> DQ graph, and checks that all @@ -74,7 +42,7 @@ static void RunLeakyReluOpQDQTest(const TestInputDef& input_def, // - Uses uint8 as the quantization type. TEST_F(QnnHTPBackendTests, LeakyReluOpSet15) { RunLeakyReluOpQDQTest(TestInputDef({1, 2, 3}, false, {-40.0f, -20.0f, 0.0f, 10.0f, 30.0f, 40.0f}), - 0.2f, + {utils::MakeAttribute("alpha", 0.2f)}, 15, ExpectedEPNodeAssignment::All); } @@ -85,7 +53,7 @@ TEST_F(QnnHTPBackendTests, LeakyReluOpSet15) { // - Uses uint8 as the quantization type. TEST_F(QnnHTPBackendTests, LeakyReluOpSet16) { RunLeakyReluOpQDQTest(TestInputDef({1, 2, 3}, false, {-40.0f, -20.0f, 0.0f, 10.0f, 30.0f, 40.0f}), - 0.2f, + {utils::MakeAttribute("alpha", 0.2f)}, 16, ExpectedEPNodeAssignment::All); } diff --git a/onnxruntime/test/providers/qnn/max_min_op_test.cc b/onnxruntime/test/providers/qnn/max_min_op_test.cc index 09ea71e5f03eb..3deff121f3c72 100644 --- a/onnxruntime/test/providers/qnn/max_min_op_test.cc +++ b/onnxruntime/test/providers/qnn/max_min_op_test.cc @@ -27,7 +27,7 @@ static void RunCPUMinOrMaxOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, kOnnxDomain), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), provider_options, opset, expected_ep_assignment); @@ -48,12 +48,11 @@ static void RunQDQMinOrMaxOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, kOnnxDomain), // baseline float32 model - BuildQDQOpTestCase(op_type, input_defs, {}, kOnnxDomain), // QDQ model + TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), // baseline float32 model + BuildQDQOpTestCase(op_type, input_defs, {}, {}, kOnnxDomain), // QDQ model provider_options, opset, - expected_ep_assignment, - 1e-4f); + expected_ep_assignment); } // diff --git a/onnxruntime/test/providers/qnn/pool_op_test.cpp b/onnxruntime/test/providers/qnn/pool_op_test.cpp index fee10a542fb82..7ed9072a95b32 100644 --- a/onnxruntime/test/providers/qnn/pool_op_test.cpp +++ b/onnxruntime/test/providers/qnn/pool_op_test.cpp @@ -17,21 +17,6 @@ namespace onnxruntime { namespace test { -// Returns a function that creates a graph with a single MaxPool operator. -static GetTestModelFn BuildPoolTestCase(const std::string& op_type, - const TestInputDef& input_def, - const std::vector& attrs) { - return [op_type, input_def, attrs](ModelTestBuilder& builder) { - NodeArg* input = MakeTestInput(builder, input_def); - NodeArg* output = builder.MakeOutput(); - Node& pool_node = builder.AddNode(op_type, {input}, {output}); - - for (const auto& attr : attrs) { - pool_node.AddAttributeProto(attr); - } - }; -} - // Returns a function that creates a graph with a QDQ MaxPool operator. template GetTestQDQModelFn BuildPoolQDQTestCase(const std::string& op_type, @@ -74,7 +59,7 @@ static void RunPoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildPoolTestCase(op_type, input_def, attrs), + RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {}, attrs), provider_options, opset, expected_ep_assignment); @@ -95,7 +80,7 @@ static void RunQDQPoolOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildPoolTestCase(op_type, input_def, attrs), + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, attrs), BuildPoolQDQTestCase(op_type, input_def, attrs), provider_options, opset, diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 724e9a11cd781..51df93f8853ec 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -73,7 +73,7 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, const ProviderOption void InferenceModel(const std::string& model_data, const char* log_id, std::unique_ptr execution_provider, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, - std::vector& output_names, std::vector& output_vals) { + std::vector& output_vals) { SessionOptions so; so.session_logid = log_id; RunOptions run_options; @@ -102,14 +102,12 @@ void InferenceModel(const std::string& model_data, const char* log_id, } const auto& outputs = graph.GetOutputs(); + std::vector output_names; - // fetch all outputs if necessary. - if (output_names.empty()) { - output_names.reserve(outputs.size()); - for (const auto* node_arg : outputs) { - if (node_arg->Exists()) { - output_names.push_back(node_arg->Name()); - } + output_names.reserve(outputs.size()); + for (const auto* node_arg : outputs) { + if (node_arg->Exists()) { + output_names.push_back(node_arg->Name()); } } diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index fd572fa17f2b1..14c62f98f6a3e 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -213,13 +213,12 @@ inline QuantParams GetTestInputQuantParams(const TestInputDef& inp * \param execution_provider The EP on which to run the model. Set to nullptr for CPU EP. * \param expected_ep_assignment Describes "which nodes" should be assigned to the EP. * \param feeds The input feeds. - * \param output_names If empty, the function will write the output names. * \param output_vals Initialized to the inference results. */ void InferenceModel(const std::string& model_data, const char* log_id, std::unique_ptr execution_provider, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, - std::vector& output_names, std::vector& output_vals); + std::vector& output_vals); /** * Tests the accuracy of a QDQ model on QNN EP by runnning 3 inferences: @@ -263,9 +262,8 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Run f32 model on CPU EP and collect outputs. std::vector cpu_f32_outputs; - std::vector output_names; InferenceModel(f32_model_data, "f32_model_logger", nullptr, ExpectedEPNodeAssignment::All, - f32_helper.feeds_, output_names, cpu_f32_outputs); + f32_helper.feeds_, cpu_f32_outputs); ASSERT_FALSE(cpu_f32_outputs.empty()); const size_t num_outputs = cpu_f32_outputs.size(); @@ -304,13 +302,13 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Run QDQ model on QNN EP and collect outputs. std::vector qnn_qdq_outputs; InferenceModel(qdq_model_data, "qdq_model_logger", QnnExecutionProviderWithOptions(qnn_options), - expected_ep_assignment, qdq_helper.feeds_, output_names, qnn_qdq_outputs); + expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs); if (expected_ep_assignment != ExpectedEPNodeAssignment::None) { // Run QDQ model on CPU EP and collect outputs. std::vector cpu_qdq_outputs; InferenceModel(qdq_model_data, "qdq_model_logger", nullptr, ExpectedEPNodeAssignment::All, - qdq_helper.feeds_, output_names, cpu_qdq_outputs); + qdq_helper.feeds_, cpu_qdq_outputs); ASSERT_EQ(cpu_qdq_outputs.size(), num_outputs); ASSERT_EQ(qnn_qdq_outputs.size(), num_outputs); @@ -320,7 +318,9 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe // Compare accuracy of QDQ results with float model. // QNN EP must be at least as accurate as CPU EP when running the QDQ model. + const std::string base_output_name = "output_"; for (size_t i = 0; i < num_outputs; i++) { + std::string debug_output_name = base_output_name + std::to_string(i); auto& cpu_qdq_tensor = cpu_qdq_outputs[i].Get(); auto& qnn_qdq_tensor = qnn_qdq_outputs[i].Get(); @@ -353,8 +353,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe } EXPECT_TRUE(is_as_accurate_as_cpu_qdq) - << "Inaccuracy detected for output '" - << output_names[i] + << "Inaccuracy detected for output '" << debug_output_name << "', element " << j << ".\nOutput quant params: scale=" << output_qparams[i].scale << ", zero_point=" << static_cast(output_qparams[i].zero_point) @@ -363,7 +362,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe << "CPU QDQ val: " << cpu_qdq_val << " (err " << cpu_err << ")"; } } else { - VerifyOutput(output_names[i], cpu_f32_outputs[i].Get(), qnn_qdq_tensor, fp32_abs_err); + VerifyOutput(debug_output_name, cpu_f32_outputs[i].Get(), qnn_qdq_tensor, fp32_abs_err); } } } @@ -438,25 +437,33 @@ NodeArg* MakeTestQDQBiasInput(ModelTestBuilder& builder, const TestInputDef +template inline GetTestModelFn BuildOpTestCase(const std::string& op_type, - const std::vector>& input_defs, + const std::vector>& input_defs_1, + const std::vector>& input_defs_2, const std::vector& attrs, const std::string& op_domain = kOnnxDomain) { - return [op_type, input_defs, attrs, op_domain](ModelTestBuilder& builder) { + return [op_type, input_defs_1, input_defs_2, attrs, op_domain](ModelTestBuilder& builder) { std::vector op_inputs; - op_inputs.reserve(input_defs.size()); + op_inputs.reserve(input_defs_1.size() + input_defs_2.size()); + + for (const auto& input_def : input_defs_1) { + NodeArg* input = MakeTestInput(builder, input_def); + op_inputs.push_back(input); + } - for (const auto& input_def : input_defs) { - NodeArg* input = MakeTestInput(builder, input_def); + for (const auto& input_def : input_defs_2) { + NodeArg* input = MakeTestInput(builder, input_def); op_inputs.push_back(input); } @@ -470,7 +477,8 @@ inline GetTestModelFn BuildOpTestCase(const std::string& op_type, } /** - * Returns a function that builds a model with a single QDQ operator with N inputs of the same element type. + * Returns a function that builds a model with a single QDQ operator with N float (quantizeable) inputs + * and M inputs of a potentially different type. * * \param op_type The operator to instantiate. * \param input_defs List of input definitions. @@ -478,25 +486,33 @@ inline GetTestModelFn BuildOpTestCase(const std::string& op_type, * \param op_domain The operator's domain. Defaults to the ONNX domain (i.e., ""). * \returns A model building function. */ -template -inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_type, - const std::vector>& input_defs, - const std::vector& attrs, - const std::string& op_domain = kOnnxDomain, - bool use_contrib_qdq = false) { - return [op_type, input_defs, attrs, op_domain, - use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { +template +inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_type, + const std::vector>& quant_input_defs, + const std::vector>& non_quant_input_defs, + const std::vector& attrs, + const std::string& op_domain = kOnnxDomain, + bool use_contrib_qdq = false) { + return [op_type, quant_input_defs, non_quant_input_defs, attrs, op_domain, + use_contrib_qdq](ModelTestBuilder& builder, std::vector>& output_qparams) { std::vector op_inputs; - op_inputs.reserve(input_defs.size()); + op_inputs.reserve(quant_input_defs.size() + non_quant_input_defs.size()); - for (const auto& input_def : input_defs) { + // Create QDQ inputs + for (const auto& input_def : quant_input_defs) { NodeArg* input = MakeTestInput(builder, input_def); - QuantParams input_qparams = GetTestInputQuantParams(input_def); - NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, - input_qparams.zero_point, use_contrib_qdq); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); op_inputs.push_back(input_after_qdq); } + // Create non-QDQ inputs + for (const auto& input_def : non_quant_input_defs) { + NodeArg* input = MakeTestInput(builder, input_def); + op_inputs.push_back(input); + } + // Op -> op_output auto* op_output = builder.MakeIntermediate(); Node& onnx_node = builder.AddNode(op_type, op_inputs, {op_output}, op_domain); @@ -506,8 +522,8 @@ inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_ty } // op_output -> Q -> DQ -> output - AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, - output_qparams[0].zero_point, use_contrib_qdq); + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, output_qparams[0].scale, + output_qparams[0].zero_point, use_contrib_qdq); }; } diff --git a/onnxruntime/test/providers/qnn/reshape_op_test.cc b/onnxruntime/test/providers/qnn/reshape_op_test.cc new file mode 100644 index 0000000000000..eb495e44ec770 --- /dev/null +++ b/onnxruntime/test/providers/qnn/reshape_op_test.cc @@ -0,0 +1,225 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Reshape operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunReshapeTestOnCPU(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Reshape", {input_def}, {shape_def}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Reshape with a dynamic shape input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Reshape_DynamicShape_Unsupported) { + RunReshapeTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, false /* is_initializer */, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test that Reshape with an enabled 'allowzero' attribute is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Reshape_AllowZeroAttr_Unsupported) { + RunReshapeTestOnCPU(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {1, 48}), + {utils::MakeAttribute("allowzero", static_cast(1))}, + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test Reshape of rank 4 -> rank 2. +TEST_F(QnnCPUBackendTests, Reshape_4D_f32) { + RunReshapeTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ Reshape operator. +template +GetTestQDQModelFn BuildQDQReshapeTestCase(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_def, shape_def, attrs, + use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // shape input + NodeArg* shape_input = MakeTestInput(builder, shape_def); + + // Reshape op + NodeArg* reshape_output = builder.MakeIntermediate(); + Node& reshape_node = builder.AddNode("Reshape", {input_qdq, shape_input}, {reshape_output}); + + for (const auto& attr : attrs) { + reshape_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for Reshape. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, reshape_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + }; +} + +// Runs a model with a non-QDQ Reshape operator on the QNN HTP backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunReshapeTestOnHTP(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Reshape", {input_def}, {shape_def}, attrs), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ Reshape model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQReshapeTestOnHTP(const TestInputDef& input_def, + const TestInputDef& shape_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Reshape", {input_def}, {shape_def}, attrs); + auto qdq_model_builder = BuildQDQReshapeTestCase(input_def, shape_def, attrs, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test that QDQ Reshape with a dynamic shape input is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Reshape_DynamicShape_Unsupported) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, false /* is_initializer */, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test that QDQ Reshape with an enabled 'allowzero' attribute is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Reshape_AllowZeroAttr_Unsupported) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {1, 48}), + {utils::MakeAttribute("allowzero", static_cast(1))}, + ExpectedEPNodeAssignment::None, // Should not be assigned to QNN EP. + 19); // Opset +} + +// Test 8-bit QDQ Reshape of rank 4 -> rank 2. +TEST_F(QnnHTPBackendTests, Reshape_4D_u8) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test 16-bit QDQ Reshape of rank 4 -> rank 2. +TEST_F(QnnHTPBackendTests, Reshape_4D_u16) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({2}, true, {1, 48}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19, // Opset + true); // Use com.microsoft Q/DQ ops +} + +// Test that int32 Reshape runs on HTP backend. +TEST_F(QnnHTPBackendTests, Reshape_4D_int32) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunReshapeTestOnHTP(TestInputDef({1, 3, 2, 2}, false, input_data), + TestInputDef({3}, true, {1, 1, 12}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test QDQ Reshape with a shape value of 0 (copy dimension from input) +TEST_F(QnnHTPBackendTests, Reshape_4D_0MeansCopy) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({3}, true, {1, 0, 16}), // zero means copy => '(1, 3, 16)' + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +// Test QDQ Reshape with a shape value of -1 (dimension is inferred from the expected number of elements) +TEST_F(QnnHTPBackendTests, Reshape_4D_Neg1MeansInfer) { + RunQDQReshapeTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({3}, true, {1, 3, -1}), // -1 means infer => '(1, 3, 16)' + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19); // Opset +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/resize_test.cc b/onnxruntime/test/providers/qnn/resize_test.cc index cf336ca9eeb8b..cd6865d443cc0 100644 --- a/onnxruntime/test/providers/qnn/resize_test.cc +++ b/onnxruntime/test/providers/qnn/resize_test.cc @@ -120,7 +120,7 @@ static void RunCPUResizeOpTest(const TestInputDef& input_def, const std:: const std::string& mode, const std::string& coordinate_transformation_mode, const std::string& nearest_mode, ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 11) { + int opset = 19) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnCpu.dll"; @@ -138,7 +138,7 @@ static void RunCPUResizeOpTestWithScales(const TestInputDef& input_def, c const std::string& mode, const std::string& coordinate_transformation_mode, const std::string& nearest_mode, ExpectedEPNodeAssignment expected_ep_assignment, - int opset = 11) { + int opset = 19) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnCpu.dll"; @@ -157,7 +157,8 @@ static void RunQDQResizeOpTest(const TestInputDef& input_def, const std::vector& sizes_data, const std::string& mode, const std::string& coordinate_transformation_mode, const std::string& nearest_mode, - ExpectedEPNodeAssignment expected_ep_assignment) { + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { ProviderOptions provider_options; #if defined(_WIN32) provider_options["backend_path"] = "QnnHtp.dll"; @@ -169,27 +170,20 @@ static void RunQDQResizeOpTest(const TestInputDef& input_def, GetQDQResizeModelBuilder(input_def, sizes_data, mode, coordinate_transformation_mode, nearest_mode), provider_options, - 18, // opset - expected_ep_assignment, - 1e-5f); + opset, + expected_ep_assignment); } // // CPU tests: // -// TODO: Our QNN CPU translation of ONNX Resize with "nearest" mode uses QNN's ResizeNearestNeighbor -// operator, which does not have a way to specify rounding (i.e., "nearest_mode" in ONNX). It is not clear -// what kind of rounding QNN's ResizeNearestNeighbor uses. Therefore, we do not yet know how to compare -// ONNX Resize to QNN ResizeNearestNeighbor. These tests should remain disabled until this behavior is -// clarified. If, for example, it turns out that ResizeNearestNeighbor uses "floor" rounding, then we should -// only compare against ONNX resize with "floor" rounding. - // Upsample that uses "round_prefer_floor" as the "nearest_mode". // coordinate_transformation_mode: "half_pixel" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeUpsampleNearestHalfPixel_rpf) { - RunCPUResizeOpTest(TestInputDef({1, 2, 7, 5}, false, -10.0f, 10.0f), // Random input w/ range [-10, 10] - {1, 2, 21, 10}, // Sizes +TEST_F(QnnCPUBackendTests, ResizeUpsampleNearestHalfPixel_rpf) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 70); + RunCPUResizeOpTest(TestInputDef({1, 2, 7, 5}, false, input_data), + {1, 2, 21, 10}, // Sizes "nearest", "half_pixel", "round_prefer_floor", @@ -198,57 +192,72 @@ TEST_F(QnnCPUBackendTests, DISABLED_ResizeUpsampleNearestHalfPixel_rpf) { // Upsample that uses "round_prefer_ceil" as the "nearest_mode". // coordinate_transformation_mode: "half_pixel" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeUpsampleNearestHalfPixel_rpc) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeUpsampleNearestHalfPixel_rpc) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 7, 5}, "nearest", "half_pixel", "round_prefer_ceil", ExpectedEPNodeAssignment::All); } // Downsample that uses "round_prefer_ceil" as the "nearest_mode". // coordinate_transformation_mode: "half_pixel" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeDownsampleNearestHalfPixel_rpc) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeDownsampleNearestHalfPixel_rpc) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 1, 3}, "nearest", "half_pixel", "round_prefer_ceil", ExpectedEPNodeAssignment::All); } // Downsample that uses "round_prefer_floor" as the "nearest_mode". // coordinate_transformation_mode: "half_pixel" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeDownsampleNearestHalfPixel_rpf) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeDownsampleNearestHalfPixel_rpf) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 1, 2}, "nearest", "half_pixel", "round_prefer_ceil", ExpectedEPNodeAssignment::All); } // Upsample that uses "round_prefer_floor" as the "nearest_mode". // coordinate_transformation_mode: "align_corners" -// QNN v2.13: index #50 don't match, which is 4.67152 from -1.93515 -TEST_F(QnnCPUBackendTests, DISABLED_ResizeUpsampleNearestAlignCorners_rpf) { - RunCPUResizeOpTest(TestInputDef({1, 2, 7, 5}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeUpsampleNearestAlignCorners_rpf) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 70); + RunCPUResizeOpTest(TestInputDef({1, 2, 7, 5}, false, input_data), {1, 2, 21, 10}, "nearest", "align_corners", "round_prefer_floor", ExpectedEPNodeAssignment::All); } +// Upsample that uses "round_prefer_floor" as the "nearest_mode". +// coordinate_transformation_mode: "asymmetric" +TEST_F(QnnCPUBackendTests, ResizeUpsampleNearestAsymmetric_rpf) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 70); + RunCPUResizeOpTest(TestInputDef({1, 2, 7, 5}, false, input_data), + {1, 2, 21, 10}, "nearest", "asymmetric", "round_prefer_floor", + ExpectedEPNodeAssignment::All); +} + // Upsample that uses "round_prefer_ceil" as the "nearest_mode". // coordinate_transformation_mode: "align_corners" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeUpsampleNearestAlignCorners_rpc) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeUpsampleNearestAlignCorners_rpc) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 7, 5}, "nearest", "align_corners", "round_prefer_ceil", ExpectedEPNodeAssignment::All); } // Downsample that uses "round_prefer_ceil" as the "nearest_mode". // coordinate_transformation_mode: "align_corners" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeDownsampleNearestAlignCorners_rpc) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeDownsampleNearestAlignCorners_rpc) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 1, 3}, "nearest", "align_corners", "round_prefer_ceil", ExpectedEPNodeAssignment::All); } // Downsample that uses "round_prefer_floor" as the "nearest_mode". // coordinate_transformation_mode: "align_corners" -TEST_F(QnnCPUBackendTests, DISABLED_ResizeDownsampleNearestAlignCorners_rpf) { - RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, -10.0f, 10.0f), +TEST_F(QnnCPUBackendTests, ResizeDownsampleNearestAlignCorners_rpf) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunCPUResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), {1, 1, 1, 2}, "nearest", "align_corners", "round_prefer_floor", ExpectedEPNodeAssignment::All); } @@ -258,76 +267,177 @@ TEST_F(QnnCPUBackendTests, DISABLED_ResizeDownsampleNearestAlignCorners_rpf) { // TEST_F(QnnCPUBackendTests, Resize2xLinearHalfPixel) { - RunCPUResizeOpTest(TestInputDef({1, 3, 4, 5}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 60); + RunCPUResizeOpTest(TestInputDef({1, 3, 4, 5}, false, input_data), {1, 3, 8, 10}, "linear", "half_pixel", "", ExpectedEPNodeAssignment::All); } TEST_F(QnnCPUBackendTests, Resize2xLinearHalfPixel_scales) { - RunCPUResizeOpTestWithScales(TestInputDef({1, 3, 4, 5}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 60); + RunCPUResizeOpTestWithScales(TestInputDef({1, 3, 4, 5}, false, input_data), {1.0f, 1.0f, 2.0f, 2.0f}, "linear", "half_pixel", "", ExpectedEPNodeAssignment::All); } TEST_F(QnnCPUBackendTests, Resize2xLinearAlignCorners) { - RunCPUResizeOpTest(TestInputDef({1, 3, 4, 5}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 60); + RunCPUResizeOpTest(TestInputDef({1, 3, 4, 5}, false, input_data), {1, 3, 8, 10}, "linear", "align_corners", "", ExpectedEPNodeAssignment::All); } TEST_F(QnnCPUBackendTests, Resize2xLinearAlignCorners_scales) { - RunCPUResizeOpTestWithScales(TestInputDef({1, 3, 4, 5}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 60); + RunCPUResizeOpTestWithScales(TestInputDef({1, 3, 4, 5}, false, input_data), {1.0f, 1.0f, 2.0f, 2.0f}, "linear", "align_corners", "", ExpectedEPNodeAssignment::All); } +// Test Resize downsample with mode: "linear", coordinate_transformation_mode: "align_corners" +// TODO: Enable ResizeOpTest.ResizeOpLinearDownSampleTest_4DBilinear_align_corners in cpu resize_op tests when fixed. +// +// Input f32[1,1,2,4]: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 +// Expected output f32[1, 1, 1, 2]: 1.0, 4.0 +// Actual output f32[1, 1, 1, 2]: NaN, NaN +TEST_F(QnnCPUBackendTests, DISABLED_Resize_DownSample_Linear_AlignCorners_scales) { + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + RunCPUResizeOpTestWithScales(TestInputDef({1, 1, 2, 4}, false, input_data), + {1.0f, 1.0f, 0.6f, 0.6f}, "linear", "align_corners", "", + ExpectedEPNodeAssignment::All); +} + +// Test Resize downsample with mode: "linear", coordinate_transformation_mode: "half_pixel" +// TODO: Enable ResizeOpTest.ResizeOpLinearDownSampleTest_4DBilinear cpu resize_op tests when fixed. +// +// Input f32[1,1,2,4]: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0 +// Expected output f32[1, 1, 1, 2]: 2.6666 4.3333 +// Actual output f32[1, 1, 1, 2]: NaN, NaN +TEST_F(QnnCPUBackendTests, DISABLED_Resize_DownSample_Linear_HalfPixel_scales) { + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + RunCPUResizeOpTestWithScales(TestInputDef({1, 1, 2, 4}, false, input_data), + {1.0f, 1.0f, 0.6f, 0.6f}, "linear", "half_pixel", "", + ExpectedEPNodeAssignment::All); +} + #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // // HTP tests: // +// Test QDQ Resize downsample with mode: "linear", coordinate_transformation_mode: "align_corners" +TEST_F(QnnHTPBackendTests, Resize_DownSample_Linear_AlignCorners) { + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + RunQDQResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), + {1, 1, 1, 2}, "linear", "align_corners", "", + ExpectedEPNodeAssignment::All); +} + +// Test QDQ Resize downsample with mode: "linear", coordinate_transformation_mode: "half_pixel" +TEST_F(QnnHTPBackendTests, Resize_DownSample_Linear_HalfPixel) { + std::vector input_data = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f}; + RunQDQResizeOpTest(TestInputDef({1, 1, 2, 4}, false, input_data), + {1, 1, 1, 2}, "linear", "half_pixel", "", + ExpectedEPNodeAssignment::All); +} + +// Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "pytorch_half_pixel" +// QNN EP uses QNN's Resize op. TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearPytorchHalfPixel) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), {1, 3, 8, 8}, "linear", "pytorch_half_pixel", "", ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestHalfPixelRoundPreferFloor) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - {1, 3, 8, 8}, "nearest", "half_pixel", "round_prefer_floor", +// Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "half_pixel" +// QNN EP uses QNN's Resize op. +TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearHalfPixel) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "linear", "half_pixel", "", ExpectedEPNodeAssignment::All); } -TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestAsymmetricFloor) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - {1, 3, 8, 8}, "nearest", "asymmetric", "floor", +// Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "align_corners" +// QNN EP uses QNN's Resize op. +TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearAlignCorners) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "linear", "align_corners", "", ExpectedEPNodeAssignment::All); } -// TODO: Investigate with Qualcomm. The qnn-onnx-converter tool translates ONNX Resize [nearest, asymmetric, ceil] to -// QNN ResizeNearestNeighbor {align_corners: 0, half_pixel: 0}, which is NOT equivalent. It would be better to use -// QNN's own Resize operator (instead of ResizeNearestNeighbor), but it doesn't support the "asymmetric" coordinate -// transform mode. -// -// QNN v2.13: Inaccuracy detected for output 'output', element 189. -// Output quant params: scale=0.078431375324726105, zero_point=127. -// Expected val: -2.663428783416748 -// QNN QDQ val: 7.4509806632995605 (err 10.114409446716309) -// CPU QDQ val: -2.6666667461395264 (err 0.0032379627227783203) -TEST_F(QnnHTPBackendTests, DISABLED_ResizeU8_2xNearestAsymmetricCeil) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), - {1, 3, 8, 8}, "nearest", "asymmetric", "ceil", +// Test 2x QDQ Resize mode: "linear", coordinate_transformation_mode: "asymmetric" +// QNN EP uses QNN's Resize op. +TEST_F(QnnHTPBackendTests, ResizeU8_2xLinearAsymmetric) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "linear", "asymmetric", "", ExpectedEPNodeAssignment::All); } +// Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "half_pixel", nearest_mode: "round_prefer_floor" +// QNN EP uses QNN's Resize op. +TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestHalfPixelRoundPreferFloor) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "nearest", "half_pixel", "round_prefer_floor", + ExpectedEPNodeAssignment::All); +} + +// Test that the nearest_mode "ceil" is not supported on the HTP backend. +TEST_F(QnnHTPBackendTests, ResizeU8_NearestModeCeil_Unsupported) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), + {1, 3, 8, 8}, "nearest", "asymmetric", "ceil", + ExpectedEPNodeAssignment::None); +} + +// Test 3x QDQ Resize mode: "nearest", coordinate_transformation_mode: "asymmetric", nearest_mode: "floor". +// QNN EP uses QNN's ResizeNearestNeighbor op. TEST_F(QnnHTPBackendTests, ResizeU8_3xNearestAsymmetricFloor) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), {1, 3, 12, 12}, "nearest", "asymmetric", "floor", ExpectedEPNodeAssignment::All); } +// Test 2x QDQ Resize mode: "nearest", coordinate_transformation_mode: "asymmetric", nearest_mode: "round_prefer_floor" +// QNN EP uses QNN's Resize op. +TEST_F(QnnHTPBackendTests, ResizeU8_2xNearestAsymmetricRoundPreferFloor) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 8); + RunQDQResizeOpTest(TestInputDef({1, 2, 2, 2}, false, input_data), + {1, 2, 4, 4}, "nearest", "asymmetric", "round_prefer_floor", + ExpectedEPNodeAssignment::All); +} + +// Test 3x QDQ Resize mode: "nearest", coordinate_transformation_mode: "asymmetric", nearest_mode: "round_prefer_floor" +// QNN EP uses QNN's Resize op. +// +// TODO: Inaccuracy detected for output 'output_0', element 2. +// Output quant params: scale=0.078431375324726105, zero_point=127. +// Expected val: -3.3333334922790527 +// QNN QDQ val: -9.960784912109375 (err 6.6274514198303223) +// CPU QDQ val: -3.2941176891326904 (err 0.039215803146362305) +// +// More debugging info: +// Input elements f32[1,1,2,2] = -10.0000000 -3.33333349 3.33333302 10.0000000 +// ORT CPU EP (f32 model) outputs: -10.0000000 -10.0000000 -3.33333349 -3.33333349 -3.33333349 -3.33333349 -10.00 ... +// ORT CPU EP (qdq model) outputs: -9.96078491 -9.96078491 -3.29411769 -3.29411769 -3.29411769 -3.29411769 -9.961 ... +// ORT QNN EP (qdq model) outputs: -9.96078491 -9.96078491 -9.96078491 -3.37254906 -3.37254906 -3.37254906 -9.961 ... +TEST_F(QnnHTPBackendTests, DISABLED_ResizeU8_3xNearestAsymmetricRoundPreferFloor) { + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 4); + RunQDQResizeOpTest(TestInputDef({1, 1, 2, 2}, false, input_data), + {1, 1, 6, 6}, "nearest", "asymmetric", "round_prefer_floor", + ExpectedEPNodeAssignment::All); +} + +// Test 0.5x QDQ Resize mode: "nearest", coordinate_transformation_mode: "asymmetric", nearest_mode: "floor" +// QNN EP uses QNN's ResizeNearestNeighbor op. TEST_F(QnnHTPBackendTests, ResizeU8_HalfNearestAsymmetricFloor) { - RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + std::vector input_data = GetFloatDataInRange(-10.0f, 10.0f, 48); + RunQDQResizeOpTest(TestInputDef({1, 3, 4, 4}, false, input_data), {1, 3, 2, 2}, "nearest", "asymmetric", "floor", ExpectedEPNodeAssignment::All); } diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 63498982930f5..f77c098f72116 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -32,7 +32,7 @@ static void RunOpTestOnCPU(const std::string& op_type, provider_options["backend_path"] = "libQnnCpu.so"; #endif - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, attrs, op_domain), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, attrs, op_domain), provider_options, opset_version, expected_ep_assignment); @@ -113,8 +113,8 @@ static void RunQDQOpTest(const std::string& op_type, provider_options["backend_path"] = "libQnnHtp.so"; #endif - TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, attrs, op_domain), - BuildQDQOpTestCase(op_type, input_defs, attrs, op_domain, use_contrib_qdq), + TestQDQModelAccuracy(BuildOpTestCase(op_type, input_defs, {}, attrs, op_domain), + BuildQDQOpTestCase(op_type, input_defs, {}, attrs, op_domain, use_contrib_qdq), provider_options, opset_version, expected_ep_assignment, @@ -137,7 +137,7 @@ static void RunOpTest(const std::string& op_type, #endif // Runs model with a Q/DQ binary op and compares the outputs of the CPU and QNN EPs. - RunQnnModelTest(BuildOpTestCase(op_type, input_defs, attrs, op_domain), + RunQnnModelTest(BuildOpTestCase(op_type, input_defs, {}, attrs, op_domain), provider_options, opset_version, expected_ep_assignment); @@ -698,8 +698,8 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) { // Runs model with DQ-> Atan-> Q and compares the outputs of the CPU and QNN EPs. // 1st run will generate the Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}), + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, 14, ExpectedEPNodeAssignment::All); @@ -708,8 +708,8 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheTest) { EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); // 2nd run will load and run from Qnn context cache binary file - TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}), - BuildQDQOpTestCase(op_type, {input_def}, {}), + TestQDQModelAccuracy(BuildOpTestCase(op_type, {input_def}, {}, {}), + BuildQDQOpTestCase(op_type, {input_def}, {}, {}), provider_options, 14, ExpectedEPNodeAssignment::All); diff --git a/onnxruntime/test/providers/qnn/slice_htp_test.cc b/onnxruntime/test/providers/qnn/slice_htp_test.cc index f7163f04736a5..edc079dc65276 100644 --- a/onnxruntime/test/providers/qnn/slice_htp_test.cc +++ b/onnxruntime/test/providers/qnn/slice_htp_test.cc @@ -16,51 +16,6 @@ namespace onnxruntime { namespace test { #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) -// Function that builds a model with a Slice operator. -template -GetTestModelFn BuildSliceTestCase(const TestInputDef& data_def, - const TestInputDef& starts_def, - const TestInputDef& ends_def, - const TestInputDef& axes_def, - const TestInputDef& steps_def) { - return [data_def, starts_def, ends_def, axes_def, steps_def](ModelTestBuilder& builder) { - NodeArg* data = MakeTestInput(builder, data_def); - NodeArg* starts = MakeTestInput(builder, starts_def); - NodeArg* ends = MakeTestInput(builder, ends_def); - NodeArg* axes = MakeTestInput(builder, axes_def); - NodeArg* steps = MakeTestInput(builder, steps_def); - - NodeArg* output = builder.MakeOutput(); - builder.AddNode("Slice", {data, starts, ends, axes, steps}, {output}); - }; -} - -// Function that builds a QDQ model with a Slice operator. -template -static GetTestQDQModelFn BuildQDQSliceTestCase(const TestInputDef& data_def, - const TestInputDef& starts_def, - const TestInputDef& ends_def, - const TestInputDef& axes_def, - const TestInputDef& steps_def) { - return [data_def, starts_def, ends_def, axes_def, steps_def](ModelTestBuilder& builder, - std::vector>& output_qparams) { - NodeArg* data = MakeTestInput(builder, data_def); - QuantParams data_qparams = GetTestInputQuantParams(data_def); - NodeArg* data_qdq = AddQDQNodePair(builder, data, data_qparams.scale, data_qparams.zero_point); - - NodeArg* starts = MakeTestInput(builder, starts_def); - NodeArg* ends = MakeTestInput(builder, ends_def); - NodeArg* axes = MakeTestInput(builder, axes_def); - NodeArg* steps = MakeTestInput(builder, steps_def); - - auto* slice_output = builder.MakeIntermediate(); - builder.AddNode("Slice", {data_qdq, starts, ends, axes, steps}, {slice_output}); - - // Add output -> Q -> output_u8 - AddQDQNodePairWithOutputAsGraphOutput(builder, slice_output, output_qparams[0].scale, output_qparams[0].zero_point); - }; -} - /** * Runs an Slice model on the QNN HTP backend. Checks the graph node assignment, and that inference * outputs for QNN and CPU match. @@ -86,13 +41,14 @@ static void RunSliceQDQTest(const TestInputDef& data_def, provider_options["backend_path"] = "libQnnHtp.so"; #endif - // Runs model with DQ-> Slice -> Q and compares the outputs of the CPU and QNN EPs. - TestQDQModelAccuracy(BuildSliceTestCase(data_def, starts_def, ends_def, axes_def, steps_def), - BuildQDQSliceTestCase(data_def, starts_def, ends_def, axes_def, steps_def), + const std::vector> f32_inputs = {data_def}; + const std::vector> int64_inputs = {starts_def, ends_def, axes_def, steps_def}; + + TestQDQModelAccuracy(BuildOpTestCase("Slice", f32_inputs, int64_inputs, {}), + BuildQDQOpTestCase("Slice", f32_inputs, int64_inputs, {}), provider_options, 18, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } /** @@ -119,12 +75,12 @@ static void RunSliceNonQDQOnHTP(const TestInputDef& data_def, #else provider_options["backend_path"] = "libQnnHtp.so"; #endif - - RunQnnModelTest(BuildSliceTestCase(data_def, starts_def, ends_def, axes_def, steps_def), + auto f32_model_builder = BuildOpTestCase("Slice", {data_def}, + {starts_def, ends_def, axes_def, steps_def}, {}); + RunQnnModelTest(f32_model_builder, provider_options, 13, - expected_ep_assignment, - 1e-5f); + expected_ep_assignment); } // Check that QNN compiles DQ -> Slice -> Q as a single unit. diff --git a/onnxruntime/test/providers/qnn/split_op_test.cc b/onnxruntime/test/providers/qnn/split_op_test.cc new file mode 100644 index 0000000000000..57e4b211777bb --- /dev/null +++ b/onnxruntime/test/providers/qnn/split_op_test.cc @@ -0,0 +1,387 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +template +GetTestModelFn BuildSplitTestCase(const TestInputDef& input_def, + const std::vector& split, bool split_is_input, + int64_t axis, int64_t num_outputs) { + return [input_def, split, split_is_input, axis, num_outputs](ModelTestBuilder& builder) { + std::vector op_inputs; + + op_inputs.push_back(MakeTestInput(builder, input_def)); + + if (split_is_input && !split.empty()) { + op_inputs.push_back(builder.Make1DInitializer(split)); + } + + // Determine the actual number of outputs from the 'split' or 'num_outputs' arguments. + // In opset 18, the num_outputs attribute or the split input can determine the actual number of outputs. + // In opset 13, the split input determines the number of actual outputs. + // In opsets < 13, the split attribute determines the number of actual outputs. + size_t actual_num_outputs = (num_outputs > -1) ? static_cast(num_outputs) : split.size(); + + std::vector split_outputs; + for (size_t i = 0; i < actual_num_outputs; i++) { + split_outputs.push_back(builder.MakeOutput()); + } + + Node& split_node = builder.AddNode("Split", op_inputs, split_outputs); + + if (!split_is_input && !split.empty()) { + split_node.AddAttribute("split", split); + } + + if (num_outputs > -1) { + split_node.AddAttribute("num_outputs", num_outputs); + } + + split_node.AddAttribute("axis", axis); + }; +} + +template +static void RunSplitOpTestOnCPU(const TestInputDef& input_def, + const std::vector& split, + int64_t axis, + int64_t num_outputs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + const bool split_is_input = opset >= 13; + RunQnnModelTest(BuildSplitTestCase(input_def, split, split_is_input, axis, num_outputs), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test Split opset 18 on CPU backend: equal split of axis 0 via 'num_outputs' attribute +// and 'split' input. +TEST_F(QnnCPUBackendTests, Split_Equal_Axis0_Opset18) { + // Use 'split' input (initializer). + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + + // Use 'num_outputs' attribute. + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 13 on CPU backend: equal split of axis 0 +TEST_F(QnnCPUBackendTests, Split_Equal_Axis0_Opset13) { + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 11 on CPU backend: equal split of axis 0 +TEST_F(QnnCPUBackendTests, Split_Equal_Axis0_Opset11) { + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 13 on CPU backend: unequal split of axis 1 +TEST_F(QnnCPUBackendTests, Split_Unequal_Axis1_Opset13) { + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 11 on CPU backend: unequal split of axis 1 +TEST_F(QnnCPUBackendTests, Split_Unequal_Axis1_Opset11) { + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); + RunSplitOpTestOnCPU(TestInputDef({2, 4}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Return function that builds a model with a QDQ Split. +template +GetTestQDQModelFn BuildQDQSplitTestCase(const TestInputDef& input_def, + const std::vector& split, + bool split_is_input, + int64_t axis, + int64_t num_outputs, + bool use_contrib_qdq = false) { + return [input_def, split, split_is_input, axis, num_outputs, + use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + std::vector op_inputs; + + // Add QDQ input + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_after_qdq = AddQDQNodePair(builder, input, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + op_inputs.push_back(input_after_qdq); + + // Add split input + if (split_is_input && !split.empty()) { + op_inputs.push_back(builder.Make1DInitializer(split)); + } + + // Determine the actual number of outputs from the 'split' or 'num_outputs' arguments. + // In opset 18, the num_outputs attribute or the split input can determine the actual number of outputs. + // In opset 13, the split input determines the number of actual outputs. + // In opsets < 13, the split attribute determines the number of actual outputs. + size_t actual_num_outputs = (num_outputs > -1) ? static_cast(num_outputs) : split.size(); + + std::vector split_outputs; + for (size_t i = 0; i < actual_num_outputs; i++) { + split_outputs.push_back(builder.MakeIntermediate()); + } + + Node& split_node = builder.AddNode("Split", op_inputs, split_outputs); + + if (!split_is_input && !split.empty()) { + split_node.AddAttribute("split", split); + } + + if (num_outputs > -1) { + split_node.AddAttribute("num_outputs", num_outputs); + } + + split_node.AddAttribute("axis", axis); + + // op_output -> Q -> DQ -> output + assert(output_qparams.size() == actual_num_outputs); + for (size_t i = 0; i < actual_num_outputs; i++) { + // NOTE: Input and output quantization parameters must be equal for Split. + output_qparams[i] = input_qparams; + AddQDQNodePairWithOutputAsGraphOutput(builder, split_outputs[i], output_qparams[i].scale, + output_qparams[i].zero_point, use_contrib_qdq); + } + }; +} + +// Runs a non-QDQ Split operator on the HTP backend. +template +static void RunSplitOpTestOnHTP(const TestInputDef& input_def, + const std::vector& split, + int64_t axis, + int64_t num_outputs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + const bool split_is_input = opset >= 13; + RunQnnModelTest(BuildSplitTestCase(input_def, split, split_is_input, axis, num_outputs), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ Split operator on the HTP backend. +template +static void RunQDQSplitOpTestOnHTP(const TestInputDef& input_def, + const std::vector& split, + int64_t axis, + int64_t num_outputs, + int opset, + ExpectedEPNodeAssignment expected_ep_assignment, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + const bool split_is_input = opset >= 13; + auto f32_model_builder = BuildSplitTestCase(input_def, split, split_is_input, axis, num_outputs); + auto qdq_model_builder = BuildQDQSplitTestCase(input_def, split, split_is_input, axis, num_outputs, + use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test that HTP can run non-QDQ Split (int32 input). +TEST_F(QnnHTPBackendTests, Split_Int32_Opset13) { + // Equal split. + RunSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1, 2, 3, 4, 5, 6, 7, 8}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Split opset 18 on HTP backend: equal split of axis 0 via 'num_outputs' attribute +// and 'split' input. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset18) { + // Use 'split' input (initializer). + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); + + // Use 'num_outputs' attribute. + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Split opset 18 on HTP backend: equal split of axis 0 via 'num_outputs' attribute +// and 'split' input. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset18_U16) { + // Use 'split' input (initializer). + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft Q/DQ ops + + // Use 'num_outputs' attribute. + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {}, // split (use num_outputs instead) + 0, // axis + 2, // num_outputs + 18, // opset + ExpectedEPNodeAssignment::All, + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Split op on HTP backend: equal split on axis 0 with opset 13. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset13) { + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test QDQ Split op on HTP backend: equal split on axis 0 with opset 11. +TEST_F(QnnHTPBackendTests, Split_Equal_Axis0_Opset11) { + RunQDQSplitOpTestOnHTP(TestInputDef({4, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {2, 2}, // split + 0, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 13 on HTP backend: unequal split of axis 1 +TEST_F(QnnHTPBackendTests, Split_Unequal_Axis1_Opset13) { + RunQDQSplitOpTestOnHTP(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 13) + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test Split opset 11 on HTP backend: unequal split of axis 1 +TEST_F(QnnHTPBackendTests, Split_Unequal_Axis1_Opset11) { + RunQDQSplitOpTestOnHTP(TestInputDef({2, 4}, false, {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.f, 8.f}), + {1, 3}, // split + 1, // axis + -1, // num_outputs (not in opset 11) + 11, // opset + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/squeeze_unsqueeze_op_test.cc b/onnxruntime/test/providers/qnn/squeeze_unsqueeze_op_test.cc new file mode 100644 index 0000000000000..33d2f64c0315e --- /dev/null +++ b/onnxruntime/test/providers/qnn/squeeze_unsqueeze_op_test.cc @@ -0,0 +1,324 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Squeeze (or Unsqueeze) operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunSqueezeTestOnCPU(const std::string& op_type, // Squeeze or Unsqueeze + const TestInputDef& input_def, + const TestInputDef& axes_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {axes_def}, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that Squeeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Squeeze_DynamicAxes_Unsupported) { + RunSqueezeTestOnCPU("Squeeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Unsqueeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Unsqueeze_DynamicAxes_Unsupported) { + RunSqueezeTestOnCPU("Unsqueeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test Squeeze of rank 5 -> rank 2. +TEST_F(QnnCPUBackendTests, Squeeze_Rank5_Rank2_f32) { + RunSqueezeTestOnCPU("Squeeze", + TestInputDef({1, 3, 1, 2, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {0, 2}), // Squeeze axes 0 and 2 => (3, 2, 4) + ExpectedEPNodeAssignment::All); +} + +// Test Squeeze of rank 4 -> rank 3 with a negative axes value. +TEST_F(QnnCPUBackendTests, Squeeze_Rank4_Rank3_NegAxes_f32) { + RunSqueezeTestOnCPU("Squeeze", + TestInputDef({1, 3, 2, 1}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Squeeze last axis => (1, 3, 2) + ExpectedEPNodeAssignment::All); +} + +// Test Unsqueeze of rank 3 -> rank 5. +TEST_F(QnnCPUBackendTests, Unsqueeze_Rank3_Rank5_f32) { + RunSqueezeTestOnCPU("Unsqueeze", + TestInputDef({3, 2, 4}, false, -10.0f, 10.0f), + TestInputDef({2}, true, {0, 2}), // Add 1's => (1, 3, 1, 2, 4) + ExpectedEPNodeAssignment::All); +} + +// Test Unsqueeze of rank 3 -> rank 4 with a negative axes value. +TEST_F(QnnCPUBackendTests, Unsqueeze_Rank3_Rank4_NegAxes_f32) { + RunSqueezeTestOnCPU("Unsqueeze", + TestInputDef({1, 3, 2}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Add 1 as last axis => (1, 3, 2, 1) + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ (Un)Squeeze operator. +template +GetTestQDQModelFn BuildQDQSqueezeTestCase(const std::string& op_type, // Squeeze or Unsqueeze + const TestInputDef& input_def, + const TestInputDef& axes_def, + bool use_contrib_qdq = false) { + return [op_type, input_def, axes_def, + use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // axes input + NodeArg* axes_input = MakeTestInput(builder, axes_def); + + // (Un)Squeeze op + NodeArg* op_output = builder.MakeIntermediate(); + builder.AddNode(op_type, {input_qdq, axes_input}, {op_output}); + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for (Un)Squeeze. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, op_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + }; +} + +// Runs a model with a non-QDQ (Un)Squeeze operator on the QNN HTP backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunSqueezeTestOnHTP(const std::string& op_type, // Squeeze or Unsqueeze + const TestInputDef& input_def, + const TestInputDef& axes_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(BuildOpTestCase(op_type, {input_def}, {axes_def}, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// Runs a QDQ (Un)Squeeze model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and +// that inference running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP +// (when compared to the baseline float32 model). +template +static void RunQDQSqueezeTestOnHTP(const std::string& op_type, + const TestInputDef& input_def, + const TestInputDef& axes_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase(op_type, {input_def}, {axes_def}, {}); + auto qdq_model_builder = BuildQDQSqueezeTestCase(op_type, input_def, axes_def, use_contrib_qdq); + + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test that QDQ Squeeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Squeeze_DynamicAxes_Unsupported) { + RunQDQSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Unsqueeze with a dynamic axes input is not supported by QNN EP. +TEST_F(QnnHTPBackendTests, Unsqueeze_DynamicAxes_Unsupported) { + RunQDQSqueezeTestOnHTP("Unsqueeze", + TestInputDef({1, 3, 4, 4}, false, -10.0f, 10.0f), + TestInputDef({1}, false /* is_initializer */, {0}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test Squeeze of rank 5 -> rank 2. +TEST_F(QnnHTPBackendTests, Squeeze_Rank5_Rank2_f32) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Squeeze -> Q + // QDQ node group, which gets lowered to a single QNN Reshape node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({1, 3, 1, 2, 4}, 0, 255); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // axes_input -> + NodeArg* axes_input = builder.Make1DInitializer({0, 2}); // Squeeze axes 0 and 2 => (3, 2, 4) + + // Squeeze -> + NodeArg* squeeze_output = builder.MakeIntermediate(); + builder.AddNode("Squeeze", {input_dq, axes_input}, {squeeze_output}); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(squeeze_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Squeeze of rank 4 -> rank 3 with a negative axes value. +TEST_F(QnnHTPBackendTests, Squeeze_Rank4_Rank3_NegAxes_u8) { + RunQDQSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 2, 1}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Squeeze last axis => (1, 3, 2) + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Squeeze of rank 4 -> rank 3 with a negative axes value. +TEST_F(QnnHTPBackendTests, Squeeze_Rank4_Rank3_NegAxes_u16) { + RunQDQSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 2, 1}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Squeeze last axis => (1, 3, 2) + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test QDQ Unsqueeze of rank 3 -> rank 5. +TEST_F(QnnHTPBackendTests, Unsqueeze_Rank3_Rank5_f32) { + // We can't use the usual model-building functions because they add standalone Quantize and Dequantize nodes + // at the input and output. These Q/DQ ops get lowered to QNN's Quantize and Dequantize operators, which DO NOT + // support rank 5 tensors. Therefore, we have to create a test model that only instantiates the DQ -> Unsqueeze -> Q + // QDQ node group, which gets lowered to a single QNN Reshape node. + GetTestModelFn model_fn = [](ModelTestBuilder& builder) { + // input (u8) -> DQ -> + NodeArg* quant_input = builder.MakeInput({3, 2, 4}, 0, 255); + NodeArg* input_dq = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(quant_input, 1.0f, 0, input_dq); // scale = 1.0, zp = 0 + + // axes_input -> + NodeArg* axes_input = builder.Make1DInitializer({0, 2}); // Add 1's => (1, 3, 1, 2, 4) + + // Unsqueeze -> + NodeArg* unsqueeze_output = builder.MakeIntermediate(); + builder.AddNode("Unsqueeze", {input_dq, axes_input}, {unsqueeze_output}); + + // Q -> output (u8) + NodeArg* output = builder.MakeOutput(); + builder.AddQuantizeLinearNode(unsqueeze_output, 1.0f, 0, output); // scale = 1.0, zp = 0 + }; + + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + RunQnnModelTest(model_fn, + provider_options, + 13, // opset + ExpectedEPNodeAssignment::All); +} + +// Test 8-bit QDQ Unsqueeze of rank 3 -> rank 4 with a negative axes value. +TEST_F(QnnHTPBackendTests, Unsqueeze_Rank3_Rank4_NegAxes_u8) { + RunQDQSqueezeTestOnHTP("Unsqueeze", + TestInputDef({1, 3, 2}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Add 1 as last axis => (1, 3, 2, 1) + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Unsqueeze of rank 3 -> rank 4 with a negative axes value. +TEST_F(QnnHTPBackendTests, Unsqueeze_Rank3_Rank4_NegAxes_u16) { + RunQDQSqueezeTestOnHTP("Unsqueeze", + TestInputDef({1, 3, 2}, false, -10.0f, 10.0f), + TestInputDef({1}, true, {-1}), // Add 1 as last axis => (1, 3, 2, 1) + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +// Test that int32 Squeeze runs on HTP backend. +TEST_F(QnnHTPBackendTests, Squeeze_Int32_Rank4_Rank3) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunSqueezeTestOnHTP("Squeeze", + TestInputDef({1, 3, 2, 2}, false, input_data), + TestInputDef({1}, true, {0}), // Squeeze 0th axis => (3, 2, 2) + ExpectedEPNodeAssignment::All); +} + +// Test that int32 Unsqueeze runs on HTP backend. +TEST_F(QnnHTPBackendTests, Unsqueeze_Int32_Rank3_Rank4) { + std::vector input_data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + RunSqueezeTestOnHTP("Unsqueeze", + TestInputDef({3, 2, 2}, false, input_data), + TestInputDef({1}, true, {0}), // Unsqueeze 0th axis => (1, 3, 2, 2) + ExpectedEPNodeAssignment::All); +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/tile_op_test.cc b/onnxruntime/test/providers/qnn/tile_op_test.cc new file mode 100644 index 0000000000000..2b35c730ee5fe --- /dev/null +++ b/onnxruntime/test/providers/qnn/tile_op_test.cc @@ -0,0 +1,132 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Runs a model with a Tile operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunTileTestOnCPU(const TestInputDef& input_def, + const TestInputDef& repeats_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildOpTestCase("Tile", {input_def}, {repeats_def}, {}), + provider_options, + opset, + expected_ep_assignment); +} + +// Test that Tile with a dynamic repeats input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, Tile_DynamicRepeats_Unsupported) { + RunTileTestOnCPU(TestInputDef({2, 2}, false, {1.0f, 2.0f, 3.0f, 4.0f}), + TestInputDef({2}, false /* is_initializer */, {1, 2}), + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that Tile with rank 4 float input. +TEST_F(QnnCPUBackendTests, Tile_F32_Rank4) { + std::vector input_data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; + RunTileTestOnCPU(TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({4}, true /* is_initializer */, {1, 2, 1, 1}), + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ Tile operator. +template +GetTestQDQModelFn BuildQDQTileTestCase(const TestInputDef& input_def, + const TestInputDef& repeats_def, + bool use_contrib_qdq = false) { + return [input_def, repeats_def, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // repeats input + NodeArg* repeats_input = MakeTestInput(builder, repeats_def); + + // Tile op + NodeArg* tile_output = builder.MakeIntermediate(); + builder.AddNode("Tile", {input_qdq, repeats_input}, {tile_output}); + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for Tile. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, tile_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + }; +} + +// Runs a QDQ Tile model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQTileTestOnHTP(const TestInputDef& input_def, + const TestInputDef& repeats_def, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 13, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildOpTestCase("Tile", {input_def}, {repeats_def}, {}); + auto qdq_model_builder = BuildQDQTileTestCase(input_def, repeats_def, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ Tile with rank 4 input. +TEST_F(QnnHTPBackendTests, Tile_U8_Rank4) { + std::vector input_data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; + RunQDQTileTestOnHTP(TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({4}, true /* is_initializer */, {1, 2, 1, 1}), + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ Tile with rank 4 input. +TEST_F(QnnHTPBackendTests, Tile_U16_Rank4) { + std::vector input_data = {-4.0f, -3.0f, -1.0f, 0.0f, 1.0f, 2.0f, 3.0f, 4.0f}; + RunQDQTileTestOnHTP(TestInputDef({1, 2, 2, 2}, false, input_data), + TestInputDef({4}, true /* is_initializer */, {1, 2, 1, 1}), + ExpectedEPNodeAssignment::All, + 13, // opset + true); // Use com.microsoft Q/DQ ops +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/topk_op_test.cc b/onnxruntime/test/providers/qnn/topk_op_test.cc new file mode 100644 index 0000000000000..93e725af5f20e --- /dev/null +++ b/onnxruntime/test/providers/qnn/topk_op_test.cc @@ -0,0 +1,209 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#if !defined(ORT_MINIMAL_BUILD) + +#include + +#include "test/providers/qnn/qnn_test_utils.h" +#include "core/graph/node_attr_utils.h" + +#include "onnx/onnx_pb.h" +#include "gtest/gtest.h" + +namespace onnxruntime { +namespace test { + +// Returns a function that builds a model with a TopK operator. +template +inline GetTestModelFn BuildTopKTestCase(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + bool cast_output_indices = true) { + return [input_def, k_def, attrs, cast_output_indices](ModelTestBuilder& builder) { + NodeArg* input = MakeTestInput(builder, input_def); + NodeArg* k_input = MakeTestInput(builder, k_def); + + NodeArg* values_output = builder.MakeOutput(); + NodeArg* indices_output = cast_output_indices ? builder.MakeIntermediate() : builder.MakeOutput(); + Node& topk_node = builder.AddNode("TopK", {input, k_input}, {values_output, indices_output}); + + for (const auto& attr : attrs) { + topk_node.AddAttributeProto(attr); + } + + // Cast indices to uint32 + if (cast_output_indices) { + auto* uint32_indices_output = builder.MakeOutput(); + Node& cast_node = builder.AddNode("Cast", {indices_output}, {uint32_indices_output}); + const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; + cast_node.AddAttribute("to", static_cast(dst_type)); + } + }; +} + +// Runs a model with a TopK operator on the QNN CPU backend. Checks the graph node assignment +// and that inference outputs for QNN EP and CPU EP match. +template +static void RunTopKTestOnCPU(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnCpu.dll"; +#else + provider_options["backend_path"] = "libQnnCpu.so"; +#endif + + RunQnnModelTest(BuildTopKTestCase(input_def, k_def, attrs, false /*cast_output_indices*/), + provider_options, + opset, + expected_ep_assignment); +} + +// +// CPU tests: +// + +// Test that TopK with a dynamic K input is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, TopK_DynamicK_Unsupported) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, false /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that TopK with an axis attribute that is not the last dimension is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, TopK_NonLastAxis_Unsupported) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {utils::MakeAttribute("axis", static_cast(1))}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test that TopK that returns the top k minimum values is not supported by QNN EP. +TEST_F(QnnCPUBackendTests, TopK_MinValues_Unsupported) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {utils::MakeAttribute("largest", static_cast(0))}, + ExpectedEPNodeAssignment::None); // Should not be assigned to QNN EP. +} + +// Test TopK on CPU backend: top 2 largest floats from last axis +TEST_F(QnnCPUBackendTests, TopK_LargestFloats_LastAxis) { + RunTopKTestOnCPU(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All); +} + +// Test TopK on CPU backend: top 2 largest int32s from last axis +TEST_F(QnnCPUBackendTests, TopK_LargestInt32s_LastAxis) { + std::vector input_data = {-6, -5, -4, -3, -2, 0, 1, 2, 3, 4, 5, 6}; + RunTopKTestOnCPU(TestInputDef({1, 2, 2, 3}, false, input_data), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All); +} + +#if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +// +// HTP tests: +// + +// Returns a function that creates a graph with a QDQ TopK operator. +template +GetTestQDQModelFn BuildQDQTopKTestCase(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + bool use_contrib_qdq = false) { + return [input_def, k_def, attrs, use_contrib_qdq](ModelTestBuilder& builder, + std::vector>& output_qparams) { + // input -> Q -> DQ -> + NodeArg* input = MakeTestInput(builder, input_def); + QuantParams input_qparams = GetTestInputQuantParams(input_def); + NodeArg* input_qdq = AddQDQNodePair(builder, input, input_qparams.scale, input_qparams.zero_point, + use_contrib_qdq); + + // K input + NodeArg* k_input = MakeTestInput(builder, k_def); + + // Reshape op + NodeArg* values_output = builder.MakeIntermediate(); + NodeArg* indices_output = builder.MakeIntermediate(); + Node& topk_node = builder.AddNode("TopK", {input_qdq, k_input}, {values_output, indices_output}); + + for (const auto& attr : attrs) { + topk_node.AddAttributeProto(attr); + } + + // op_output -> Q -> DQ -> output + // NOTE: Input and output quantization parameters must be equal for Reshape. + output_qparams[0] = input_qparams; // Overwrite! + AddQDQNodePairWithOutputAsGraphOutput(builder, values_output, input_qparams.scale, + input_qparams.zero_point, use_contrib_qdq); + + // Cast indices to uint32 (HTP backend does not support int64 graph outputs) + auto* uint32_indices_output = builder.MakeOutput(); + Node& cast_node = builder.AddNode("Cast", {indices_output}, {uint32_indices_output}); + const auto dst_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_UINT32; + cast_node.AddAttribute("to", static_cast(dst_type)); + }; +} + +// Runs a QDQ TopK model on the QNN (HTP) EP and the ORT CPU EP. Checks the graph node assignment and that inference +// running the QDQ model on QNN EP is at least as accurate as on ORT CPU EP (compared to the baseline float32 model). +template +static void RunQDQTopKTestOnHTP(const TestInputDef& input_def, + const TestInputDef& k_def, + const std::vector& attrs, + ExpectedEPNodeAssignment expected_ep_assignment, + int opset = 19, + bool use_contrib_qdq = false) { + ProviderOptions provider_options; + +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + auto f32_model_builder = BuildTopKTestCase(input_def, k_def, attrs, true /*cast_output_indices*/); + auto qdq_model_builder = BuildQDQTopKTestCase(input_def, k_def, attrs, use_contrib_qdq); + TestQDQModelAccuracy(f32_model_builder, + qdq_model_builder, + provider_options, + opset, + expected_ep_assignment); +} + +// Test 8-bit QDQ TopK on HTP backend: top 2 largest floats from last axis +TEST_F(QnnHTPBackendTests, TopK_LargestFloats_U8_LastAxis) { + RunQDQTopKTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-10.0f, 10.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All); +} + +// Test 16-bit QDQ TopK on HTP backend: top 2 largest floats from last axis +// TODO: Inaccuracy detected for output 'output_0', element 6. +// Output quant params: scale=0.00061036087572574615, zero_point=32768. +// Expected val: -7.2340402603149414 +// QNN QDQ val: -17.446556091308594 (err 10.212515830993652) +// CPU QDQ val: -7.2339968681335449 (err 4.3392181396484375e-05) +TEST_F(QnnHTPBackendTests, DISABLED_TopK_LargestFloats_U16_LastAxis) { + RunQDQTopKTestOnHTP(TestInputDef({1, 3, 4, 4}, false, GetFloatDataInRange(-20.0f, 20.0f, 48)), + TestInputDef({1}, true /* is_initializer */, {2}), + {}, // Attributes + ExpectedEPNodeAssignment::All, + 19, // opset + true); // Use com.microsoft Q/DQ ops +} + +#endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) +} // namespace test +} // namespace onnxruntime +#endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/python/onnxruntime_test_collective.py b/onnxruntime/test/python/onnxruntime_test_collective.py index db1ebb5384730..4882b403c3c91 100644 --- a/onnxruntime/test/python/onnxruntime_test_collective.py +++ b/onnxruntime/test/python/onnxruntime_test_collective.py @@ -155,6 +155,7 @@ def _create_alltoall_ut_model_for_boolean_tensor( ) return ORTBertPretrainTest._create_model_with_opsets(graph_def) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") @parameterized.expand( [ (np.float32, TensorProto.FLOAT), @@ -193,6 +194,7 @@ def test_all_reduce(self, np_elem_type, elem_type): outputs[0], size * input, err_msg=f"{rank}: AllGather ({np_elem_type}, {elem_type}): results mismatch" ) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") @parameterized.expand( [ (np.float32, TensorProto.FLOAT, TensorProto.FLOAT), @@ -231,6 +233,7 @@ def test_all_gather(self, np_elem_type, elem_type, communication_elem_type): err_msg=f"{rank}: AllGather (axis0) ({np_elem_type}, {elem_type}, {communication_elem_type}): results mismatch", ) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") def test_all_gather_bool(self): model = self._create_allgather_ut_model((4,), 0, TensorProto.INT64, TensorProto.INT64) rank, _ = self._get_rank_size() @@ -250,6 +253,7 @@ def test_all_gather_bool(self): np.testing.assert_allclose(y, y_expected, err_msg=f"{rank}: AllGather (bool): results mismatch") + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") def test_all_gather_axis1(self): model = self._create_allgather_ut_model((128, 128), 1) rank, size = self._get_rank_size() @@ -268,6 +272,7 @@ def test_all_gather_axis1(self): np.testing.assert_allclose(outputs[0], expected_output, err_msg=f"{rank}: AllGather (axis1): results mismatch") + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") @parameterized.expand( [ (np.float32, TensorProto.FLOAT, TensorProto.FLOAT), @@ -349,6 +354,7 @@ def test_all_to_all(self, np_elem_type, elem_type, communication_elem_type): err_msg=f"{rank}: AllToAll ({np_elem_type}, {elem_type}, {communication_elem_type}): results mismatch", ) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") def test_all_to_all_bool(self): rank, _ = self._get_rank_size() diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 59f7781bb4f8a..1d954fe4370ad 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -80,11 +80,7 @@ def test_model_serialization(self): so.log_severity_level = 1 so.logid = "TestModelSerialization" so.optimized_model_filepath = "./PythonApiTestOptimizedModel.onnx" - onnxrt.InferenceSession( - get_name("mul_1.onnx"), - sess_options=so, - providers=["CPUExecutionProvider"], - ) + onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so) self.assertTrue(os.path.isfile(so.optimized_model_filepath)) os.remove(so.optimized_model_filepath) except Fail as onnxruntime_error: @@ -107,11 +103,7 @@ def test_model_serialization_with_external_initializers(self): "session.optimized_model_external_initializers_file_name", external_initializers_file ) so.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "100") - onnxrt.InferenceSession( - get_name("mnist.onnx"), - sess_options=so, - providers=["CPUExecutionProvider"], - ) + onnxrt.InferenceSession(get_name("mnist.onnx"), sess_options=so) self.assertTrue(os.path.isfile(so.optimized_model_filepath)) self.assertTrue(os.path.isfile(external_initializers_file)) os.remove(so.optimized_model_filepath) @@ -137,7 +129,7 @@ def test_model_serialization_with_external_initializers_to_directory(self): "session.optimized_model_external_initializers_file_name", external_initializers_file ) so.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "100") - onnxrt.InferenceSession(get_name("mnist.onnx"), sess_options=so, providers=["CPUExecutionProvider"]) + onnxrt.InferenceSession(get_name("mnist.onnx"), sess_options=so) self.assertTrue(os.path.isfile(so.optimized_model_filepath)) self.assertTrue(os.path.isfile(os.path.join(directory, external_initializers_file))) os.remove(so.optimized_model_filepath) @@ -163,9 +155,7 @@ def test_model_serialization_with_original_external_initializers_to_directory(se "session.optimized_model_external_initializers_file_name", external_initializers_file ) so.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "100") - onnxrt.InferenceSession( - get_name("model_with_orig_ext_data.onnx"), sess_options=so, providers=["CPUExecutionProvider"] - ) + onnxrt.InferenceSession(get_name("model_with_orig_ext_data.onnx"), sess_options=so) self.assertTrue(os.path.isfile(so.optimized_model_filepath)) self.assertTrue(os.path.isfile(os.path.join(directory, external_initializers_file))) os.remove(so.optimized_model_filepath) @@ -198,9 +188,7 @@ def test_model_serialization_with_original_external_initializers_to_current_dire # still refers to the original external data file. We shall fix this issue so that the # optimized model only refers to one external data file. so.add_session_config_entry("session.optimized_model_external_initializers_min_size_in_bytes", "10") - session1 = onnxrt.InferenceSession( - get_name("model_with_orig_ext_data.onnx"), sess_options=so, providers=["CPUExecutionProvider"] - ) + session1 = onnxrt.InferenceSession(get_name("model_with_orig_ext_data.onnx"), sess_options=so) del session1 self.assertTrue(os.path.isfile(optimized_model_filepath)) self.assertTrue(os.path.isfile(external_initializers_file)) @@ -216,9 +204,7 @@ def test_model_serialization_with_original_external_initializers_to_current_dire # verify that we can load the optimized model with external data in current directory and save # optimized model with external data to current directory. - session2 = onnxrt.InferenceSession( - optimized_model_filepath, sess_options=so2, providers=["CPUExecutionProvider"] - ) + session2 = onnxrt.InferenceSession(optimized_model_filepath, sess_options=so2) del session2 self.assertTrue(os.path.isfile(optimized_model_filepath_2)) self.assertTrue(os.path.isfile(external_initializers_file_2)) @@ -227,9 +213,7 @@ def test_model_serialization_with_original_external_initializers_to_current_dire os.remove(optimized_model_filepath) os.remove(external_initializers_file) - session3 = onnxrt.InferenceSession( - optimized_model_filepath_2, sess_options=onnxrt.SessionOptions(), providers=["CPUExecutionProvider"] - ) + session3 = onnxrt.InferenceSession(optimized_model_filepath_2, sess_options=onnxrt.SessionOptions()) del session3 os.remove(optimized_model_filepath_2) diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index e94ac5c961583..f26b6297cdbda 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -279,6 +279,9 @@ def check_model_correctness( ops_set = set(node.op_type for node in model_onnx.graph.node) check_reference_evaluator = not (ops_set & {"EmbedLayerNormalization", "Conv", "Attention", "Transpose"}) + with open(model_path_to_check, "rb") as f: + model_check = onnx.load(f) + if check_reference_evaluator and onnx_recent_enough: ref = ReferenceEvaluator(model_path_origin) ref_origin_results = ref.run(None, inputs) @@ -289,7 +292,7 @@ def check_model_correctness( output, rtol=rtol, atol=atol, - err_msg=f"Model {model_path_to_check!r} failed for providers={providers!r}.", + err_msg=f"Model {model_path_origin!r} failed for providers={providers!r}.", ) # Verifies the shapes in the quantized model. @@ -301,40 +304,52 @@ def check_model_correctness( expected_shapes[init.name] = tuple(init.dims) checked = 0 f8_quantization = False - with open(model_path_to_check, "rb") as f: - model_check = onnx.load(f) - for init in model_check.graph.initializer: - if init.name.endswith("_quantized"): - name = init.name.replace("_quantized", "") - expected = expected_shapes[name] - shape = tuple(init.dims) - if not dynamic and expected != shape: - raise AssertionError( - f"Shape mismatch for initializer {init.name!r} from {init.name!r}, " - f"shape={shape} != {expected} (expected)." - ) - else: - checked += 1 - if "zero_point" in init.name: - dt = init.data_type - f8_quantization = f8_quantization or dt in ( - TensorProto.FLOAT8E4M3FN, - TensorProto.FLOAT8E4M3FNUZ, - TensorProto.FLOAT8E5M2, - TensorProto.FLOAT8E5M2FNUZ, + for init in model_check.graph.initializer: + if init.name.endswith("_quantized"): + name = init.name.replace("_quantized", "") + expected = expected_shapes[name] + shape = tuple(init.dims) + if not dynamic and expected != shape: + raise AssertionError( + f"Shape mismatch for initializer {init.name!r} from {init.name!r}, " + f"shape={shape} != {expected} (expected)." ) - if checked == 0: - raise AssertionError( - f"Unable to check expected shape, expected_shapes={expected_shapes}, " - f"names={[init.name for init in model_check.graph.initializer]}." + else: + checked += 1 + if "zero_point" in init.name: + dt = init.data_type + f8_quantization = f8_quantization or dt in ( + TensorProto.FLOAT8E4M3FN, + TensorProto.FLOAT8E4M3FNUZ, + TensorProto.FLOAT8E5M2, + TensorProto.FLOAT8E5M2FNUZ, ) + if checked == 0: + raise AssertionError( + f"Unable to check expected shape, expected_shapes={expected_shapes}, " + f"names={[init.name for init in model_check.graph.initializer]}." + ) if f8_quantization: check_sign_f8_quantization(model_path_origin, model_path_to_check) # Verifies the expected outputs. if check_reference_evaluator and onnx_recent_enough: + reference_new_ops = [QGemm] + has_missing_reference_ops = any( + node.domain not in ["", "ai.onnx"] + and not any( + node.domain == new_node.op_domain and node.op_type == new_node.__name__ + for new_node in reference_new_ops + ) + for node in model_check.graph.node + ) + if has_missing_reference_ops: + # We need to skip the test if the model contains ops that are not supported. + testcase.skipTest( + f"Model {model_path_to_check!r} contains ops that are not supported by the reference evaluator." + ) # Needs pv.Version(onnx.__version__) >= pv.Version("1.16.0") - ref = ReferenceEvaluator(model_path_to_check, new_ops=[QGemm]) + ref = ReferenceEvaluator(model_check, new_ops=reference_new_ops) target_results = ref.run(None, inputs) testcase.assertEqual(len(origin_results), len(target_results), "result count are different") for idx, ref_output in enumerate(origin_results): diff --git a/onnxruntime/test/python/quantization/resnet_code.py b/onnxruntime/test/python/quantization/resnet_code.py new file mode 100644 index 0000000000000..2f78047c824a6 --- /dev/null +++ b/onnxruntime/test/python/quantization/resnet_code.py @@ -0,0 +1,13757 @@ +import numpy +from onnx import numpy_helper +from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info, set_model_props + + +def create_model(): + initializers = [] + nodes = [] + inputs = [] + outputs = [] + functions = [] + + # opsets + opsets = {"": 13} + + # initializers + + list_value = [ + -0.013732648454606533, + -0.005861935671418905, + 0.06889285147190094, + -0.1172710582613945, + 0.08841240406036377, + -0.03748627379536629, + 0.016256270930171013, + -0.1059316024184227, + 0.08246039599180222, + 0.14295539259910583, + -0.32958757877349854, + 0.1631188541650772, + 0.05412565544247627, + -0.10758306831121445, + 0.12607362866401672, + -0.4987836182117462, + 0.7441706657409668, + -0.24774713814258575, + -0.30415549874305725, + 0.4033295810222626, + -0.13447114825248718, + 0.04623159021139145, + 0.2380414456129074, + -1.226112723350525, + 2.150630235671997, + -1.702580213546753, + 0.5305419564247131, + -0.06836353242397308, + -0.20055373013019562, + 0.7035881280899048, + -0.8389442563056946, + -0.1904432326555252, + 1.2609282732009888, + -1.0670661926269531, + 0.4142579436302185, + 0.04739700257778168, + -0.3265092074871063, + 1.1873037815093994, + -1.6817731857299805, + 0.9709527492523193, + -0.09095840901136398, + -0.12556785345077515, + 0.0835147574543953, + -0.24109329283237457, + 0.032948240637779236, + 0.46304041147232056, + -0.6594106554985046, + 0.349990576505661, + -0.04113377630710602, + 0.016451245173811913, + 0.008994563482701778, + -0.028321878984570503, + -0.05336569994688034, + 0.16036668419837952, + -0.12088149785995483, + 0.031160499900579453, + -0.0618649423122406, + 0.07205374538898468, + 0.15965768694877625, + -0.3389044404029846, + 0.21603335440158844, + 0.04029613360762596, + -0.0813034325838089, + 0.1019665077328682, + -0.4873599112033844, + 0.7873126268386841, + -0.2951086163520813, + -0.43754327297210693, + 0.5905176401138306, + -0.21821773052215576, + 0.06022067740559578, + 0.26326146721839905, + -1.6453089714050293, + 2.606400728225708, + -1.8939754962921143, + 0.5196341276168823, + 0.0055860355496406555, + -0.2335057258605957, + 0.9807199239730835, + -1.2137882709503174, + -0.2699125409126282, + 1.7379733324050903, + -1.4401814937591553, + 0.435971736907959, + -0.04829222336411476, + -0.24543480575084686, + 1.3292583227157593, + -2.0375823974609375, + 1.2458536624908447, + -0.08251484483480453, + -0.14181238412857056, + 0.10612589120864868, + -0.21671657264232635, + 0.1129523366689682, + 0.3666985034942627, + -0.7546612024307251, + 0.42979565262794495, + -0.0976259633898735, + -0.0008812264422886074, + 0.02994859404861927, + -0.07027778774499893, + 0.01393035613000393, + 0.07363647222518921, + -0.10249849408864975, + 0.06602989137172699, + -0.012129798531532288, + 0.10730132460594177, + -0.04546127840876579, + -0.16065146028995514, + 0.14788293838500977, + -0.05488971993327141, + 0.03601694852113724, + 0.07513345777988434, + -0.23953600227832794, + 0.48062530159950256, + -0.42057543992996216, + -0.02402813360095024, + 0.17920851707458496, + -0.10703158378601074, + -0.028666120022535324, + 0.2815375030040741, + -0.860264241695404, + 1.4422725439071655, + -1.2058128118515015, + 0.5272247791290283, + -0.06504356116056442, + -0.20021803677082062, + 0.44968947768211365, + -0.3856053650379181, + -0.1589551419019699, + 0.7579770684242249, + -0.8349987268447876, + 0.3225692808628082, + 0.08153475821018219, + -0.43163740634918213, + 0.8742384910583496, + -0.9722443222999573, + 0.579015851020813, + -0.06688100844621658, + -0.12384293973445892, + 0.08289378881454468, + -0.10082041472196579, + -0.11204896867275238, + 0.3934254050254822, + -0.4511864185333252, + 0.32745760679244995, + -0.06534548103809357, + -0.028830429539084435, + 0.021844232454895973, + 0.01775779016315937, + -0.004250001162290573, + 0.013087524101138115, + -0.001250433037057519, + -0.040545206516981125, + -0.014049320481717587, + -0.024194253608584404, + -0.023865194991230965, + -0.0038033330347388983, + 0.00920871365815401, + -0.006582418456673622, + 0.0032474950421601534, + -0.0369916632771492, + -0.16640843451023102, + -0.28968843817710876, + -0.3531132638454437, + -0.26307201385498047, + -0.13392697274684906, + -0.03747623786330223, + 0.08083077520132065, + 0.2026241272687912, + 0.25018608570098877, + 0.2529378831386566, + 0.2307336926460266, + 0.13928599655628204, + 0.08631229400634766, + 0.13893137872219086, + 0.4867081344127655, + 0.7170669436454773, + 0.8331555724143982, + 0.6734364032745361, + 0.3549460768699646, + 0.16798041760921478, + -0.14487245678901672, + -0.47733625769615173, + -0.7670150995254517, + -0.875726580619812, + -0.6291986703872681, + -0.2910463213920593, + -0.09991979598999023, + -0.009158087894320488, + 0.018850643187761307, + 0.02646111696958542, + -0.009077857248485088, + 0.029430989176034927, + -0.03707962855696678, + -0.05111744999885559, + -0.02076525054872036, + 0.011828843504190445, + 0.017857171595096588, + 0.02548048458993435, + -0.009077494964003563, + 0.0022066361270844936, + -0.02064262516796589, + -0.008582246489822865, + -0.022748643532395363, + -0.03038850985467434, + 0.0006585497176274657, + -0.0016039719339460135, + -0.01612498238682747, + 0.013966801576316357, + -0.05851661041378975, + -0.21422894299030304, + -0.33863192796707153, + -0.3720807433128357, + -0.3030800521373749, + -0.1737397164106369, + -0.05903157964348793, + 0.15018144249916077, + 0.27454254031181335, + 0.31182464957237244, + 0.30118387937545776, + 0.24605700373649597, + 0.14123573899269104, + 0.14992672204971313, + 0.20660799741744995, + 0.5046274662017822, + 0.7706091403961182, + 0.8978630900382996, + 0.7368614673614502, + 0.3929724097251892, + 0.23079657554626465, + -0.21169082820415497, + -0.5920398235321045, + -0.893406867980957, + -0.9499238729476929, + -0.730407178401947, + -0.3615736961364746, + -0.15422092378139496, + -0.024615347385406494, + 0.005115498788654804, + 0.024657316505908966, + 0.028517475351691246, + 0.027910854667425156, + -0.009482389315962791, + -0.042242538183927536, + -0.017875321209430695, + 0.00430292496457696, + 0.015949612483382225, + 0.003636278910562396, + -0.018156034871935844, + -0.0009349065367132425, + -0.0010362856555730104, + -0.013051170855760574, + -0.009141271002590656, + -8.714485738892108e-05, + 0.02399279735982418, + 0.01753612607717514, + -0.013710699044167995, + -0.014245252124965191, + -0.0028008236549794674, + -0.08206935226917267, + -0.1098734438419342, + -0.10250325500965118, + -0.08874496072530746, + -0.031079040840268135, + 0.004536658991128206, + 0.03923843801021576, + 0.08478657901287079, + 0.07715648412704468, + 0.018803801387548447, + 0.013921198435127735, + 0.015864359214901924, + 0.04947463795542717, + 0.039856068789958954, + 0.1712094396352768, + 0.362756609916687, + 0.4192918539047241, + 0.2668488621711731, + 0.11430513113737106, + 0.06648365408182144, + -0.058979276567697525, + -0.24177154898643494, + -0.3709423542022705, + -0.3979431986808777, + -0.29706764221191406, + -0.11569518595933914, + -0.01848490908741951, + -0.015523962676525116, + 0.05081642046570778, + 0.09057094901800156, + 0.08520761132240295, + 0.04497350752353668, + -0.019453801214694977, + -0.06109466031193733, + 0.011463015340268612, + -0.008522219955921173, + -0.005283404141664505, + -0.017313135787844658, + -0.0015744483098387718, + -0.011845857836306095, + -0.016727561131119728, + -0.006708915811032057, + 0.0008860539528541267, + -0.010050912387669086, + -0.028460539877414703, + -0.0165643822401762, + -0.016545938327908516, + -0.00567589420825243, + -0.0032017906196415424, + -0.0130555285140872, + -0.026848897337913513, + -0.02615198865532875, + 0.002669057110324502, + -0.027966763824224472, + -0.03851256147027016, + -0.014509409666061401, + -0.029059220105409622, + -0.007284109480679035, + 0.04045313969254494, + 0.10005538910627365, + 0.014574537053704262, + -0.044292762875556946, + -0.01750861294567585, + -0.02231375314295292, + -0.004432118032127619, + 0.10051869601011276, + 0.1443023532629013, + 0.0508832149207592, + -0.04350621998310089, + -0.0025447055231779814, + -0.014583000913262367, + -0.02153291553258896, + 0.018860718235373497, + 0.03618147224187851, + 0.007304056081920862, + -0.029104959219694138, + 0.00576505484059453, + -0.016025763005018234, + -0.025094063952565193, + -0.05296780914068222, + -0.037012189626693726, + -0.04414081946015358, + -0.053135257214307785, + -0.028890708461403847, + -0.010220452211797237, + -0.027575822547078133, + -0.01087758969515562, + -0.027209162712097168, + -0.030827227979898453, + -0.007646164856851101, + -0.016133273020386696, + 0.000639698002487421, + -0.0034172122832387686, + 0.03914793208241463, + 0.030786357820034027, + 0.005965455900877714, + 0.020923329517245293, + -0.03435938432812691, + -0.0026781477499753237, + 0.04278327897191048, + 0.20045910775661469, + 0.21770593523979187, + 0.09422573447227478, + 0.03198440372943878, + -0.021056609228253365, + 0.028007682412862778, + 0.19196027517318726, + 0.4791645109653473, + 0.5333831906318665, + 0.3014310598373413, + 0.103666290640831, + -0.03651479259133339, + 0.027079502120614052, + 0.19239209592342377, + 0.5168290138244629, + 0.5564895868301392, + 0.2977963089942932, + 0.07770062237977982, + -0.042239490896463394, + -0.017265107482671738, + 0.08760321140289307, + 0.2775075435638428, + 0.312491774559021, + 0.12284757196903229, + 0.019664151594042778, + -0.026643047109246254, + 0.0009152573184110224, + 0.016156431287527084, + 0.09042830765247345, + 0.08991760015487671, + 0.013326293788850307, + 0.02613811008632183, + 0.021025240421295166, + 0.0198842640966177, + 0.03375901281833649, + 0.028616728261113167, + 0.026605166494846344, + 0.04126269370317459, + 0.029309948906302452, + 0.01408455427736044, + -0.003831037785857916, + 0.01922326348721981, + -0.018229445442557335, + -0.013015883974730968, + 0.017597628757357597, + -0.007964612916111946, + 0.045263469219207764, + 0.0184696726500988, + -0.001163159729912877, + -0.1809321641921997, + -0.22486254572868347, + -0.08606110513210297, + 0.001087217591702938, + 0.037091098725795746, + -0.013625397346913815, + -0.178089901804924, + -0.5483279824256897, + -0.612791895866394, + -0.32531827688217163, + -0.06506585329771042, + 0.05076128616929054, + -0.007585812360048294, + -0.20981833338737488, + -0.6155760884284973, + -0.7119701504707336, + -0.354442298412323, + -0.04236743599176407, + 0.045713260769844055, + 0.03192479908466339, + -0.07216271013021469, + -0.310979425907135, + -0.3656359910964966, + -0.13522450625896454, + 0.008291869424283504, + 0.03362602740526199, + -0.0009240762447007, + 0.01604474149644375, + -0.055634208023548126, + -0.06180194392800331, + 0.0222025066614151, + 0.027704820036888123, + -0.034385330975055695, + -0.07050742954015732, + -0.06287489086389542, + 0.03521641716361046, + -0.00020920530369039625, + 0.05458284169435501, + 0.058752644807100296, + -0.08097169548273087, + -0.01668735221028328, + 0.18557283282279968, + 0.26208117604255676, + 0.1253771185874939, + 0.07758381962776184, + -0.022084739059209824, + 0.016727397218346596, + 0.23247942328453064, + 0.35444316267967224, + 0.21802566945552826, + -0.04409221559762955, + -0.08573070168495178, + -0.0994141548871994, + 0.07754423469305038, + 0.14311672747135162, + 0.04036660119891167, + -0.29222917556762695, + -0.38828015327453613, + -0.26185816526412964, + -0.12845511734485626, + 0.04763585329055786, + -0.017382778227329254, + -0.16010743379592896, + -0.2395028918981552, + -0.2049665004014969, + -0.041346337646245956, + 0.091490738093853, + -0.005191737785935402, + -0.07687077671289444, + -0.08105621486902237, + -0.05329642817378044, + -0.03404862806200981, + 0.11478845030069351, + 0.13328343629837036, + -0.037197597324848175, + -0.01787363924086094, + -0.016605347394943237, + 0.007853846065700054, + 0.029950136318802834, + 0.10808859020471573, + 0.02873288467526436, + -0.1766187697649002, + -0.17560969293117523, + -0.03922238200902939, + 0.14447443187236786, + 0.1534212827682495, + 0.11272227019071579, + 0.008810695260763168, + -0.1485181748867035, + 0.07839693129062653, + 0.43013128638267517, + 0.4898712635040283, + 0.26522761583328247, + 0.10202436149120331, + -0.07163076847791672, + 0.09933187812566757, + 0.47377726435661316, + 0.6340300440788269, + 0.36741772294044495, + -0.04812543839216232, + -0.17370514571666718, + -0.17513291537761688, + 0.22105705738067627, + 0.3226463794708252, + 0.09850790351629257, + -0.4044247269630432, + -0.6237908601760864, + -0.4679968059062958, + -0.1954391747713089, + 0.09878316521644592, + -0.004430827684700489, + -0.31550562381744385, + -0.5235733985900879, + -0.4510284662246704, + -0.13843706250190735, + 0.10064390301704407, + -0.006748788990080357, + -0.12714813649654388, + -0.2107744812965393, + -0.18755048513412476, + -0.05646044388413429, + 0.12781813740730286, + 0.18928050994873047, + -0.04337320104241371, + -0.04973407834768295, + -0.04690375551581383, + 0.0245530866086483, + 0.10698680579662323, + 0.1646823137998581, + 0.081840381026268, + -0.01471243891865015, + -0.03138890117406845, + -0.04195617139339447, + 0.012708203867077827, + 0.033312954008579254, + 0.02409377694129944, + -0.0036440726835280657, + -0.06239784508943558, + 0.0037516560405492783, + 0.11261500418186188, + 0.13069754838943481, + 0.05901307612657547, + 0.048614490777254105, + -0.027712708339095116, + 0.027247682213783264, + 0.19195327162742615, + 0.2688453793525696, + 0.1509387195110321, + 0.020540937781333923, + -0.004100556951016188, + -0.012650247663259506, + 0.039176344871520996, + 0.09037251025438309, + -0.004689970053732395, + -0.23859903216362, + -0.2364242821931839, + -0.15189304947853088, + -0.0761493444442749, + -0.0028172829188406467, + -0.04328106716275215, + -0.16187387704849243, + -0.21743592619895935, + -0.1282283067703247, + -0.024501819163560867, + 0.04029383510351181, + -0.027387680485844612, + -0.05414740741252899, + -0.08344019204378128, + -0.06591048091650009, + 0.012637111358344555, + 0.06905930489301682, + 0.08426016569137573, + -0.0030199100729078054, + 0.034059297293424606, + 0.01111840270459652, + 0.013492933474481106, + 0.0674189031124115, + 0.08242739737033844, + 0.006129032466560602, + -0.07763395458459854, + -0.03002289868891239, + -0.055725954473018646, + 0.008795201778411865, + 0.02994825504720211, + -0.06114519387483597, + -0.0560108907520771, + -0.008179228752851486, + -0.07149285078048706, + -0.02700420655310154, + -0.01306728646159172, + 0.06276566535234451, + 0.007125973701477051, + -0.03540417551994324, + -0.039717916399240494, + 0.009147526696324348, + -0.06517947465181351, + 0.0720859095454216, + -0.05035398155450821, + 0.06659520417451859, + -0.01841895841062069, + 0.004233633633702993, + -0.020911216735839844, + -0.004646372981369495, + 1.6690073013305664, + 0.4517613649368286, + -0.07667035609483719, + 0.005556757096201181, + -0.02638973295688629, + 0.044588603079319, + -0.020916732028126717, + 0.2571280598640442, + -0.009559552185237408, + -0.043380800634622574, + 0.03196016326546669, + -0.03783237189054489, + -0.03076902963221073, + 0.03180111199617386, + 0.06352709978818893, + 0.020281998440623283, + -0.00741154421120882, + -0.0009214285528287292, + -0.0476187989115715, + -0.07208544760942459, + -0.05323023349046707, + -0.011103631928563118, + 0.02877136506140232, + -0.05324484035372734, + -0.10076326876878738, + 0.026193000376224518, + 0.03536469116806984, + 0.045722659677267075, + -0.03756006807088852, + 0.022998394444584846, + 0.0019359687576070428, + 0.01654801517724991, + 0.047304198145866394, + -0.08431598544120789, + -0.0645647644996643, + -0.17326746881008148, + -0.10692577064037323, + -0.08416426181793213, + -0.04107839986681938, + -0.0012680464424192905, + -0.02600814774632454, + -0.014215772971510887, + 0.2114446610212326, + -0.040954578667879105, + -0.05050172284245491, + 0.004194092936813831, + -0.0025900816544890404, + -0.1359374076128006, + 0.03946976363658905, + 2.3023669719696045, + 0.7484877109527588, + -0.1994970589876175, + -0.06490366160869598, + 0.007983183488249779, + -0.017937449738383293, + -0.12516839802265167, + 0.3313288688659668, + 0.11946671456098557, + -0.16942338645458221, + -0.007721045054495335, + 0.02824605070054531, + -0.05310647189617157, + -0.1122083067893982, + -0.17094524204730988, + -0.08465421944856644, + -0.09679102897644043, + -0.03848385065793991, + 0.040121182799339294, + -0.06661732494831085, + 0.0005764663219451904, + -0.05729356408119202, + -0.04778655245900154, + -0.034835152328014374, + -0.07634143531322479, + -0.05054831504821777, + 0.00597620103508234, + 0.04499154910445213, + -0.03308190405368805, + -0.04915233701467514, + -0.05842791870236397, + 0.003590918146073818, + 0.055837079882621765, + -0.02547842636704445, + -0.018847621977329254, + -0.2073899656534195, + -0.14987564086914062, + -0.03971748799085617, + 0.05886378139257431, + 0.020922083407640457, + -0.039155181497335434, + -0.028855402022600174, + 0.08688661456108093, + -0.1402827501296997, + -0.05810496211051941, + 0.037841811776161194, + -0.04082907736301422, + -0.1191127747297287, + -0.10852136462926865, + 1.6274418830871582, + 0.3678200840950012, + -0.2865799367427826, + -0.05291350558400154, + 0.023858532309532166, + -0.046683818101882935, + -0.2307816743850708, + -0.001670230645686388, + -0.17716962099075317, + -0.16724731028079987, + 0.040194038301706314, + -0.023075448349118233, + -0.01538322027772665, + -0.07914327085018158, + -0.19621343910694122, + -0.11628971993923187, + -0.05851752683520317, + 0.06313594430685043, + 0.017808571457862854, + 0.02447943389415741, + 0.048611078411340714, + -0.009247995913028717, + 0.00789090245962143, + 0.06673033535480499, + 0.0661577433347702, + 0.019111329689621925, + 0.038164373487234116, + 0.029342610388994217, + -0.03547409921884537, + -0.11017149686813354, + -0.11077891290187836, + 0.001108204829506576, + -0.0330691784620285, + -0.05039837956428528, + 0.017638904973864555, + 0.277705579996109, + 0.5606598258018494, + 0.5469182133674622, + 0.13591277599334717, + 0.012421006336808205, + 0.046348799020051956, + -0.02721901424229145, + -0.5645118355751038, + -1.072814702987671, + -0.9852984547615051, + -0.3608386516571045, + -0.010197073221206665, + -0.09785731136798859, + -0.02597353421151638, + 0.4627133309841156, + 1.1483618021011353, + 0.9505703449249268, + 0.17471027374267578, + -0.016467586159706116, + 0.026623696088790894, + 0.04765752702951431, + -0.4000166058540344, + -0.8956774473190308, + -0.6268588304519653, + -0.09439487755298615, + 0.02861764468252659, + -0.004155704285949469, + 0.08989865332841873, + 0.27384331822395325, + 0.6518518328666687, + 0.4184596836566925, + 0.13106893002986908, + 0.0050344159826636314, + 0.007061495911329985, + -0.016157688573002815, + -0.1364346295595169, + -0.27324289083480835, + -0.14245718717575073, + -0.04623992741107941, + -0.015541884116828442, + 0.030779436230659485, + 0.03756715729832649, + 0.01957445964217186, + -0.04964561015367508, + -0.0211405660957098, + 0.044496409595012665, + -0.026335055008530617, + -0.11620140820741653, + -0.11803250014781952, + 0.18242181837558746, + 0.5057784914970398, + 0.5045838952064514, + 0.03748183697462082, + 0.05692485347390175, + 0.1608155369758606, + 0.02245517633855343, + -0.7651812434196472, + -1.5504053831100464, + -1.3563542366027832, + -0.4314505457878113, + -0.028384560719132423, + -0.12238024920225143, + 0.106974296271801, + 1.11427903175354, + 2.173083543777466, + 1.747692346572876, + 0.5455064177513123, + 0.03363418206572533, + 0.11388687789440155, + -0.05905687436461449, + -0.8059568405151367, + -1.6196117401123047, + -1.1898213624954224, + -0.2654758095741272, + -0.004251840524375439, + -0.0916782096028328, + -0.024067873135209084, + 0.22692462801933289, + 0.6695711612701416, + 0.3673460781574249, + -0.017016466706991196, + -0.029604146257042885, + 0.020365707576274872, + 0.03215239942073822, + 0.0070981839671730995, + -0.14026938378810883, + -0.02425236999988556, + 0.059152450412511826, + -0.006319367326796055, + 0.003989882301539183, + 0.048541076481342316, + 0.003988460637629032, + -0.03105335496366024, + -0.08329232037067413, + 0.03226872906088829, + 0.02119620516896248, + -0.0953872874379158, + -0.15174035727977753, + 0.07963212579488754, + 0.29094186425209045, + 0.2690921127796173, + -0.020104877650737762, + 0.024988379329442978, + 0.15326620638370514, + 0.1256464123725891, + -0.40941280126571655, + -0.946648120880127, + -0.8358487486839294, + -0.14284957945346832, + -0.07980851829051971, + -0.1435413807630539, + 0.038134895265102386, + 0.8021518588066101, + 1.552701473236084, + 1.2496209144592285, + 0.38152581453323364, + 0.07136060297489166, + 0.14329172670841217, + -0.06546801328659058, + -0.5923707485198975, + -1.253793478012085, + -0.9458200335502625, + -0.156633198261261, + -0.04217473417520523, + -0.11199303716421127, + -0.07520301640033722, + 0.15331010520458221, + 0.4794600307941437, + 0.2449675053358078, + -0.10396319627761841, + 0.0034801275469362736, + 0.04475663974881172, + 0.024035215377807617, + 0.056806568056344986, + -0.07363307476043701, + -0.001563104335218668, + 0.05157755687832832, + 0.043718185275793076, + 0.02102719619870186, + 0.11859089881181717, + 0.08675580471754074, + -0.13180124759674072, + -0.15522590279579163, + 0.03273458778858185, + -0.0019622649997472763, + 0.1011638194322586, + -0.10800585150718689, + -0.6884365677833557, + -0.5495791435241699, + 0.0780424103140831, + 0.33674973249435425, + -0.21274283528327942, + -0.4183696210384369, + -0.8053947687149048, + 0.03347628563642502, + 1.3938312530517578, + 0.9454176425933838, + -0.012210174463689327, + 0.04924672842025757, + 0.16284359991550446, + 1.1340152025222778, + 2.0020322799682617, + 0.2796843647956848, + -0.968036413192749, + -0.5768532752990723, + 0.17757350206375122, + 0.37485063076019287, + 0.11534234136343002, + -1.2916942834854126, + -1.692176103591919, + -0.30523377656936646, + 0.14307916164398193, + 0.03928302228450775, + -0.19196964800357819, + -0.4533900022506714, + -0.3294944167137146, + 0.5480389595031738, + 0.4497548043727875, + 0.2170887440443039, + -0.05817069113254547, + -0.06957870721817017, + 0.03169052675366402, + 0.23751793801784515, + 0.0823391005396843, + -0.04811413958668709, + -0.051265716552734375, + -0.0395645909011364, + -0.03849785774946213, + 0.04607917368412018, + 0.09946659207344055, + -0.029992828145623207, + -0.05369366332888603, + -0.005230880342423916, + 0.012808755040168762, + 0.1821947544813156, + 0.05478882044553757, + -0.47736144065856934, + -0.44480830430984497, + -0.036321353167295456, + 0.13646431267261505, + -0.04045571759343147, + -0.21837295591831207, + -0.6888197660446167, + -0.08431777358055115, + 0.96018385887146, + 0.6788493990898132, + 0.011028020642697811, + 0.05917810648679733, + 0.02488739602267742, + 0.6898419857025146, + 1.4259209632873535, + 0.13193827867507935, + -0.8078985810279846, + -0.31056249141693115, + 0.018122224137187004, + 0.137860506772995, + 0.051947757601737976, + -0.9757952094078064, + -1.1060559749603271, + 0.06675099581480026, + 0.2091575562953949, + -0.029623042792081833, + -0.0705878809094429, + -0.18514159321784973, + -0.07947035878896713, + 0.5719470381736755, + 0.2286168485879898, + -0.03433626517653465, + 0.0036030709743499756, + 0.006251791957765818, + 0.04144154116511345, + 0.08598234504461288, + -0.050599172711372375, + -0.10440917313098907, + -0.02927244082093239, + -0.04102599248290062, + -0.07101748138666153, + -0.03579306975007057, + 0.03586365282535553, + 0.06752362847328186, + 0.048901572823524475, + -0.020898710936307907, + -0.009411930106580257, + 0.10169848799705505, + 0.1812015175819397, + -0.014482695609331131, + -0.12548771500587463, + -0.060731250792741776, + -0.034499138593673706, + 0.0829617902636528, + 0.04616715386509895, + -0.20867496728897095, + -0.1990129053592682, + 0.1773940473794937, + 0.13156233727931976, + -0.03437860682606697, + 0.04012921825051308, + -0.11132699251174927, + -0.023460939526557922, + 0.2713286876678467, + -0.06662362813949585, + -0.2709292471408844, + -0.0030232456047087908, + -0.10379529744386673, + -0.07136038690805435, + 0.03757762163877487, + -0.20515622198581696, + -0.1231834888458252, + 0.26915228366851807, + 0.0998353362083435, + -0.031466737389564514, + 0.04657471179962158, + 0.07664929330348969, + 0.10308870673179626, + 0.23429608345031738, + -0.06942534446716309, + -0.09051290899515152, + 0.03243685141205788, + 0.04053235426545143, + -0.021392958238720894, + -0.05330868810415268, + -0.11525140702724457, + -0.03889385238289833, + 0.01636480540037155, + -0.009352890774607658, + 0.13151532411575317, + -0.14738643169403076, + -0.18289834260940552, + 0.15955400466918945, + -0.001023759599775076, + 0.028809679672122, + 0.012261062860488892, + 0.29654747247695923, + -0.285063236951828, + -0.40187928080558777, + 0.3713407516479492, + 0.009383893571794033, + -0.023022817447781563, + -0.003799814498052001, + 0.48470190167427063, + -0.43402406573295593, + -0.5858806371688843, + 0.5751441717147827, + 0.05045031011104584, + -0.05559438094496727, + -0.02045449987053871, + 0.5281224250793457, + -0.5058223605155945, + -0.5950849056243896, + 0.6492323279380798, + 0.013408469036221504, + -0.05940670147538185, + -0.0044364179484546185, + 0.3112560212612152, + -0.34908774495124817, + -0.42427319288253784, + 0.43349501490592957, + 0.03724945709109306, + -0.05263671651482582, + -0.010485195554792881, + 0.1261255145072937, + -0.1349790245294571, + -0.2524855136871338, + 0.24608080089092255, + 0.036001257598400116, + -0.028843939304351807, + 0.0056989979930222034, + 0.04458172619342804, + -0.06122935935854912, + -0.166972354054451, + 0.14557687938213348, + 0.018050044775009155, + 0.032598987221717834, + -0.0055792503990232944, + 0.24355076253414154, + -0.21433626115322113, + -0.29646870493888855, + 0.1958809792995453, + 0.015435033477842808, + 0.05235098674893379, + 0.010786890983581543, + 0.47903597354888916, + -0.4127257168292999, + -0.6203306317329407, + 0.47024452686309814, + 0.0823090448975563, + -0.04538045823574066, + -0.004072466865181923, + 0.7509317994117737, + -0.6508772969245911, + -0.8481631278991699, + 0.7875698208808899, + 0.0966777428984642, + -0.10461349785327911, + 0.0063789174892008305, + 0.7535857558250427, + -0.8082649111747742, + -0.8165622353553772, + 0.9064085483551025, + 0.04986630380153656, + -0.10200339555740356, + 0.0314355194568634, + 0.46324053406715393, + -0.5523763298988342, + -0.5632953643798828, + 0.6378755569458008, + 0.07833302766084671, + -0.07979781180620193, + 0.031164664775133133, + 0.1967470794916153, + -0.21681970357894897, + -0.29283079504966736, + 0.3367702066898346, + 0.034929461777210236, + -0.047199901193380356, + -0.0033645557705312967, + 0.05454660952091217, + -0.11264829337596893, + -0.190998375415802, + 0.17961400747299194, + 0.0009085010970011353, + -0.0001827089727157727, + 0.04841821268200874, + 0.019923821091651917, + -0.07004066556692123, + -0.10590090602636337, + 0.054114967584609985, + 0.04302384704351425, + 0.00462615629658103, + 0.022948985919356346, + 0.1673787385225296, + -0.1319379210472107, + -0.2711219787597656, + 0.2387620061635971, + 0.05667697265744209, + -0.018639734014868736, + -0.07672597467899323, + 0.3503187298774719, + -0.2981504797935486, + -0.38647517561912537, + 0.4072522521018982, + 0.010913677513599396, + -0.05246961489319801, + -0.04058554396033287, + 0.39216771721839905, + -0.3605193495750427, + -0.34857264161109924, + 0.46899959444999695, + -0.03358001261949539, + -0.05188553035259247, + -0.023204902186989784, + 0.17140533030033112, + -0.2120431810617447, + -0.2144550085067749, + 0.2837989032268524, + -0.0191226527094841, + -0.020922169089317322, + 0.004324179142713547, + 0.038136694580316544, + -0.042803723365068436, + -0.11487454175949097, + 0.11820490658283234, + 0.003412557765841484, + 0.0035020115319639444, + 0.03646541014313698, + -0.010104459710419178, + -0.010897459462285042, + -0.09292570501565933, + 0.06823977828025818, + 0.02677192911505699, + 0.020071662962436676, + 0.005776307079941034, + 0.02613351307809353, + 0.017107944935560226, + -0.0002623539185151458, + -0.039298396557569504, + -0.0314190648496151, + -0.019773684442043304, + -0.01924789510667324, + 0.04253160580992699, + 0.09694722294807434, + 0.1925637573003769, + 0.1901547759771347, + 0.09470294415950775, + -0.00296174269169569, + -0.03602522239089012, + 0.03572473302483559, + 0.08787581324577332, + 0.1773553043603897, + 0.20970025658607483, + 0.14899243414402008, + 0.05427362397313118, + -0.032429151237010956, + 0.023915717378258705, + 0.06557436287403107, + 0.13488733768463135, + 0.17550915479660034, + 0.17485061287879944, + 0.10260436683893204, + -0.005381361581385136, + -0.05573735386133194, + -0.09410752356052399, + -0.07940010726451874, + -0.03424998000264168, + 0.007975265383720398, + 0.028827181085944176, + 0.023788832128047943, + -0.02962818741798401, + -0.13474339246749878, + -0.22529757022857666, + -0.20413516461849213, + -0.14711618423461914, + -0.05960607901215553, + 0.04579121991991997, + 0.005325576290488243, + -0.11592217534780502, + -0.2260522097349167, + -0.2467145025730133, + -0.22054187953472137, + -0.13919179141521454, + 0.0016459478065371513, + 0.0515579916536808, + 0.060555730015039444, + 0.040788713842630386, + -0.017907800152897835, + -0.026459651067852974, + -0.02488812990486622, + 0.015644825994968414, + 0.10543125867843628, + 0.19312354922294617, + 0.28380078077316284, + 0.28878358006477356, + 0.16968156397342682, + 0.04848042502999306, + -0.00986899808049202, + 0.06337545067071915, + 0.16356752812862396, + 0.2444516271352768, + 0.29273414611816406, + 0.2314801961183548, + 0.12695762515068054, + -0.022283215075731277, + 0.018402203917503357, + 0.07152476161718369, + 0.14247483015060425, + 0.18759845197200775, + 0.20828258991241455, + 0.14114585518836975, + -0.047197990119457245, + -0.13794781267642975, + -0.17509934306144714, + -0.1696663200855255, + -0.1206701323390007, + -0.036128126084804535, + 0.007180679589509964, + 0.006984225939959288, + -0.09600912779569626, + -0.22975720465183258, + -0.33287662267684937, + -0.2942708134651184, + -0.20305578410625458, + -0.08411446958780289, + 0.042896877974271774, + -0.020053744316101074, + -0.16365791857242584, + -0.3145587742328644, + -0.3321540057659149, + -0.2667454183101654, + -0.1542910486459732, + -0.006954069249331951, + 0.020191870629787445, + 0.014010002836585045, + 0.0016916356980800629, + -0.04649524390697479, + -0.014931428246200085, + -0.017954425886273384, + -0.020003901794552803, + 0.03831968829035759, + 0.08447518199682236, + 0.14068123698234558, + 0.13400419056415558, + 0.08205568045377731, + -0.0004489773709792644, + -0.019211264327168465, + 0.023363608866930008, + 0.08738930523395538, + 0.12299696356058121, + 0.13070489466190338, + 0.09040816128253937, + 0.03286544978618622, + -0.006979941390454769, + -0.0010930931894108653, + 0.04313739389181137, + 0.10121051222085953, + 0.11390950530767441, + 0.11383924633264542, + 0.06694260239601135, + -0.00425445893779397, + -0.0666416585445404, + -0.09225274622440338, + -0.0977785512804985, + -0.07118111103773117, + -0.026749763637781143, + -0.019425569102168083, + 0.03321055322885513, + -0.0033978468272835016, + -0.08309262245893478, + -0.15557922422885895, + -0.14969374239444733, + -0.07188998907804489, + -0.018716221675276756, + 0.022834330797195435, + 0.004232254344969988, + -0.04141783341765404, + -0.125192329287529, + -0.14545302093029022, + -0.12225300818681717, + -0.05844716727733612, + 0.010607236064970493, + 0.024218380451202393, + -0.002702374942600727, + -0.030814893543720245, + 0.03507756441831589, + -0.0506589449942112, + 0.03415676951408386, + 0.0011444400297477841, + 0.0026324463542550802, + 0.028514407575130463, + -0.01849454641342163, + -0.030959082767367363, + -0.05565863475203514, + 0.05771413818001747, + 0.003916156478226185, + -0.004474544432014227, + 0.04403551295399666, + 0.1733711212873459, + -0.37650829553604126, + 0.22322984039783478, + 0.0032540319953113794, + -0.01139416079968214, + -0.039046600461006165, + 0.0021948080975562334, + 0.5777754783630371, + -1.1944804191589355, + 0.769478976726532, + -0.1349843591451645, + 0.0004430754925124347, + -0.0061850035563111305, + -0.08340868353843689, + 0.8327823877334595, + -1.649588942527771, + 1.126111388206482, + -0.2918313145637512, + 0.003614947199821472, + 0.0016799914883449674, + -0.03255167230963707, + 0.6123784184455872, + -1.1993682384490967, + 0.8305437564849854, + -0.13622376322746277, + 0.00905851274728775, + -0.006772476714104414, + 0.07578610628843307, + 0.05859832838177681, + -0.4543764293193817, + 0.26330503821372986, + 0.0259060300886631, + -0.0007997890934348106, + 0.01269856933504343, + 0.006897627376019955, + -0.02491801232099533, + -0.03139931708574295, + 0.0028456314466893673, + 0.0008253560517914593, + -0.01086023822426796, + -0.004186873324215412, + 0.06299160420894623, + -0.039931319653987885, + -0.09315146505832672, + 0.05495935305953026, + 0.027547571808099747, + -0.010900916531682014, + -0.025233760476112366, + 0.060600072145462036, + 0.21010243892669678, + -0.5445898771286011, + 0.35070353746414185, + -0.033771682530641556, + -0.0269146841019392, + -0.025363197550177574, + -0.021729450672864914, + 0.70921790599823, + -1.4368270635604858, + 0.9582043290138245, + -0.1708265244960785, + 0.010022420436143875, + -0.032301150262355804, + -0.08667651563882828, + 1.0338889360427856, + -1.913576364517212, + 1.262008547782898, + -0.23795078694820404, + -0.032233912497758865, + -0.01397701445966959, + -0.05402921140193939, + 0.7621430158615112, + -1.387437343597412, + 0.8621506094932556, + -0.14765247702598572, + -0.004747485741972923, + 0.0017516895895823836, + 0.08154146373271942, + 0.16601374745368958, + -0.5324177742004395, + 0.27442997694015503, + 0.03274058923125267, + -0.008812552317976952, + 0.005774920806288719, + 0.04165825620293617, + -0.011749272234737873, + -0.01953396573662758, + -0.009672109968960285, + 0.01170953270047903, + 0.003071938641369343, + -0.018979815766215324, + 0.062123894691467285, + -0.004921444226056337, + -0.03380037844181061, + 0.01310884952545166, + 0.007953890599310398, + -0.0012086924398317933, + -0.03317898139357567, + -0.0015596294542774558, + 0.08166785538196564, + -0.2291223704814911, + 0.11783571541309357, + -0.016078786924481392, + 0.018957575783133507, + 0.025793947279453278, + -0.09036394208669662, + 0.3833881616592407, + -0.5794023871421814, + 0.4610825777053833, + -0.14165280759334564, + -0.007412370759993792, + 0.05252876877784729, + -0.21435455977916718, + 0.6177686452865601, + -0.8516795635223389, + 0.667263925075531, + -0.22572898864746094, + -0.004465761594474316, + 0.02589319832623005, + -0.1893543303012848, + 0.43213585019111633, + -0.6462821364402771, + 0.434274822473526, + -0.15750259160995483, + -0.01198036689311266, + -2.4281514924950898e-05, + 0.039562296122312546, + 0.11126027256250381, + -0.23193514347076416, + 0.1412443071603775, + -0.011839920654892921, + 0.007880321703851223, + 0.02950354479253292, + 0.011689653620123863, + -0.07272310554981232, + -0.03319466486573219, + -0.003948990721255541, + 0.03549842908978462, + -0.02165558747947216, + -0.09912239760160446, + -0.08742356300354004, + 0.30591821670532227, + 0.23934677243232727, + 0.02658180706202984, + -0.022127188742160797, + -0.02769642136991024, + 0.16399237513542175, + 0.5140998959541321, + 0.007951628416776657, + -0.5589093565940857, + -0.24106110632419586, + -0.02753414213657379, + 0.06947467476129532, + 0.048558495938777924, + -0.5370690822601318, + -0.761831521987915, + 0.16272802650928497, + 0.29426246881484985, + 0.07943751662969589, + -0.022394873201847076, + -0.217612162232399, + -0.03093647211790085, + 0.5945476293563843, + 0.2873935103416443, + -0.16481661796569824, + -0.02931203693151474, + -0.029083512723445892, + 0.06754925847053528, + 0.20200076699256897, + -0.07271742075681686, + -0.1976277083158493, + -0.04189611226320267, + 0.06403793394565582, + -0.00022445111244451255, + -0.01032529678195715, + -0.03415631130337715, + 0.009091783314943314, + 0.04317992925643921, + 0.07196266949176788, + -0.025028688833117485, + -0.02722775563597679, + -0.017168480902910233, + -0.027666645124554634, + -0.06734028458595276, + 0.10843724757432938, + 0.08066407591104507, + -0.027849983423948288, + -0.0045820740051567554, + -0.03388727456331253, + 0.16772156953811646, + 0.651636004447937, + 0.34874194860458374, + -0.1454945057630539, + -0.18056720495224, + 0.11703842133283615, + 0.43017855286598206, + 0.7624525427818298, + -0.3420296907424927, + -1.272199273109436, + -0.5284644365310669, + -0.005667245015501976, + 0.08240436762571335, + -0.13299596309661865, + -1.3164156675338745, + -1.659982442855835, + 0.19898656010627747, + 0.6253566741943359, + 0.25137946009635925, + -0.18244975805282593, + -0.5360167622566223, + -0.06195700913667679, + 1.2547520399093628, + 1.0296341180801392, + 0.10651036351919174, + -0.023540280759334564, + -0.07594245672225952, + 0.1492130160331726, + 0.5033117532730103, + 0.09394379705190659, + -0.22459803521633148, + -0.22473134100437164, + -0.04738321527838707, + 0.04127531498670578, + 0.0682951882481575, + -0.02095615118741989, + -0.1233135387301445, + -0.10028401762247086, + -0.008111395873129368, + -0.000617706507910043, + 0.018859047442674637, + 0.028446361422538757, + -0.06159031391143799, + -0.1292838156223297, + 0.051308393478393555, + 0.11001072078943253, + -0.02056661807000637, + -0.012175443582236767, + -0.1313694268465042, + 0.0067574759013950825, + 0.4612729251384735, + 0.323080450296402, + -0.09392253309488297, + -0.1256203055381775, + 0.03537299111485481, + 0.2556088864803314, + 0.6467183232307434, + -0.16340143978595734, + -0.8799455165863037, + -0.3312987685203552, + 0.01464154850691557, + 0.07046713680028915, + 0.053634822368621826, + -0.8514915108680725, + -1.176972508430481, + 0.2056443840265274, + 0.4998764395713806, + 0.1268644779920578, + -0.10905193537473679, + -0.3750888705253601, + -0.06701061874628067, + 0.9052186608314514, + 0.6792045831680298, + -0.00323892361484468, + -0.0007412935374304652, + -0.03608793020248413, + 0.1009129211306572, + 0.36775916814804077, + 0.035214491188526154, + -0.2273784875869751, + -0.15815992653369904, + -0.004773923195898533, + 0.06374036520719528, + 0.04737555980682373, + -0.0563247986137867, + -0.09587392956018448, + -0.043853096663951874, + 0.032572731375694275, + -0.0036250585690140724, + 0.07889056205749512, + -0.03589344769716263, + -0.019771328195929527, + 0.04937156289815903, + 0.039052557200193405, + -0.013377528637647629, + -0.0841481015086174, + -0.03358105197548866, + -0.2128981053829193, + -0.14468812942504883, + 0.14675867557525635, + 0.2550889551639557, + 0.22369499504566193, + -0.0032973098568618298, + 0.006679064594209194, + -0.11752036958932877, + 0.025247232988476753, + 0.23064176738262177, + 0.25043538212776184, + 0.3474777638912201, + 0.2151806503534317, + 0.051294319331645966, + 0.16301114857196808, + 0.25422143936157227, + -0.1796918362379074, + -0.6128425598144531, + -0.42049655318260193, + 0.07740531116724014, + -0.007960617542266846, + 0.2504507601261139, + 0.2932300865650177, + -0.5157915949821472, + -1.2904177904129028, + -1.0362532138824463, + -0.22443994879722595, + 0.007411653641611338, + 0.16024430096149445, + 0.33939966559410095, + -0.2748318016529083, + -0.8487470149993896, + -0.5955387949943542, + 0.033155132085084915, + -0.09185351431369781, + -0.05639262869954109, + 0.17084303498268127, + 0.11292264610528946, + -0.046329669654369354, + 0.11495561897754669, + 0.31740760803222656, + -0.13903948664665222, + 0.05507560819387436, + 0.10180198401212692, + -0.1369788944721222, + -0.10618618875741959, + -0.001083499751985073, + 0.16340164840221405, + 0.07591762393712997, + 0.3417445123195648, + 0.27897438406944275, + -0.32192930579185486, + -0.5731648206710815, + -0.46150147914886475, + -0.03230089321732521, + 0.04096771031618118, + 0.22242987155914307, + 0.027000218629837036, + -0.4113498628139496, + -0.433158278465271, + -0.5252256393432617, + -0.3510502874851227, + -0.133863165974617, + -0.38554033637046814, + -0.45547229051589966, + 0.2475612610578537, + 1.154951572418213, + 0.8282179236412048, + -0.13197137415409088, + -0.03350961208343506, + -0.5282800197601318, + -0.5297923684120178, + 0.9037952423095703, + 2.516275405883789, + 2.086421489715576, + 0.3573826849460602, + -0.010694397613406181, + -0.31418153643608093, + -0.5325371026992798, + 0.48083701729774475, + 1.7732245922088623, + 1.2747145891189575, + -0.06401863694190979, + 0.14296381175518036, + 0.07267159968614578, + -0.28001847863197327, + -0.29204103350639343, + 0.12853951752185822, + -0.1998838633298874, + -0.6375644207000732, + 0.06310836225748062, + -0.020014479756355286, + -0.08150970935821533, + 0.08175478130578995, + 0.07667485624551773, + 0.0025236753281205893, + -0.08504530042409897, + -0.035742271691560745, + -0.1332666128873825, + -0.15150736272335052, + 0.18459312617778778, + 0.3363596200942993, + 0.2501969635486603, + 0.029292423278093338, + -0.060296736657619476, + -0.1142202764749527, + -0.05918247997760773, + 0.18826954066753387, + 0.2183520495891571, + 0.21247169375419617, + 0.14935970306396484, + 0.09923429787158966, + 0.21808095276355743, + 0.21930061280727386, + -0.060535889118909836, + -0.5729222297668457, + -0.4199080169200897, + 0.058897778391838074, + 0.050647757947444916, + 0.2784770131111145, + 0.2754706144332886, + -0.40136128664016724, + -1.3269731998443604, + -1.124815583229065, + -0.11878778040409088, + -0.005137663800269365, + 0.17839783430099487, + 0.2115524858236313, + -0.24165289103984833, + -0.9655010104179382, + -0.7425088286399841, + 0.0304054357111454, + -0.07012742757797241, + -0.015557953156530857, + 0.1128007024526596, + 0.18957749009132385, + -0.07996463775634766, + 0.09505810588598251, + 0.34419506788253784, + -0.3072076439857483, + 0.03868290036916733, + 0.11494885385036469, + 0.03748936951160431, + 0.0797261893749237, + -0.003397951368242502, + -0.07380004972219467, + -0.11507676541805267, + -0.10298885405063629, + 0.10698320716619492, + 0.06602972000837326, + 0.08226803690195084, + 0.0037747276946902275, + -0.162277951836586, + 0.01671667955815792, + 0.09137773513793945, + 0.18799471855163574, + 0.04144813120365143, + 0.1285877376794815, + 0.1820434182882309, + 0.04940629005432129, + 0.0991915687918663, + 0.10219171643257141, + -0.013141660951077938, + -0.051191627979278564, + 0.05468929558992386, + 0.087598517537117, + 0.15897324681282043, + 0.11863455921411514, + -0.00814050156623125, + -0.07701541483402252, + -0.14013728499412537, + -0.044140227138996124, + -0.05328791216015816, + 0.06760499626398087, + 0.12053386867046356, + 0.09780212491750717, + -0.053725965321063995, + -0.07915244251489639, + -0.0032519602682441473, + 0.019637396559119225, + 0.07848430424928665, + 0.019138827919960022, + 0.1460287868976593, + 0.1281038075685501, + 0.024417784065008163, + 0.059176862239837646, + 0.0658111497759819, + -0.016405148431658745, + -0.18877744674682617, + 0.16666102409362793, + 0.1610611230134964, + 0.08374520391225815, + 0.11570518463850021, + 0.11903064697980881, + 0.1294964700937271, + 0.06379758566617966, + 0.08417274057865143, + 0.12754113972187042, + 0.025328608229756355, + 0.05170705169439316, + 0.0835295170545578, + 0.07477264851331711, + 0.11244285851716995, + 0.11559426784515381, + 0.045258160680532455, + -0.14825093746185303, + -0.08153342455625534, + 0.06288623809814453, + 0.11952362209558487, + 0.11784297972917557, + 0.011141132563352585, + -0.21666541695594788, + -0.29976174235343933, + -0.2279169261455536, + -0.11828474700450897, + 0.12436322867870331, + 0.10465826094150543, + -0.09751085937023163, + -0.292611300945282, + -0.37374064326286316, + -0.31437963247299194, + -0.25637903809547424, + 0.06173908710479736, + 0.14131486415863037, + 0.008434675633907318, + -0.23816508054733276, + -0.30330890417099, + -0.22094152867794037, + -0.11608295142650604, + 0.13235151767730713, + 0.15353602170944214, + 0.15839524567127228, + 0.012247815728187561, + -0.08126968890428543, + -0.003756331978365779, + 0.10660683363676071, + 0.21976575255393982, + -0.04188326746225357, + 0.15462253987789154, + 0.06303395330905914, + 0.006879634689539671, + 0.008284888230264187, + 0.07084798067808151, + 0.1211942657828331, + 0.10190404951572418, + 0.02935362420976162, + -0.05645999684929848, + -0.16800500452518463, + -0.1850246787071228, + -0.09476880729198456, + -0.025327544659376144, + 0.054355036467313766, + -0.035813912749290466, + -0.18694879114627838, + -0.34871891140937805, + -0.3151862621307373, + -0.1943007856607437, + -0.09755205363035202, + 0.014881589449942112, + -0.14875493943691254, + -0.37112873792648315, + -0.37739917635917664, + -0.3241480886936188, + -0.2915399968624115, + -0.11268249899148941, + -0.019726404920220375, + -0.2510305941104889, + -0.38005372881889343, + -0.3622463345527649, + -0.2932804226875305, + -0.28574010729789734, + -0.1505027860403061, + -0.004947682376950979, + -0.18587322533130646, + -0.34759166836738586, + -0.28965193033218384, + -0.21052972972393036, + -0.18780536949634552, + -0.07400713860988617, + 0.11154936999082565, + -0.03556853160262108, + -0.1896934062242508, + -0.18135806918144226, + -0.10117948800325394, + -0.0393117293715477, + 0.06517928093671799, + -0.016659021377563477, + -0.011290309950709343, + -0.007930322550237179, + 0.008189777843654156, + 0.03678786754608154, + 0.021890517324209213, + 0.0034292477648705244, + 0.02200375869870186, + 0.0014921070542186499, + -0.0800287202000618, + -0.17657361924648285, + -0.18702608346939087, + -0.12880444526672363, + -0.022084584459662437, + 0.026420501992106438, + -0.023968446999788284, + -0.07948111742734909, + -0.16741475462913513, + -0.18733707070350647, + -0.16539834439754486, + -0.07347387820482254, + -0.009723886847496033, + -0.02016977220773697, + -0.061092622578144073, + -0.13145211338996887, + -0.15919029712677002, + -0.15043555200099945, + -0.10107766091823578, + 0.0016151965828612447, + 0.0627974420785904, + 0.08695066720247269, + 0.11727584898471832, + 0.11745581030845642, + 0.11329426616430283, + 0.0533670075237751, + -0.016355818137526512, + 0.008450252935290337, + 0.06448577344417572, + 0.1538505256175995, + 0.21232697367668152, + 0.14713847637176514, + 0.039088234305381775, + -0.015588105656206608, + 0.026483291760087013, + 0.060862988233566284, + 0.18265819549560547, + 0.23042462766170502, + 0.168768972158432, + 0.034099943935871124, + -0.018249109387397766, + -0.0321880541741848, + -0.03254542127251625, + -0.03061222843825817, + -0.0026304698549211025, + 0.017764942720532417, + 0.010707704350352287, + 0.009254949167370796, + -0.04533161595463753, + -0.1483704000711441, + -0.2637183666229248, + -0.2678598165512085, + -0.1737881749868393, + -0.049990858882665634, + 0.013515918515622616, + -0.054345693439245224, + -0.1467861533164978, + -0.24911582469940186, + -0.2831358015537262, + -0.22300836443901062, + -0.13739243149757385, + -0.017879672348499298, + -0.040345460176467896, + -0.09990613907575607, + -0.16936856508255005, + -0.2266550064086914, + -0.2020808756351471, + -0.1509508341550827, + 0.014163740910589695, + 0.07591170817613602, + 0.09185601025819778, + 0.10455341637134552, + 0.09514842182397842, + 0.09877350926399231, + 0.053898438811302185, + 0.005704578943550587, + 0.0591997392475605, + 0.13600079715251923, + 0.21777905523777008, + 0.2574957311153412, + 0.20117221772670746, + 0.11415109038352966, + -0.001181072206236422, + 0.09470006823539734, + 0.18978413939476013, + 0.3073742389678955, + 0.36875811219215393, + 0.3069853186607361, + 0.1708926260471344, + -0.0325310118496418, + -0.02656698040664196, + 0.016060845926404, + 0.02459372952580452, + 0.04165660962462425, + 0.033969976007938385, + 0.012855498120188713, + 0.030497560277581215, + 0.004896117839962244, + -0.030887477099895477, + -0.13454437255859375, + -0.1294785887002945, + -0.06398608535528183, + 0.016156472265720367, + 0.03577340394258499, + -0.0033482143189758062, + -0.07112833857536316, + -0.16465041041374207, + -0.1621057391166687, + -0.09478478878736496, + -0.03555302321910858, + -0.001592929707840085, + -0.01719600521028042, + -0.06598587334156036, + -0.1411861628293991, + -0.1496778130531311, + -0.11535074561834335, + -0.0905962884426117, + -0.013807609677314758, + 0.029542237520217896, + 0.039138730615377426, + 0.03988270089030266, + 0.02665030211210251, + 0.049553126096725464, + -0.0015685928519815207, + -0.018007200211286545, + 0.009533192962408066, + 0.06910547614097595, + 0.1034330427646637, + 0.15017645061016083, + 0.10221225768327713, + 0.020978443324565887, + -0.023747621104121208, + 0.02295384369790554, + 0.09313814342021942, + 0.1771395057439804, + 0.21169933676719666, + 0.17989481985569, + 0.05862005427479744, + -0.004540165886282921, + 0.021994179114699364, + -0.003493826137855649, + -0.000224211675231345, + 0.031808022409677505, + -0.05090906098484993, + 0.001970196608453989, + 0.01633802428841591, + 0.0049764602445065975, + 0.0006027702474966645, + -0.005952450912445784, + -0.009886081330478191, + -0.08520589768886566, + 0.030780712142586708, + 0.00037104589864611626, + 0.011886775493621826, + -0.023506291210651398, + 0.08029806613922119, + -0.005086984951049089, + -0.07738454639911652, + 0.06721897423267365, + -0.02397127076983452, + 0.006669329944998026, + -0.016343094408512115, + 0.06056324020028114, + 0.15656796097755432, + -0.49836501479148865, + 0.2475810945034027, + -0.009270203299820423, + -0.006855266634374857, + 0.0034896093420684338, + -0.027938276529312134, + 0.5722692012786865, + -1.1357109546661377, + 0.5644665956497192, + 0.015787361189723015, + -0.015141892246901989, + -0.0032788251992315054, + -0.04797150194644928, + 0.6196744441986084, + -1.1540743112564087, + 0.6065864562988281, + 0.0019708566833287477, + 0.006332532037049532, + 0.014192940667271614, + 0.03773411735892296, + 0.27323007583618164, + -0.594700813293457, + 0.2488076239824295, + -0.008853388018906116, + 0.005692378617823124, + 0.000576167949475348, + -0.027197014540433884, + 0.022015029564499855, + -0.02571249194443226, + 0.004507753532379866, + -0.002439734758809209, + -0.01994609646499157, + 0.03601142391562462, + 0.008136607706546783, + 0.01658148691058159, + -0.06548810750246048, + 0.022721221670508385, + -0.0038820707704871893, + -0.0007800398161634803, + 0.001392301986925304, + 0.09576108306646347, + -0.014628835022449493, + -0.14505760371685028, + 0.07135403156280518, + -0.00839388556778431, + -0.004555124789476395, + -0.04466082155704498, + 0.1456393599510193, + 0.3475525975227356, + -0.7879117131233215, + 0.36262738704681396, + 0.008226356469094753, + 0.0055343699641525745, + -0.061139706522226334, + 0.08975803852081299, + 0.9340736269950867, + -1.7307822704315186, + 0.796896755695343, + -0.024700213223695755, + -0.013090251013636589, + -0.05148586630821228, + 0.050525497645139694, + 0.927090048789978, + -1.7473385334014893, + 0.7727715373039246, + -0.005721901543438435, + 0.010676853358745575, + -0.012798544019460678, + 0.11131046712398529, + 0.4181194007396698, + -0.8475598096847534, + 0.33206430077552795, + 0.018843427300453186, + 0.0006885005859658122, + 0.027498219162225723, + 0.00207257061265409, + 0.0032615051604807377, + -0.021950624883174896, + -0.008452882058918476, + -0.007631891872733831, + -0.028561849147081375, + 0.04865337535738945, + -0.0023105579894036055, + -0.026170270517468452, + -0.011794357560575008, + 0.004327487666159868, + 0.01756221242249012, + 0.0011611212976276875, + -0.008793564513325691, + 0.0741758644580841, + -0.057649385184049606, + -0.006000686902552843, + -0.022717488929629326, + -0.0047143916599452496, + 0.005709030199795961, + -0.05611564591526985, + 0.05792170390486717, + 0.1873699128627777, + -0.3856293857097626, + 0.1371920108795166, + 0.018953431397676468, + 0.015250314958393574, + -0.0016827551880851388, + -0.08515634387731552, + 0.6517581939697266, + -0.9557326436042786, + 0.46986615657806396, + -0.014306572265923023, + -0.01625121757388115, + -0.016088897362351418, + -0.13429272174835205, + 0.6437729001045227, + -1.0167845487594604, + 0.5061463117599487, + 0.00879831612110138, + -0.008598369546234608, + 0.02747279778122902, + 0.007245234213769436, + 0.2527446150779724, + -0.47163763642311096, + 0.15560215711593628, + 0.005050336476415396, + -0.024848125874996185, + -0.0006449198699556291, + -0.008673148229718208, + -0.06940636038780212, + -0.016248086467385292, + 0.1250494420528412, + 0.026387182995676994, + 0.009615709073841572, + -0.0025482974015176296, + -0.04534498229622841, + -0.2626228630542755, + -0.2753732204437256, + 0.052055053412914276, + 0.010792221873998642, + 0.007360508665442467, + 0.10271529853343964, + 0.1113760769367218, + -0.31120774149894714, + -0.49849262833595276, + -0.2206398844718933, + 0.04994913563132286, + 0.054614756256341934, + 0.27786919474601746, + 0.56647789478302, + 0.20970205962657928, + -0.22717078030109406, + -0.17321231961250305, + -0.07836200296878815, + -0.09607961028814316, + 0.10685958713293076, + 0.40848156809806824, + 0.34087467193603516, + -0.005242985673248768, + -0.0682876780629158, + -0.0694413110613823, + -0.1886596381664276, + -0.04473332315683365, + 0.18096435070037842, + 0.1961163580417633, + 0.0014336564345285296, + 0.014584851451218128, + 0.0462430939078331, + -0.1556192934513092, + -0.12809665501117706, + 0.0213937908411026, + 0.10984069108963013, + -0.023050926625728607, + -0.013447473756968975, + 0.007857509888708591, + -0.027979737147688866, + -0.04768490046262741, + -0.09350565075874329, + -0.1659490317106247, + 0.007927919737994671, + 0.26641780138015747, + 0.03398526459932327, + 0.02118881419301033, + -0.006898822728544474, + -0.15209096670150757, + -0.4939330220222473, + -0.42655149102211, + 0.08215854316949844, + 0.02115131914615631, + 0.08892140537500381, + 0.2164168655872345, + 0.12431265413761139, + -0.47813764214515686, + -0.6588870882987976, + -0.3097454905509949, + 0.0837375745177269, + 0.1548176258802414, + 0.49661529064178467, + 0.7337944507598877, + 0.1966201215982437, + -0.29367199540138245, + -0.2547970116138458, + -0.11655519157648087, + -0.11720486730337143, + 0.21941716969013214, + 0.5902130603790283, + 0.42572125792503357, + 0.020460324361920357, + -0.12768393754959106, + -0.12030418962240219, + -0.2582310736179352, + -0.0355166494846344, + 0.2766987085342407, + 0.28080257773399353, + 0.08665957301855087, + 0.027141664177179337, + 0.02690703421831131, + -0.25276950001716614, + -0.23180679976940155, + 0.015180152840912342, + 0.11523276567459106, + 0.041165824979543686, + 0.017444534227252007, + 0.0009439520072191954, + -0.025763530284166336, + -0.022880665957927704, + -0.024819007143378258, + -0.04901815578341484, + 0.027672944590449333, + 0.11211585998535156, + 0.024664992466568947, + -0.010093869641423225, + 0.009466213174164295, + -0.043605536222457886, + -0.17007218301296234, + -0.1366996467113495, + 0.08740171790122986, + -0.014591479673981667, + -0.0031720874831080437, + 0.0835830345749855, + 0.028662094846367836, + -0.21436777710914612, + -0.24753160774707794, + -0.06092096120119095, + 0.03788171336054802, + 0.04295210912823677, + 0.19064708054065704, + 0.3095722496509552, + 0.08003447204828262, + -0.09509303420782089, + -0.05495578795671463, + -0.052218906581401825, + -0.07204427570104599, + 0.07710819691419601, + 0.18033725023269653, + 0.0834946483373642, + -0.049662720412015915, + -0.06561554968357086, + -0.013351643458008766, + -0.11217659711837769, + 0.031957074999809265, + 0.12180440872907639, + 0.06891122460365295, + -0.013705568388104439, + 0.0011150656500831246, + 0.03281388059258461, + -0.11285661906003952, + -0.06422404199838638, + 0.04218210279941559, + 0.014165353029966354, + -0.006244795396924019, + 0.01745765097439289, + 0.08924975246191025, + 0.01710040494799614, + -0.14013372361660004, + -0.21913501620292664, + 0.03613810986280441, + 0.14273521304130554, + 0.05801931768655777, + 0.021427493542432785, + 0.23185034096240997, + 0.2427377849817276, + -0.4384608566761017, + -0.7205182909965515, + -0.18313364684581757, + 0.033575087785720825, + -0.0809125304222107, + 0.04173902049660683, + 0.7251381874084473, + 1.1058244705200195, + -0.015065462328493595, + -0.6434917449951172, + -0.3080260753631592, + -0.090518057346344, + -0.3659006655216217, + -0.4520319700241089, + 0.5924424529075623, + 1.4148176908493042, + 0.5285682082176208, + -0.027211233973503113, + -0.07359065115451813, + -0.08583711832761765, + -0.5631492137908936, + -1.0246236324310303, + -0.1835726648569107, + 0.3307121694087982, + 0.22562064230442047, + 0.05237145721912384, + 0.13263091444969177, + 0.13899636268615723, + -0.1626550555229187, + -0.3918432295322418, + -0.03585565462708473, + 0.06904798001050949, + 0.029870154336094856, + 0.04289601743221283, + 0.05758490040898323, + 0.10055387020111084, + -0.011962685734033585, + -0.13269846141338348, + 0.0012237781193107367, + 0.05511128902435303, + 0.03764793649315834, + -0.07580426335334778, + -0.1750984787940979, + 0.0189101230353117, + 0.08156414330005646, + 0.01691802591085434, + 0.004023027140647173, + 0.18009696900844574, + 0.22744491696357727, + -0.38747039437294006, + -0.6413040161132812, + -0.19208981096744537, + 0.01971367374062538, + -0.036756888031959534, + 0.004946697968989611, + 0.7331712245941162, + 1.1178003549575806, + 0.03220612183213234, + -0.5881579518318176, + -0.24453559517860413, + -0.11856977641582489, + -0.43593257665634155, + -0.5339378118515015, + 0.49467018246650696, + 1.3376370668411255, + 0.5238692164421082, + 0.04584280773997307, + 0.004761924035847187, + -0.032823480665683746, + -0.5419207811355591, + -1.0093209743499756, + -0.19847697019577026, + 0.20687319338321686, + 0.12301573902368546, + 0.07981085777282715, + 0.14125365018844604, + 0.19885297119617462, + -0.1678825318813324, + -0.4042292535305023, + 0.004483209457248449, + 0.03009556047618389, + 0.010802071541547775, + 0.005967534612864256, + 0.0892769992351532, + 0.07342032343149185, + -0.0588892325758934, + -0.09044717997312546, + 0.06307072192430496, + -0.012583961710333824, + -0.006880680099129677, + 0.0030021765269339085, + 0.01633061282336712, + 0.06990820169448853, + 0.0070900083519518375, + -0.03546716272830963, + -0.022131899371743202, + -0.02906683459877968, + 0.010664403438568115, + -0.18731924891471863, + -0.158770352602005, + 0.08571326732635498, + 0.039154618978500366, + 0.032578419893980026, + -0.005781106185168028, + 0.17460086941719055, + 0.2787456810474396, + -0.13416190445423126, + -0.23966801166534424, + -0.004878139588981867, + 0.02796499989926815, + -0.06610933691263199, + -0.19162042438983917, + 0.11163146048784256, + 0.371842622756958, + 0.06444671750068665, + 0.016595548018813133, + 0.01164282951503992, + 0.08330011367797852, + -0.03192862868309021, + -0.2867860198020935, + -0.07080501317977905, + -0.016348646953701973, + -0.06306261569261551, + -0.016291450709104538, + 0.010558445006608963, + 0.13014638423919678, + 0.06202690303325653, + -0.03361419215798378, + 0.0691375732421875, + 0.003561250865459442, + -0.013095442205667496, + -0.050333790481090546, + -0.019117066636681557, + 0.0012089330703020096, + -0.004555183462798595, + -0.022682132199406624, + 0.04747068136930466, + -0.06425238400697708, + -0.0010437731398269534, + -0.0071629988960921764, + -0.04302623122930527, + -0.04830477759242058, + -0.04069536179304123, + -0.06627446413040161, + -0.011470981873571873, + 0.03961857780814171, + 0.026594260707497597, + -0.020662540569901466, + -0.05999285355210304, + -0.053548794239759445, + -0.025959201157093048, + -0.015834785997867584, + 0.013910192996263504, + -0.015868371352553368, + -0.056620921939611435, + -0.06785159558057785, + -0.061030179262161255, + -0.03560228645801544, + -0.04177624359726906, + -0.024657463654875755, + -0.04889696091413498, + 0.004557035397738218, + 0.15414470434188843, + 0.21642963588237762, + 0.035425592213869095, + -0.04339970648288727, + -0.05034525692462921, + -0.08522290736436844, + 0.10652441531419754, + 0.6791198253631592, + 0.7785530686378479, + 0.19941796362400055, + -0.05430706962943077, + -0.02583213709294796, + -0.055139996111392975, + 0.17940561473369598, + 0.6757862567901611, + 0.8240399360656738, + 0.25826773047447205, + -0.062254682183265686, + -0.026456547901034355, + -0.027271386235952377, + -0.0026193747762590647, + 0.11893659085035324, + 0.1915995329618454, + 0.013776157051324844, + 0.08452087640762329, + 0.009950258769094944, + 0.01774573139846325, + 0.06609759479761124, + 0.06512798368930817, + 0.07601971179246902, + 0.09192144125699997, + 0.007696932647377253, + -0.056120894849300385, + -0.03937293961644173, + 0.043086692690849304, + 0.055803027004003525, + 0.08208976686000824, + 0.03658852353692055, + 0.025779196992516518, + -0.0340605266392231, + 0.03186321631073952, + 0.09720855951309204, + 0.10651290416717529, + 0.09562067687511444, + 0.08120692521333694, + 0.06832587718963623, + 0.03940538689494133, + 0.09561086446046829, + -0.03726261481642723, + -0.3520663380622864, + -0.4187469184398651, + -0.11643502116203308, + 0.06203937157988548, + 0.056670401245355606, + 0.11540547758340836, + -0.2742924690246582, + -1.1301417350769043, + -1.2482489347457886, + -0.4411431849002838, + 0.08538330346345901, + 0.036888301372528076, + 0.08759869635105133, + -0.32129940390586853, + -1.1163593530654907, + -1.26430082321167, + -0.48638999462127686, + 0.1056363582611084, + 0.042436979711055756, + 0.07075526565313339, + -0.08341801166534424, + -0.30567145347595215, + -0.39268070459365845, + -0.10187282413244247, + -0.02507772110402584, + -0.0044433241710066795, + -0.009278317913413048, + -0.02964872494339943, + -0.018799586221575737, + -0.03760084509849548, + -0.030454028397798538, + -0.004638439975678921, + 0.026587119325995445, + 0.0095819728448987, + -0.007110759150236845, + -0.006491640582680702, + -0.028083719313144684, + -0.009543413296341896, + -0.005706887226551771, + 0.013012710027396679, + -0.010281933471560478, + -0.0544208325445652, + -0.023230208083987236, + -0.05344587564468384, + -0.04052828997373581, + -0.028035156428813934, + -0.011922319419682026, + -0.045427750796079636, + 0.020700184628367424, + 0.2117788940668106, + 0.21090814471244812, + 0.07214333862066269, + -0.019348343834280968, + -0.014455118216574192, + -0.03561105206608772, + 0.17339389026165009, + 0.49509289860725403, + 0.5219546556472778, + 0.26121678948402405, + -0.029803339391946793, + -0.013761913403868675, + -0.04028521850705147, + 0.17008572816848755, + 0.45583003759384155, + 0.4757367670536041, + 0.22357690334320068, + -0.050064269453287125, + -0.021086007356643677, + -0.039873600006103516, + 0.06433176249265671, + 0.20187893509864807, + 0.2078690379858017, + 0.07802058011293411, + 0.022050827741622925, + -0.05272649601101875, + -0.024311071261763573, + -0.12387345731258392, + -0.20065246522426605, + 0.0262442696839571, + 0.20101603865623474, + 0.056791841983795166, + -0.008266052231192589, + 0.025132112205028534, + -0.23289933800697327, + -0.5296569466590881, + -0.282010018825531, + 0.025113720446825027, + 0.13172000646591187, + 0.16999290883541107, + 0.31588253378868103, + 0.05583454668521881, + -0.5321000814437866, + -0.5585035085678101, + -0.23885560035705566, + 0.0461968369781971, + 0.13807418942451477, + 0.6536149382591248, + 0.6385176777839661, + -0.15636183321475983, + -0.5484278798103333, + -0.5470613241195679, + -0.06269911676645279, + -0.06726553291082382, + 0.5561463236808777, + 1.0985187292099, + 0.6801460385322571, + 0.12841203808784485, + -0.21693651378154755, + -0.19168342649936676, + -0.43073776364326477, + -0.15226863324642181, + 0.41150590777397156, + 0.47421786189079285, + 0.25146934390068054, + -0.017203813418745995, + -0.09694849699735641, + -0.4082376956939697, + -0.3549531400203705, + -0.023591510951519012, + 0.12086013704538345, + 0.08050766587257385, + -0.044960521161556244, + -0.0031193571630865335, + 0.014398006722331047, + -0.005931032355874777, + 0.01548685971647501, + 0.05407734215259552, + -0.006386967841535807, + 0.021660227328538895, + 0.01656133122742176, + 0.002835798542946577, + 0.0008500503608956933, + 0.021802745759487152, + 0.13470955193042755, + 0.06802596151828766, + 0.0033256933093070984, + 0.03037848509848118, + 0.054654810577631, + -0.034221138805150986, + 0.015171438455581665, + 0.23395732045173645, + 0.24771827459335327, + 0.16352902352809906, + -0.07505007833242416, + -0.0814652070403099, + -0.21493901312351227, + -0.3109704852104187, + 0.013416547328233719, + 0.12807825207710266, + 0.12044191360473633, + -0.007915153168141842, + 0.0100772799924016, + -0.15165796875953674, + -0.4013277292251587, + -0.24811144173145294, + -0.06641282886266708, + 0.022568246349692345, + 0.061083581298589706, + 0.09920243173837662, + 0.0695505365729332, + -0.12213064730167389, + -0.12606006860733032, + -0.04593949392437935, + -0.040190644562244415, + 0.03899035230278969, + 0.12688779830932617, + 0.114081971347332, + -0.013348283246159554, + 0.03325115144252777, + 0.007111718878149986, + 0.048056699335575104, + -0.003726312192156911, + 0.05401211231946945, + 0.05355936661362648, + 0.21303032338619232, + 0.2944865822792053, + -0.13604623079299927, + -0.3770989775657654, + -0.0808275118470192, + -0.006103217601776123, + -0.02005188539624214, + 0.37605899572372437, + 0.7776278853416443, + 0.32064270973205566, + -0.23708422482013702, + -0.23380732536315918, + -0.22103570401668549, + -0.45596328377723694, + 0.07213663309812546, + 0.9384943246841431, + 0.8762810230255127, + 0.3557227551937103, + -0.09239326417446136, + -0.25462013483047485, + -0.9858288168907166, + -0.9860153198242188, + 0.2600172162055969, + 0.7731484770774841, + 0.7665594816207886, + 0.14806008338928223, + 0.13109923899173737, + -0.6917864680290222, + -1.580305814743042, + -0.9557210803031921, + -0.16357193887233734, + 0.3189502954483032, + 0.28703632950782776, + 0.5599567890167236, + 0.2459551841020584, + -0.5451022982597351, + -0.6926754713058472, + -0.4368602931499481, + 0.027606861665844917, + 0.025857241824269295, + 0.5376880764961243, + 0.535673975944519, + 0.09012678265571594, + -0.14688564836978912, + -0.1812361180782318, + 0.050619762390851974, + 0.021388273686170578, + -0.05923623591661453, + -0.006538081914186478, + 0.05171535536646843, + -0.051560595631599426, + -0.007643367163836956, + 0.027748188003897667, + 0.0024676925968378782, + -0.008760283701121807, + 0.13039670884609222, + 0.18568934500217438, + 0.06342563778162003, + 0.030788781121373177, + -0.004423442296683788, + -0.041261281818151474, + 0.013299684040248394, + 0.22491391003131866, + 0.27831292152404785, + 0.0883866474032402, + 0.048967570066452026, + 0.0012756097130477428, + -0.03215779736638069, + 0.02710782177746296, + 0.20178261399269104, + 0.22446107864379883, + 0.06052157282829285, + 0.019020315259695053, + 0.02715166099369526, + -0.03146626800298691, + -0.017960363999009132, + 0.11820292472839355, + 0.16114193201065063, + 0.05221821367740631, + -0.02201441302895546, + -0.026308327913284302, + 0.008580431342124939, + -0.02444308064877987, + 0.061380185186862946, + 0.11184953153133392, + 0.006053542252629995, + -0.03248603641986847, + -0.037558719515800476, + 0.01881473697721958, + -0.02349863201379776, + 0.02150980569422245, + 0.09881952404975891, + 0.03962325677275658, + -0.0031782283913344145, + -0.0030868228059262037, + -0.007606725674122572, + -0.06136326491832733, + 0.022755015641450882, + 0.09683670848608017, + 0.0016674631042405963, + 0.01306125894188881, + 0.011335537768900394, + -0.01769089885056019, + 0.005807302892208099, + 0.19103741645812988, + 0.2631426155567169, + 0.10424992442131042, + 0.025223100557923317, + -0.024689532816410065, + -0.03370697423815727, + 0.0512213259935379, + 0.2983294129371643, + 0.37597405910491943, + 0.18788966536521912, + 0.056492965668439865, + -0.006051253993064165, + -0.027141474187374115, + 0.06733105331659317, + 0.29171472787857056, + 0.32160115242004395, + 0.14176633954048157, + 0.008538221009075642, + -0.013039524666965008, + -0.04279422387480736, + 0.03345612436532974, + 0.19111940264701843, + 0.25728005170822144, + 0.09830093383789062, + -0.03371569141745567, + -0.05277566984295845, + -0.0011038694065064192, + -0.013657800853252411, + 0.10037966072559357, + 0.1724642813205719, + 0.04436478391289711, + -0.02240786701440811, + -0.02181128039956093, + 0.019526727497577667, + -0.050060197710990906, + 0.017275504767894745, + 0.07785085588693619, + -0.001727179973386228, + -0.0014453287003561854, + 0.019352080300450325, + -0.003202121239155531, + -0.04241566359996796, + 0.005586653482168913, + 0.06037082523107529, + 0.014115821570158005, + -0.00568200321868062, + 0.018071964383125305, + -0.0007147599244490266, + 0.011219227686524391, + 0.10582104325294495, + 0.15557849407196045, + 0.06189450994133949, + 0.014160261489450932, + 0.00814653467386961, + -0.028064200654625893, + 0.026086319237947464, + 0.1474728286266327, + 0.18273885548114777, + 0.06638553738594055, + 0.019263381138443947, + 0.028977060690522194, + -0.02551555074751377, + 0.01937149092555046, + 0.12000202387571335, + 0.1285850703716278, + 0.047506313771009445, + -0.011383740231394768, + 0.02826755866408348, + -0.009583448991179466, + -0.02093282900750637, + 0.07994058728218079, + 0.0926218256354332, + 0.0318426676094532, + -0.024409465491771698, + 0.020994359627366066, + 0.03295197710394859, + -0.034276511520147324, + 0.037398867309093475, + 0.0794353187084198, + 0.022805212065577507, + 0.0015407208120450377, + 0.013169347308576107, + 0.038584139198064804, + -0.002118688775226474, + 0.03358406573534012, + 0.09085306525230408, + 0.04255761206150055, + 0.010275964625179768, + 0.025351760908961296, + 0.04205995053052902, + 0.1319226324558258, + 0.049708493053913116, + -0.03743802383542061, + -0.04293569549918175, + -0.07646205276250839, + -0.04986324533820152, + 0.15992362797260284, + 0.011027384549379349, + -0.32150742411613464, + -0.3761928677558899, + -0.1654653549194336, + -0.08728181570768356, + 0.044714685529470444, + -0.007500737439841032, + -0.41376256942749023, + -0.6625701189041138, + -0.21809393167495728, + 0.10641554743051529, + 0.09274336695671082, + 0.10189083218574524, + -0.1175118163228035, + -0.2905261516571045, + -0.06248515099287033, + 0.4791955053806305, + 0.49865299463272095, + 0.23415400087833405, + 0.12729482352733612, + -0.05814196541905403, + -0.003843356389552355, + 0.16410382091999054, + 0.40895968675613403, + 0.22034852206707, + 0.021014101803302765, + -0.05658271536231041, + -0.012199933640658855, + 0.034277670085430145, + 0.09565535932779312, + 0.18921032547950745, + 0.010441004298627377, + -0.07427560538053513, + -0.09049694985151291, + -0.00554919708520174, + 0.021386168897151947, + 0.0297325998544693, + 0.06431404501199722, + -0.07367311418056488, + -0.08734254539012909, + -0.059512097388505936, + 0.11382041126489639, + 0.19622667133808136, + 0.02534862980246544, + -0.09704668819904327, + -0.10857658833265305, + -0.10241919010877609, + -0.037928055971860886, + 0.17917697131633759, + -0.0396210141479969, + -0.472421795129776, + -0.5453466176986694, + -0.23921693861484528, + -0.06353127211332321, + 0.033679377287626266, + -0.011634309776127338, + -0.523267924785614, + -0.8400278091430664, + -0.3026646375656128, + 0.17986975610256195, + 0.20296970009803772, + 0.14190459251403809, + -0.12953802943229675, + -0.3968985378742218, + -0.13779792189598083, + 0.548722505569458, + 0.7039015293121338, + 0.4025704264640808, + 0.19535738229751587, + -0.08568660169839859, + -0.0589536651968956, + 0.1868993639945984, + 0.5782724618911743, + 0.43018248677253723, + 0.08876730501651764, + -0.10219226032495499, + -0.04660544916987419, + 0.018129168078303337, + 0.14359626173973083, + 0.3174169361591339, + 0.07668197154998779, + -0.13716676831245422, + -0.2058524489402771, + -0.023707473650574684, + 0.03213014453649521, + 0.06718969345092773, + 0.0917893499135971, + -0.10766899585723877, + -0.206499844789505, + -0.12713390588760376, + -0.03174767270684242, + 0.046395305544137955, + 0.018318502232432365, + -0.002416136907413602, + -0.027143845334649086, + -0.0036621293984353542, + -0.019220896065235138, + 0.05427055433392525, + 0.05058867856860161, + -0.05274957790970802, + -0.11321325600147247, + -0.07062514126300812, + -0.01720590703189373, + -0.00901520811021328, + 0.01746262051165104, + -0.08946436643600464, + -0.2304752618074417, + -0.1021895483136177, + 0.013501768000423908, + 0.029721295461058617, + -0.010094762779772282, + 0.009764805436134338, + -0.06424269080162048, + -0.03032868541777134, + 0.13044297695159912, + 0.12166891992092133, + 0.07157951593399048, + 0.029467372223734856, + -0.03827595338225365, + -0.031337328255176544, + -0.026486340910196304, + 0.05953369289636612, + 0.029497025534510612, + 0.022669093683362007, + -0.01055963709950447, + -0.025020133703947067, + 0.002589448355138302, + 0.017152298241853714, + 0.062067389488220215, + 0.008266719058156013, + 0.00563611788675189, + -0.0044869836419820786, + 0.003065212396904826, + 0.014371387660503387, + 0.013636622577905655, + 0.021183570846915245, + -0.012462744489312172, + -0.02493542619049549, + 0.009652925655245781, + -0.09309647232294083, + -0.09614148736000061, + 0.020278261974453926, + 0.262399286031723, + 0.0025974283926188946, + -0.09532646089792252, + -0.0391894206404686, + -0.003332971129566431, + -0.25919869542121887, + -0.2104814499616623, + 0.5975717306137085, + 0.20378711819648743, + -0.20521192252635956, + 0.005045099183917046, + 0.16707547008991241, + -0.08322134613990784, + -1.1734565496444702, + 0.4060916006565094, + 0.9109339714050293, + -0.22450445592403412, + -0.14085394144058228, + 0.19534644484519958, + 0.6220589280128479, + -1.0614460706710815, + -1.2444484233856201, + 1.1965712308883667, + 0.5032565593719482, + -0.26604175567626953, + -0.13583213090896606, + 0.6453277468681335, + 0.4994892477989197, + -1.7917202711105347, + -0.15182015299797058, + 0.7891079783439636, + 0.10711944103240967, + -0.11587982624769211, + 0.08287231624126434, + 0.7848142981529236, + -0.1764022707939148, + -1.0492321252822876, + 0.15281184017658234, + 0.3100045919418335, + -0.0461110882461071, + -0.06824400275945663, + 0.25544390082359314, + 0.3444065451622009, + -0.3189513683319092, + -0.3503313362598419, + 0.05462741479277611, + -0.041028521955013275, + 0.00624969182536006, + -0.0014677124563604593, + 0.10383514314889908, + -0.03467189520597458, + -0.03946290910243988, + 0.012734192423522472, + -0.003676857566460967, + -0.1616411954164505, + -0.034441810101270676, + 0.34758275747299194, + -0.0017601394793018699, + -0.17407774925231934, + 0.05167992413043976, + 0.12394318729639053, + -0.018228475004434586, + -0.71342533826828, + 0.39672648906707764, + 0.4870489537715912, + -0.27272745966911316, + -0.02687050960958004, + 0.09090551733970642, + 0.46698617935180664, + -0.6089348196983337, + -0.7488552331924438, + 0.8327828645706177, + 0.19947239756584167, + -0.17806877195835114, + -0.09197663515806198, + 0.3198661506175995, + 0.42619431018829346, + -1.1321229934692383, + -0.05452701821923256, + 0.4155597984790802, + -0.001295815804041922, + -0.06596186012029648, + -0.05821318179368973, + 0.4515152871608734, + 0.06321248412132263, + -0.6065720319747925, + 0.10882120579481125, + 0.13767170906066895, + 0.01809641905128956, + -0.070295050740242, + 0.04035783186554909, + 0.22459834814071655, + -0.048405971378088, + -0.14622822403907776, + -0.01119917817413807, + 0.00666345190256834, + 0.04815478250384331, + -0.017866114154458046, + -0.04813665896654129, + -0.02366034686565399, + 0.03589487820863724, + -0.0066519430838525295, + 0.0004148671869188547, + -0.014153627678751945, + 0.04403751716017723, + 0.04098428785800934, + -0.10525348782539368, + -0.0078808031976223, + 0.0444580540060997, + -0.027595041319727898, + 0.010916849598288536, + -0.1390431821346283, + 0.20334453880786896, + -0.006475532427430153, + -0.16053295135498047, + 0.06964287906885147, + -0.025649840012192726, + 0.12622858583927155, + -0.09694403409957886, + -0.09791161119937897, + 0.2617567479610443, + -0.06268735229969025, + -0.03128494322299957, + -0.017743078991770744, + -0.02372320368885994, + 0.2195650041103363, + -0.2456466406583786, + 0.031090563163161278, + 0.010196326300501823, + -0.04323133826255798, + 0.02746250294148922, + -0.079569011926651, + 0.06894756853580475, + 0.11414647102355957, + -0.12175147980451584, + 0.025397513061761856, + 0.006027852185070515, + 0.013360690325498581, + -0.024561991915106773, + -0.10966529697179794, + 0.04913714900612831, + 0.09801583737134933, + 0.00013951699656900018, + -0.03194398432970047, + 0.002382949460297823, + -0.003335593966767192, + 0.023621119558811188, + 0.024585755541920662, + -0.016027197241783142, + -0.02846739999949932, + -0.012949706055223942, + -0.020852699875831604, + -0.016913240775465965, + 0.016088848933577538, + 0.141468346118927, + 0.07285624742507935, + -0.008997173048555851, + -0.033306676894426346, + -0.03418722003698349, + -0.15127411484718323, + -0.047440435737371445, + 0.2687169015407562, + 0.17237843573093414, + 0.03505166247487068, + -0.06994523108005524, + -0.031143782660365105, + -0.3024960458278656, + -0.1552041918039322, + 0.33517369627952576, + 0.28441429138183594, + 0.06471730768680573, + -0.0613982267677784, + -0.02271229960024357, + -0.29379361867904663, + -0.3259792923927307, + 0.16062304377555847, + 0.29220375418663025, + 0.10862076282501221, + -0.005909152328968048, + 0.049116987735033035, + -0.20140305161476135, + -0.3278747797012329, + -0.02566053718328476, + 0.14338354766368866, + 0.006411381531506777, + -0.007274044211953878, + 0.08232597261667252, + -0.04198717698454857, + -0.17330540716648102, + -0.01131037063896656, + 0.08018575608730316, + -0.02374250255525112, + -0.002276432001963258, + 0.00019528658594936132, + -0.024716932326555252, + 0.026509074494242668, + 0.08361849933862686, + 0.012956380844116211, + -0.06030649319291115, + -0.020338360220193863, + -0.03577016666531563, + -0.06858085840940475, + 0.008245388977229595, + 0.25225168466567993, + 0.16135559976100922, + -0.03690743073821068, + -0.09188401699066162, + -0.10410526394844055, + -0.25971388816833496, + -0.07926154136657715, + 0.3933144509792328, + 0.33186599612236023, + 0.059405017644166946, + -0.11824909597635269, + -0.10528354346752167, + -0.4808295667171478, + -0.25224801898002625, + 0.4267246127128601, + 0.4853539764881134, + 0.16933484375476837, + -0.073345847427845, + -0.02648857608437538, + -0.4723232388496399, + -0.4904792010784149, + 0.1938265562057495, + 0.44070878624916077, + 0.22439399361610413, + 0.03877745941281319, + 0.08536087721586227, + -0.31432414054870605, + -0.5158097743988037, + -0.09537900239229202, + 0.20227058231830597, + 0.07895126938819885, + 0.059195615351200104, + 0.14728911221027374, + -0.059377528727054596, + -0.2884902060031891, + -0.12288203090429306, + 0.05220698565244675, + -0.045279599726200104, + 0.019795719534158707, + -0.009819806553423405, + -0.013713877648115158, + 0.0012175077572464943, + 0.03281072899699211, + 0.0017424041870981455, + -0.028847966343164444, + -0.0032059827353805304, + -0.020358575507998466, + 0.0009416870889253914, + -0.007760196924209595, + 0.07921157032251358, + 0.03826644644141197, + -0.02976907789707184, + -0.03300238028168678, + -0.017963968217372894, + -0.055836472660303116, + -0.03299689665436745, + 0.15166012942790985, + 0.06786434352397919, + 0.008589516393840313, + -0.05790036544203758, + -0.0029997669626027346, + -0.14070068299770355, + -0.08799122273921967, + 0.19680362939834595, + 0.14703704416751862, + 0.03569985553622246, + -0.02847554162144661, + 0.03601403906941414, + -0.1339161992073059, + -0.20527805387973785, + 0.1060374304652214, + 0.16269326210021973, + 0.0575268417596817, + 0.0029672966338694096, + 0.018848277628421783, + -0.1029881089925766, + -0.19446833431720734, + -0.055140964686870575, + 0.09632515162229538, + 0.01196608692407608, + 0.01994382217526436, + 0.0030014747753739357, + 0.0029817752074450254, + -0.09395840018987656, + -0.038611751049757004, + 0.03793984279036522, + -0.006295992527157068, + 0.01736803539097309, + -0.0961727425456047, + 0.1318971812725067, + 0.00169672432821244, + 0.02773740515112877, + -0.03737606480717659, + -0.02413480542600155, + -0.07371329516172409, + 0.04465596005320549, + 0.34972262382507324, + 0.269726425409317, + 0.14907677471637726, + -0.15323053300380707, + -0.24987848103046417, + -0.32931339740753174, + 0.05209995433688164, + 0.5192161798477173, + 0.5108750462532043, + 0.2627664804458618, + -0.26889729499816895, + -0.49891141057014465, + -0.5081418752670288, + 0.13535383343696594, + 0.7318623661994934, + 0.7116816639900208, + 0.2973657250404358, + -0.38982102274894714, + -0.7131763100624084, + -0.5916072130203247, + 0.1200462281703949, + 0.7752112746238708, + 0.6947993636131287, + 0.21100594103336334, + -0.5576100945472717, + -0.7797606587409973, + -0.6058254837989807, + 0.08617032319307327, + 0.6432424187660217, + 0.522933304309845, + 0.16018754243850708, + -0.5134027004241943, + -0.6838728189468384, + -0.5088241100311279, + 0.10101393610239029, + 0.4321025311946869, + 0.3330003023147583, + 0.10116448998451233, + -0.2786642014980316, + -0.4134466052055359, + -0.3247438967227936, + 0.009768294170498848, + 0.008712833747267723, + -0.029476309195160866, + 0.007709377445280552, + 0.025279967114329338, + 0.01615188643336296, + 0.01585867628455162, + -0.0031516810413450003, + -0.06462288647890091, + -0.055517926812171936, + -0.013180199079215527, + -0.014849795028567314, + 0.05535515025258064, + 0.04162544384598732, + 0.0022392054088413715, + -0.09408581256866455, + -0.07889631390571594, + -0.032870080322027206, + 0.0382377915084362, + 0.07495865970849991, + 0.08439645916223526, + 0.008036677725613117, + -0.1167779192328453, + -0.10782196372747421, + -0.06854722648859024, + 0.06310252100229263, + 0.09643208235502243, + 0.08629462122917175, + -0.016969647258520126, + -0.10456187278032303, + -0.10410942137241364, + -0.017384463921189308, + 0.03931420296430588, + 0.11296819150447845, + 0.08688211441040039, + -0.018024103716015816, + -0.0985492691397667, + -0.10534191876649857, + 0.016594627872109413, + 0.024613894522190094, + 0.09626104682683945, + 0.056779902428388596, + -0.01314453687518835, + -0.1173979789018631, + -0.07576211541891098, + -0.00741730397567153, + 0.04463285952806473, + 0.06365535408258438, + 0.029472019523382187, + 0.06097950413823128, + -0.0884813666343689, + -0.020469073206186295, + -0.004499382339417934, + 0.006147715728729963, + 0.0061135985888540745, + 0.046618249267339706, + -0.024977274239063263, + -0.2809607684612274, + -0.20776452124118805, + -0.10792756825685501, + 0.10520339012145996, + 0.2195160835981369, + 0.27846819162368774, + -0.0425783209502697, + -0.4539273977279663, + -0.4210258722305298, + -0.24160517752170563, + 0.2377386838197708, + 0.4254952371120453, + 0.40258923172950745, + -0.08894401043653488, + -0.6261403560638428, + -0.6177268624305725, + -0.2941279113292694, + 0.36115866899490356, + 0.6176164746284485, + 0.5170959234237671, + -0.12760992348194122, + -0.6392932534217834, + -0.6288641095161438, + -0.20397846400737762, + 0.4859760105609894, + 0.7283636927604675, + 0.5233575105667114, + -0.08038943260908127, + -0.513219952583313, + -0.4611802101135254, + -0.08622774481773376, + 0.41959214210510254, + 0.6145293116569519, + 0.4252074360847473, + -0.08993257582187653, + -0.3586794435977936, + -0.23889268934726715, + -0.07402873039245605, + 0.2362663745880127, + 0.33187127113342285, + 0.24442552030086517, + -0.10037989169359207, + -0.1200498715043068, + -0.06188809871673584, + 0.009648810140788555, + 0.07703708112239838, + -0.07734857499599457, + -0.16337357461452484, + -0.13160429894924164, + -0.037760209292173386, + 0.10750655829906464, + 0.21975228190422058, + 0.21332265436649323, + 0.1482381671667099, + -0.012174196541309357, + -0.03128019720315933, + 0.06983920931816101, + 0.2055918425321579, + 0.16611628234386444, + 0.20955723524093628, + 0.21407610177993774, + 0.13214662671089172, + 0.01558306161314249, + 0.20919384062290192, + 0.21453723311424255, + 0.10980720072984695, + 0.10323476791381836, + 0.1754676252603531, + 0.16320686042308807, + 0.076839879155159, + 0.2669583261013031, + 0.29500535130500793, + 0.18005967140197754, + 0.14900699257850647, + 0.2337430715560913, + 0.2607984244823456, + -0.08909865468740463, + 0.12383633106946945, + 0.27329200506210327, + 0.2634970247745514, + 0.2298160344362259, + 0.22673286497592926, + 0.1753624528646469, + -0.14258335530757904, + -0.033422429114580154, + 0.09338828176259995, + 0.21975602209568024, + 0.2488732784986496, + 0.21165378391742706, + 0.08514796197414398, + 0.0776415765285492, + -0.028732767328619957, + -0.0827818363904953, + -0.14784079790115356, + -0.06101813539862633, + -0.10570015013217926, + -0.07298385351896286, + -0.03352680057287216, + -0.08094660192728043, + -0.08546923100948334, + -0.025722583755850792, + -0.04828448221087456, + -0.15816760063171387, + -0.22295169532299042, + -0.04976325109601021, + -0.12255501747131348, + -0.04869991913437843, + 0.09818085283041, + 0.2285904735326767, + 0.015187943354249, + -0.19952231645584106, + -0.1415022611618042, + -0.09511925280094147, + 0.10828559100627899, + 0.35640013217926025, + 0.5399265289306641, + 0.3026861250400543, + -0.10532847791910172, + -0.0455780103802681, + -0.09365752339363098, + 0.2482689470052719, + 0.5483031272888184, + 0.6572608947753906, + 0.4098849594593048, + -0.0039499495178461075, + -0.11641024053096771, + -0.22666053473949432, + -0.03133581206202507, + 0.2815704643726349, + 0.3229265809059143, + 0.009749597869813442, + -0.19616934657096863, + -0.05046992748975754, + -0.15597671270370483, + -0.22775587439537048, + -0.14872166514396667, + -0.12174414098262787, + -0.23433859646320343, + -0.238412007689476, + 0.09725375473499298, + 0.08522887527942657, + 0.006490080617368221, + -0.024619178846478462, + 0.07278231531381607, + 0.13406167924404144, + 0.22993306815624237, + 0.10250072181224823, + 0.09119024127721786, + -0.07687287777662277, + -0.1012108325958252, + -0.09500063210725784, + -0.10082961618900299, + 0.09466016292572021, + 0.11299365013837814, + -0.033278487622737885, + -0.20269805192947388, + -0.21449527144432068, + -0.08820098638534546, + -0.18970704078674316, + -0.050536416471004486, + -0.03471578657627106, + -0.13205547630786896, + -0.18150201439857483, + -0.03963223099708557, + 0.13029472529888153, + -0.11594776809215546, + -0.173879474401474, + 0.017406627535820007, + -0.11885572224855423, + -0.06966021656990051, + 0.1687183529138565, + 0.2677668035030365, + -0.020446041598916054, + -0.11710261553525925, + 0.044354867190122604, + -0.10054060816764832, + -0.1287878155708313, + -0.03600803390145302, + -0.03198331966996193, + -0.22372953593730927, + -0.11045534163713455, + 0.22963544726371765, + 0.16736479103565216, + -0.023956498131155968, + -0.0882943719625473, + -0.11904646456241608, + -0.10481738299131393, + 0.083598293364048, + 0.058089643716812134, + -0.04821285232901573, + 0.16764044761657715, + -0.13788309693336487, + -0.1412951946258545, + 0.059633608907461166, + 0.012824267148971558, + -0.03141501545906067, + -0.017422236502170563, + 0.3908282518386841, + -0.31520241498947144, + -0.27876099944114685, + 0.17109407484531403, + 0.011913848109543324, + -0.04440265893936157, + 0.05610174685716629, + 0.5290316343307495, + -0.4506116211414337, + -0.2946499288082123, + 0.2802693545818329, + 0.04180249199271202, + -0.05673402547836304, + 0.0445592887699604, + 0.4933576285839081, + -0.4903600513935089, + -0.3259376883506775, + 0.26069584488868713, + 0.047843094915151596, + -0.053804315626621246, + 0.029928382486104965, + 0.3588394224643707, + -0.39090782403945923, + -0.18598265945911407, + 0.1703576147556305, + 0.010407418012619019, + 0.019840527325868607, + -0.017079327255487442, + 0.21012797951698303, + -0.1586841642856598, + -0.12738685309886932, + 0.12431345880031586, + 0.028149213641881943, + 0.05083676427602768, + -0.07053223252296448, + 0.12090320140123367, + -0.13737183809280396, + -0.09807822853326797, + 0.07203921675682068, + -0.01965559460222721, + 0.036479320377111435, + -0.02657422423362732, + 0.2924504280090332, + -0.19397024810314178, + -0.20908842980861664, + 0.07435549795627594, + 0.011985386721789837, + -0.051603686064481735, + 0.039122600108385086, + 0.5911946892738342, + -0.45937344431877136, + -0.43863579630851746, + 0.23180224001407623, + 0.05592876672744751, + -0.10227655619382858, + 0.1371937245130539, + 0.7193072438240051, + -0.6789532899856567, + -0.5275344252586365, + 0.4098500609397888, + 0.09136661887168884, + -0.08802130073308945, + 0.12226735055446625, + 0.6819202303886414, + -0.7316576838493347, + -0.5229181051254272, + 0.37578293681144714, + 0.09086397290229797, + -0.05128701403737068, + 0.09287497401237488, + 0.5103837251663208, + -0.6150248646736145, + -0.3208717107772827, + 0.29780012369155884, + 0.071808360517025, + 0.04605705663561821, + 0.028153980150818825, + 0.30872926115989685, + -0.32211968302726746, + -0.1925150454044342, + 0.18948692083358765, + 0.07391810417175293, + 0.08546463400125504, + -0.07042243331670761, + 0.14390304684638977, + -0.22509464621543884, + -0.12615789473056793, + 0.09681600332260132, + 0.0030679223127663136, + 0.06206878274679184, + -0.0493885837495327, + 0.11675205081701279, + -0.09476804733276367, + -0.0708041712641716, + 0.027848264202475548, + 0.018535451963543892, + 0.01112216804176569, + -0.023546719923615456, + 0.2808285057544708, + -0.2312571257352829, + -0.16320407390594482, + 0.15229304134845734, + -0.007220278959721327, + -0.026767488569021225, + -0.008487970568239689, + 0.39064091444015503, + -0.3746477961540222, + -0.22930599749088287, + 0.23297259211540222, + -0.020648201927542686, + -0.03918099403381348, + -0.03193120285868645, + 0.37857353687286377, + -0.38306936621665955, + -0.25103962421417236, + 0.2414209097623825, + 0.007709929719567299, + -0.041483473032712936, + -0.001570625347085297, + 0.315625935792923, + -0.276553213596344, + -0.13154125213623047, + 0.17517149448394775, + 0.03219839558005333, + 0.002647437620908022, + -0.012777225114405155, + 0.17064248025417328, + -0.13943275809288025, + -0.10204917937517166, + 0.09418098628520966, + 0.026260169222950935, + 0.05167905613780022, + -0.024634944275021553, + 0.0931941494345665, + -0.11875593662261963, + -0.0752263143658638, + 0.0569780170917511, + 0.00024334408226422966, + -0.001991289434954524, + -0.012094452045857906, + -0.0012201170902699232, + 0.01342268567532301, + 0.006425719242542982, + 0.01147665549069643, + -0.002208880614489317, + -0.019385183230042458, + -0.024868011474609375, + 0.00465290667489171, + 0.009205960668623447, + 0.0016242304118350148, + 0.0059639886021614075, + -0.03436571732163429, + 0.01672518253326416, + 0.008815832436084747, + 0.06389293074607849, + 0.06249547377228737, + 0.06542838364839554, + 0.043118152767419815, + 0.04117512330412865, + 0.014435848221182823, + 0.0065850247628986835, + 0.03811212629079819, + -0.006077034864574671, + -0.004025861620903015, + 0.006247953977435827, + 0.014478449709713459, + 0.0009701942908577621, + -0.002422194229438901, + 0.009390920400619507, + -0.052253514528274536, + -0.05192738026380539, + -0.010346310213208199, + -0.001328076352365315, + -0.002972622634842992, + 0.0015572139527648687, + 0.022503724321722984, + -0.002475353656336665, + 0.001927886507473886, + 0.02994818612933159, + 0.02062363363802433, + -0.0010653833160176873, + -0.005995174869894981, + 0.024450020864605904, + 0.013005194254219532, + 0.0496530681848526, + 0.029475165531039238, + 0.004157512914389372, + -0.0007043799851089716, + 0.01860312558710575, + 0.03839566186070442, + 0.00014980587002355605, + 0.018569663166999817, + 0.05668198689818382, + 0.04645680636167526, + 0.01642409712076187, + 0.03577466681599617, + 0.03575601801276207, + -0.03680748492479324, + -0.01865880750119686, + 0.041660092771053314, + 0.033268485218286514, + 0.03338993713259697, + 0.04665865749120712, + -0.03322917968034744, + -0.2860279381275177, + -0.28877392411231995, + -0.09617949277162552, + 0.014234350994229317, + 0.038012001663446426, + -0.016850680112838745, + -0.27252569794654846, + -0.6714493632316589, + -0.686245322227478, + -0.3376169502735138, + -0.0812990590929985, + 0.003058002796024084, + -0.026376569643616676, + -0.29216718673706055, + -0.6779875159263611, + -0.6917123198509216, + -0.3184400796890259, + -0.058261968195438385, + 0.06338769942522049, + 0.03199980780482292, + -0.09837217628955841, + -0.3355932831764221, + -0.30900436639785767, + -0.04878076910972595, + 0.061543505638837814, + 0.04651529714465141, + 0.0263908002525568, + 0.0030237447936087847, + -0.10458099842071533, + -0.07959774881601334, + 0.05430716276168823, + 0.056767694652080536, + 0.00796051137149334, + -0.016737859696149826, + -0.042338743805885315, + -0.0198048185557127, + -0.03085070475935936, + -0.058721307665109634, + -0.036032311618328094, + -0.0035414688754826784, + -8.359456842299551e-05, + -0.02213932015001774, + 0.02032857947051525, + 0.021788733080029488, + -0.03522418439388275, + -0.025317413732409477, + -0.042937491089105606, + -0.05680134892463684, + -0.012510996311903, + 0.226289302110672, + 0.24401520192623138, + 0.022300971671938896, + -0.030825607478618622, + -0.05485948920249939, + 0.007590078748762608, + 0.2208130657672882, + 0.6964298486709595, + 0.7457719445228577, + 0.3470557630062103, + 0.06941442936658859, + -0.03543366119265556, + 0.035853609442710876, + 0.2872598171234131, + 0.7504303455352783, + 0.7509996294975281, + 0.34327855706214905, + 0.024429334327578545, + -0.05711393058300018, + -0.034500252455472946, + 0.057939525693655014, + 0.33292675018310547, + 0.3141649067401886, + 0.033748809248209, + -0.062175147235393524, + -0.041224412620067596, + -0.01891348883509636, + -0.014519350603222847, + 0.08635713160037994, + 0.03148616850376129, + -0.08749162405729294, + -0.05658482387661934, + 0.00018510188965592533, + 0.002624311950057745, + -0.003570129396393895, + 0.0067627751268446445, + -0.01349653396755457, + -0.003961967770010233, + 0.0034001911990344524, + -0.00385954394005239, + 0.018012456595897675, + -0.018755480647087097, + -0.03163064643740654, + -0.0035233700182288885, + 0.011690095998346806, + -0.014693490229547024, + 0.017746854573488235, + 0.05693097040057182, + 0.1272590607404709, + 0.23477119207382202, + 0.19823509454727173, + 0.05071045830845833, + -0.007188393268734217, + -0.05571149289608002, + -0.06468938291072845, + -0.017831332981586456, + -0.07572834193706512, + -0.19599483907222748, + -0.15608063340187073, + -0.039450764656066895, + -0.035583946853876114, + -0.1605951488018036, + -0.5041624307632446, + -0.6836286783218384, + -0.3773191571235657, + -0.08623629808425903, + -0.04881078004837036, + 0.029403403401374817, + 0.15516817569732666, + 0.4108496308326721, + 0.6393839716911316, + 0.4688946008682251, + 0.2135964334011078, + 0.0623941570520401, + 0.02426956780254841, + -8.065254223765805e-05, + -0.00816427543759346, + -0.09353788942098618, + -0.06872912496328354, + -0.029405562207102776, + 0.012364620342850685, + 0.0060868943110108376, + 0.017015695571899414, + -0.0076495204120874405, + -0.006090708542615175, + -0.016521835699677467, + 0.009218892082571983, + 0.030833140015602112, + -0.0002345978282392025, + 0.03332215175032616, + 0.0030349211301654577, + 0.009600857272744179, + 0.05706647038459778, + 0.06095677986741066, + -0.016137542203068733, + 0.03195658698678017, + 0.13535599410533905, + 0.28229761123657227, + 0.4573267698287964, + 0.39102476835250854, + 0.17547546327114105, + 0.005337159149348736, + -0.07699840515851974, + -0.12667469680309296, + -0.16613735258579254, + -0.2908898890018463, + -0.44942277669906616, + -0.34229782223701477, + -0.16225378215312958, + -0.1100199744105339, + -0.4044281840324402, + -0.9058251976966858, + -1.1549302339553833, + -0.7502554059028625, + -0.2716369032859802, + -0.13495275378227234, + 0.08614412695169449, + 0.3164423108100891, + 0.7155097723007202, + 1.0356683731079102, + 0.7939887642860413, + 0.39567017555236816, + 0.16957539319992065, + 0.02675812318921089, + 0.048314403742551804, + 0.053107086569070816, + -0.009243623353540897, + -0.011442561633884907, + 0.004911235999315977, + 0.012210517190396786, + 0.006660772021859884, + -0.004562888294458389, + -0.009606098756194115, + -0.01610635593533516, + -0.03475078567862511, + 0.007796770427376032, + 0.02015513926744461, + 0.020311446860432625, + 0.009043446741998196, + -0.01929326355457306, + -0.04183953255414963, + -0.003052672604098916, + 0.020744286477565765, + 0.01371331699192524, + 0.004048139322549105, + 0.0692848190665245, + 0.16867054998874664, + 0.2799474000930786, + 0.28119951486587524, + 0.13579942286014557, + -0.0015732255997136235, + -0.05406518653035164, + -0.05831173434853554, + -0.034435681998729706, + -0.11925295740365982, + -0.2570647895336151, + -0.19120880961418152, + -0.09981344640254974, + -0.011702792719006538, + -0.22477947175502777, + -0.5395713448524475, + -0.7111374139785767, + -0.4207299053668976, + -0.11811137199401855, + -0.035199034959077835, + 0.024358956143260002, + 0.16262274980545044, + 0.46769100427627563, + 0.677872896194458, + 0.4637402892112732, + 0.15558630228042603, + 0.04467496648430824, + 0.03221412003040314, + 0.02430277317762375, + -0.006398700177669525, + -0.07235423475503922, + -0.03669704124331474, + -0.000992153538390994, + 0.02220241352915764, + -0.03329842537641525, + 0.05199713259935379, + -0.14053553342819214, + 0.1906905472278595, + -0.13544943928718567, + 0.08535720407962799, + -0.009813228622078896, + 0.03578176349401474, + -0.05863757058978081, + 0.33848440647125244, + -0.49837300181388855, + 0.15308170020580292, + 0.14865124225616455, + -0.12349266558885574, + -0.025796135887503624, + 0.17790427803993225, + -0.7813658714294434, + 0.853188693523407, + 0.2489670068025589, + -0.7378701567649841, + 0.2207188457250595, + 0.05207442864775658, + -0.4280349314212799, + 1.1408430337905884, + -0.24505679309368134, + -1.5490919351577759, + 1.4560288190841675, + -0.31143030524253845, + -0.03536878153681755, + 0.5640448331832886, + -0.6874421834945679, + -1.210310697555542, + 2.6637399196624756, + -1.6589887142181396, + 0.2221546173095703, + 0.10179737955331802, + -0.4354941248893738, + 0.034149203449487686, + 1.480568528175354, + -2.072199821472168, + 0.9205833673477173, + 0.021510563790798187, + -0.07755836099386215, + 0.17983688414096832, + 0.040537625551223755, + -0.5325585603713989, + 0.550999641418457, + -0.11060550063848495, + -0.09052976220846176, + -0.048361390829086304, + 0.03450514376163483, + -0.11854307353496552, + 0.23462797701358795, + -0.17563995718955994, + 0.0653814822435379, + -0.009748813696205616, + 0.07013920694589615, + -0.08628369867801666, + 0.3019683063030243, + -0.630340576171875, + 0.274477481842041, + 0.15417183935642242, + -0.036220982670784, + -0.07344137132167816, + 0.2339126616716385, + -1.0395091772079468, + 1.2002928256988525, + 0.085142120718956, + -0.7080597281455994, + 0.23101751506328583, + 0.016307154670357704, + -0.45877355337142944, + 1.617128849029541, + -0.6593433618545532, + -1.8957709074020386, + 1.746606469154358, + -0.37062564492225647, + 0.01213759370148182, + 0.5851964354515076, + -1.0307577848434448, + -1.4803766012191772, + 3.812014102935791, + -2.0028398036956787, + 0.12008816003799438, + 0.01813559979200363, + -0.5065457820892334, + 0.17598780989646912, + 2.0418734550476074, + -2.680522918701172, + 0.7466094493865967, + 0.16271913051605225, + -0.04379571974277496, + 0.21930621564388275, + 0.041255541145801544, + -0.6644601821899414, + 0.481300413608551, + 0.05410065874457359, + -0.09025495499372482, + 0.01954805478453636, + 0.01899997517466545, + -0.1337241530418396, + 0.19821906089782715, + -0.06395180523395538, + -0.03586877882480621, + 0.01973363384604454, + 0.013873124495148659, + -0.09288538247346878, + 0.4300728440284729, + -0.4235192537307739, + 0.03646458685398102, + 0.10077393800020218, + -0.07569073140621185, + -0.08176662772893906, + 0.3834531605243683, + -0.747482419013977, + 0.4493187367916107, + 0.2960513234138489, + -0.5245057344436646, + 0.27831950783729553, + 0.0731748417019844, + -0.45574328303337097, + 0.6987965703010559, + 0.019539732486009598, + -1.1160184144973755, + 1.0756875276565552, + -0.3804619312286377, + -0.040626902133226395, + 0.2780243456363678, + -0.32946258783340454, + -0.8122196793556213, + 1.9535348415374756, + -1.300661563873291, + 0.3443142771720886, + 0.04858396574854851, + -0.17409801483154297, + -0.07783844321966171, + 1.0875797271728516, + -1.5148566961288452, + 0.8014272451400757, + -0.19643208384513855, + -0.033590562641620636, + 0.11178025603294373, + 0.08284300565719604, + -0.5165408849716187, + 0.5841389894485474, + -0.24739950895309448, + 0.027926180511713028, + -0.028708497062325478, + 0.0037401756271719933, + -0.0047450135461986065, + 0.008427698165178299, + 0.009801353327929974, + -0.0029346586670726538, + -0.010193527676165104, + 0.014876358211040497, + 0.009861295111477375, + -0.005554665345698595, + -0.06270359456539154, + -0.0316256619989872, + 0.006706684362143278, + 0.04316525161266327, + 0.008637072518467903, + -0.03666357323527336, + -0.0719730481505394, + -0.1525861918926239, + -0.14396126568317413, + -0.05387119948863983, + 0.01955549605190754, + 0.007112634833902121, + -0.05175568535923958, + -0.16772602498531342, + -0.20807777345180511, + -0.18768996000289917, + -0.17093753814697266, + -0.03334345668554306, + 0.0011808606795966625, + -0.01579100452363491, + -0.12589050829410553, + -0.17219413816928864, + -0.19648219645023346, + -0.21980451047420502, + -0.04920821264386177, + 0.0012217299081385136, + 0.023885242640972137, + -0.056074876338243484, + -0.13907776772975922, + -0.19139252603054047, + -0.13652737438678741, + -0.0027339402586221695, + 0.004720518831163645, + -0.00037206560955382884, + 0.017924504354596138, + -0.02118082158267498, + -0.06553903222084045, + -0.0435921773314476, + 0.02721239998936653, + 0.020702000707387924, + 0.024033410474658012, + 0.005382229574024677, + -0.01273527555167675, + -0.01742861233651638, + 0.007402990944683552, + 0.010333286598324776, + 0.02598601020872593, + 0.012456837110221386, + -0.03471057116985321, + -0.10051856189966202, + -0.08084382116794586, + -0.023420603945851326, + 0.031205907464027405, + 0.00424322672188282, + -0.03734385594725609, + -0.1152661070227623, + -0.2012551724910736, + -0.1995576024055481, + -0.07972321659326553, + -0.011126434430480003, + -0.0185835100710392, + -0.06944561004638672, + -0.21481844782829285, + -0.26795628666877747, + -0.24916253983974457, + -0.17833945155143738, + -0.06658200174570084, + -0.00305415247566998, + -0.054028186947107315, + -0.19072681665420532, + -0.256619930267334, + -0.26868295669555664, + -0.21621295809745789, + -0.06564134359359741, + 0.0031192339956760406, + 0.013205861672759056, + -0.08044812828302383, + -0.18137820065021515, + -0.23007699847221375, + -0.13054916262626648, + -0.01135951280593872, + 0.013734308071434498, + 0.010981118306517601, + -0.02249351143836975, + -0.05804377421736717, + -0.10652261227369308, + -0.04163172468543053, + 0.017101088538765907, + -0.028687385842204094, + -0.0019976652693003416, + 0.009987232275307178, + 0.010130539536476135, + 0.0015575449215248227, + -0.000983694102615118, + -0.012845008634030819, + 0.01329281460493803, + 0.0029350779950618744, + -0.003755913581699133, + -0.036475058645009995, + -0.0245466697961092, + -0.0020879909861832857, + 0.025867130607366562, + -0.0065954397432506084, + 0.008656582795083523, + -0.04037104919552803, + -0.11718368530273438, + -0.13506115972995758, + -0.024255141615867615, + 0.014097613282501698, + -0.0009370348998345435, + -0.010953565128147602, + -0.12869219481945038, + -0.18789908289909363, + -0.19098156690597534, + -0.12795749306678772, + -0.002666366985067725, + -0.004907527007162571, + -0.014610078185796738, + -0.11913872510194778, + -0.19921070337295532, + -0.21869640052318573, + -0.1849898099899292, + -0.03470952808856964, + 0.0064156935550272465, + 0.03401843458414078, + -0.04000416398048401, + -0.12354391813278198, + -0.16908879578113556, + -0.10385500639677048, + 0.002833302365615964, + -0.036176733672618866, + -0.001048827893100679, + 0.010002595372498035, + -0.020798830315470695, + -0.0488261841237545, + -0.002972641494125128, + 0.016395021229982376, + -0.045770127326250076, + -0.12710650265216827, + -0.1637774109840393, + -0.1411965787410736, + 0.20447289943695068, + 0.509396493434906, + 0.07264503091573715, + 0.12041529268026352, + -0.015143441036343575, + -0.2673257887363434, + -0.3589763641357422, + 0.11289574205875397, + 0.8517020344734192, + 0.7068799138069153, + 0.067301444709301, + -0.02102830447256565, + -0.5235708355903625, + -1.2064802646636963, + -0.856619656085968, + 0.26774707436561584, + 0.6825867295265198, + 0.13516077399253845, + 0.3054035007953644, + -0.0727991834282875, + -1.4912222623825073, + -1.906838297843933, + -0.8574200868606567, + -0.15282419323921204, + 0.39327505230903625, + 0.9758505821228027, + 1.2323224544525146, + 0.18179064989089966, + -0.947610080242157, + -0.6657719016075134, + -0.19935055077075958, + -0.09150458872318268, + 0.34379544854164124, + 1.2025749683380127, + 0.9517407417297363, + -0.12023784220218658, + -0.3146151900291443, + -0.1049022302031517, + -0.34867578744888306, + -0.32945582270622253, + 0.28920575976371765, + 0.7844374179840088, + 0.35520124435424805, + 0.007452746387571096, + 0.018862545490264893, + -0.0021927610505372286, + 0.0321974977850914, + 0.05439181253314018, + -0.030729038640856743, + -0.03517322614789009, + -0.037830010056495667, + -0.056672073900699615, + -0.017769837751984596, + 0.06385952979326248, + 0.08161566406488419, + 0.07809178531169891, + 0.06333671510219574, + -0.036322008818387985, + -0.06432312726974487, + -0.03629852458834648, + 0.010879911482334137, + 0.088901087641716, + 0.0021402277052402496, + 0.09618857502937317, + 0.02661084569990635, + -0.03414442762732506, + -0.08736730366945267, + -0.048222169280052185, + 0.03507986292243004, + -0.053828027099370956, + 0.006044292356818914, + 0.04232194274663925, + 0.001624415279366076, + -0.028371643275022507, + -0.08724038302898407, + -0.005835397634655237, + 0.01057528518140316, + 0.04210871085524559, + 0.06106603890657425, + 0.04250370338559151, + 0.0028668276499956846, + -0.07583706080913544, + -0.06849333643913269, + -0.08538331836462021, + -0.021475542336702347, + 0.044341571629047394, + 0.03604369983077049, + 0.05146002024412155, + 0.00280605535954237, + -0.004615028854459524, + -0.07857430726289749, + -0.03716180846095085, + 0.010876243002712727, + -0.03418488800525665, + 0.007391764782369137, + 0.05969953536987305, + 0.08769611269235611, + 0.066011443734169, + -0.10404568910598755, + -0.27194535732269287, + -0.05224551260471344, + -0.03618992492556572, + -0.023098375648260117, + 0.13832588493824005, + 0.21510572731494904, + -0.07285867631435394, + -0.489085853099823, + -0.33285844326019287, + -0.04830349236726761, + 0.014211038127541542, + 0.2612524926662445, + 0.6911754608154297, + 0.5294638276100159, + -0.2706173360347748, + -0.39350029826164246, + -0.05156399682164192, + -0.16490484774112701, + 0.1161464974284172, + 0.8029336929321289, + 1.1809980869293213, + 0.5025736689567566, + 0.07084998488426208, + -0.1901131123304367, + -0.4918227195739746, + -0.603122889995575, + -0.09460704773664474, + 0.5786081552505493, + 0.35392242670059204, + 0.1328991800546646, + -0.008106965571641922, + -0.2159435749053955, + -0.6369062662124634, + -0.5241336822509766, + 0.06276796758174896, + 0.1139409989118576, + 0.05483332276344299, + 0.1703934520483017, + 0.14603517949581146, + -0.16187912225723267, + -0.4139055907726288, + -0.14918148517608643, + -0.06163417547941208, + 0.005302567034959793, + 0.015524876303970814, + -0.11895350366830826, + -0.19724233448505402, + 0.03412429615855217, + 0.10862118750810623, + 0.08550503104925156, + -0.008599682711064816, + -0.03031114675104618, + -0.33224624395370483, + -0.27994298934936523, + 0.196475550532341, + 0.31109708547592163, + 0.17151644825935364, + -0.04994147643446922, + -0.167176753282547, + -0.5247878432273865, + -0.21136601269245148, + 0.54701828956604, + 0.6110883951187134, + 0.04194486886262894, + -0.27640673518180847, + -0.0795169249176979, + -0.360530287027359, + 0.3472684621810913, + 1.5428175926208496, + 1.0249378681182861, + -0.2724844515323639, + -0.3013695478439331, + 0.020736562088131905, + -0.019495302811264992, + 0.7758124470710754, + 1.5381159782409668, + 0.028625331819057465, + -1.289720892906189, + -0.5894255638122559, + 0.0526396706700325, + 0.11443997919559479, + 0.5935031771659851, + 0.47169724106788635, + -1.2507063150405884, + -1.351940631866455, + -0.03894977271556854, + 0.05095001682639122, + 0.01581231690943241, + 0.11137383431196213, + -0.22327138483524323, + -0.9629225730895996, + -0.2607772946357727, + 0.5907121300697327, + 0.006906076334416866, + 0.002633580705150962, + 0.01940075121819973, + 0.0143396882340312, + 0.020781584084033966, + -0.07249777764081955, + -0.016355905681848526, + 0.016553230583667755, + -0.027528395876288414, + 0.0244428887963295, + 0.024910561740398407, + 0.027229825034737587, + -0.04104151204228401, + 0.007100561633706093, + 0.0157785601913929, + -0.06626633554697037, + 0.006520191207528114, + 0.021171070635318756, + 0.036674920469522476, + -0.06950324773788452, + -0.03003627620637417, + 2.178798422391992e-05, + -0.07278106361627579, + 0.014382920227944851, + 0.0982266515493393, + 0.1454961597919464, + -0.10096189379692078, + 0.022237209603190422, + -0.00040665315464138985, + -0.013766243122518063, + 0.06440296769142151, + 0.21751047670841217, + 0.02519127167761326, + -0.23383572697639465, + 0.0038903038948774338, + -0.042271602898836136, + -0.012596859596669674, + 0.023778460919857025, + 0.07685687392950058, + -0.21480663120746613, + -0.19205358624458313, + 0.04876565560698509, + -0.016765035688877106, + -0.02620583213865757, + 0.01641852967441082, + 0.02201787941157818, + -0.07457322627305984, + -0.003633625339716673, + 0.07550841569900513, + 0.024774253368377686, + 0.04710151255130768, + 0.09110233932733536, + -0.017366377636790276, + -0.04366954043507576, + -0.039786458015441895, + 0.005311290733516216, + 0.037867460399866104, + 0.05367766693234444, + 0.07434491813182831, + -0.07251215726137161, + -0.04231821000576019, + -0.023427855223417282, + 0.036294277757406235, + 0.07782749086618423, + 0.11835407465696335, + 0.08753973245620728, + -0.20742319524288177, + -0.13341759145259857, + -0.008225077763199806, + 0.07292432337999344, + 0.006392402108758688, + 0.021914338693022728, + -0.09218581020832062, + -0.44192466139793396, + -0.1744878888130188, + 0.014938815496861935, + 0.10678526759147644, + -0.012087192386388779, + -0.024533385410904884, + -0.1804407387971878, + -0.3253834545612335, + 0.040678758174180984, + 0.2011708915233612, + 0.17262929677963257, + -0.0045212251134216785, + -0.033313386142253876, + -0.10575363039970398, + -0.07636692374944687, + 0.20343273878097534, + 0.28330928087234497, + 0.043149981647729874, + -0.01109551265835762, + -0.0027725452091544867, + 0.003926735837012529, + 0.029440222308039665, + 0.23945140838623047, + 0.09122566133737564, + -0.15140119194984436, + 0.08737201988697052, + 0.07120998948812485, + 0.05722665786743164, + -0.04388495534658432, + 0.02116825245320797, + 0.023315919563174248, + 0.10898162424564362, + 0.11808467656373978, + 0.03412344306707382, + 0.002771642990410328, + -0.1959579437971115, + -0.05181330814957619, + -0.0044630044139921665, + 0.12481725960969925, + 0.09140311926603317, + 0.03444851189851761, + -0.10931172221899033, + -0.3204459846019745, + -0.21193139255046844, + -0.11101037263870239, + 0.04186606407165527, + -0.07420916110277176, + -0.2004990428686142, + -0.26937955617904663, + -0.12928874790668488, + 0.20819628238677979, + -0.17379426956176758, + -0.2181481271982193, + 0.005387924611568451, + -0.24132733047008514, + -0.23942433297634125, + 0.41489261388778687, + 1.0702778100967407, + 0.024913936853408813, + -0.28405970335006714, + 0.083008773624897, + -0.11059781163930893, + -0.17623695731163025, + -0.17386195063591003, + 0.010644182562828064, + -0.32716259360313416, + -0.2135595828294754, + 0.1223129853606224, + 0.07060510665178299, + -0.048680394887924194, + -0.3332099914550781, + -0.25886017084121704, + -0.18619979918003082, + -0.00733158877119422, + 0.03393476828932762, + -0.010564662516117096, + -0.01817108877003193, + -0.05650597810745239, + -0.01891104131937027, + -0.0554141066968441, + -0.004592927638441324, + -0.0013615720672532916, + -0.05552899092435837, + -0.0560498908162117, + -0.1080632209777832, + -0.013965745456516743, + -0.03290533646941185, + -0.02599845454096794, + -0.02877708151936531, + -0.05670137703418732, + -0.07158109545707703, + -0.08808472007513046, + -0.03919175639748573, + -0.08478893339633942, + -0.08045543730258942, + -0.10066724568605423, + -0.048338882625103, + -0.06750114262104034, + 0.08164039999246597, + 0.3343777060508728, + 0.004952755756676197, + -0.14891156554222107, + 0.032855477184057236, + -0.03277512267231941, + 0.0474768728017807, + 0.6316664814949036, + 1.2214386463165283, + 0.2548498213291168, + -0.13185030221939087, + -0.018188906833529472, + -0.07653989642858505, + -0.01643386110663414, + 0.06630122661590576, + 0.23864209651947021, + -0.013703612610697746, + -0.09347789734601974, + -0.0900193303823471, + -0.04930814355611801, + -0.02791711315512657, + -0.15441712737083435, + -0.01623091846704483, + -0.0447690524160862, + -0.06071227043867111, + -0.04737209901213646, + -0.059769801795482635, + -0.04375007003545761, + -0.00650476710870862, + 0.021540174260735512, + -0.05590728670358658, + -0.13030850887298584, + -0.022067781537771225, + -0.05066747963428497, + 0.00609770929440856, + 0.108611099421978, + 0.1621929407119751, + 0.05232185125350952, + -0.049729123711586, + -0.11906369775533676, + -0.030973592773079872, + 0.057787079364061356, + 0.1610448956489563, + 0.18756121397018433, + 0.07277501374483109, + -0.05777435004711151, + -0.05227195844054222, + 0.14434091746807098, + 0.1889694482088089, + 0.26951169967651367, + 0.4710105359554291, + 0.2164669781923294, + 0.05052375793457031, + -0.0038236663676798344, + 0.20267778635025024, + 0.31214746832847595, + 0.7506387829780579, + 1.2302387952804565, + 0.4363090693950653, + 0.16759593784809113, + -0.049752235412597656, + 0.044786907732486725, + 0.14537742733955383, + 0.2227499932050705, + 0.37362414598464966, + 0.16590620577335358, + 0.0864599421620369, + -0.14058542251586914, + -0.04404178634285927, + -0.0325944609940052, + -0.019113417714834213, + 0.17414243519306183, + 0.11160623282194138, + -0.034911543130874634, + 0.1523953527212143, + 0.04554234445095062, + -0.054958827793598175, + -0.11794494092464447, + -0.19570015370845795, + -0.21358126401901245, + -0.1885669231414795, + -0.08286706358194351, + -0.29818814992904663, + -0.52330082654953, + -0.6190353631973267, + -0.682529091835022, + -0.6171367764472961, + -0.4793100655078888, + -0.11180876195430756, + -0.3490432798862457, + -0.5531057715415955, + -0.6426181793212891, + -0.6420838832855225, + -0.4970071613788605, + -0.27038174867630005, + -0.09740017354488373, + -0.1929621547460556, + -0.30848363041877747, + -0.27204805612564087, + -0.2515120208263397, + -0.07497832179069519, + 0.03551386669278145, + -0.05060403421521187, + 0.08276989310979843, + 0.14321963489055634, + 0.3583574593067169, + 0.40667927265167236, + 0.39398193359375, + 0.27561235427856445, + 0.005085935816168785, + 0.2793635427951813, + 0.48155927658081055, + 0.7088037729263306, + 0.7394692897796631, + 0.6158861517906189, + 0.3986552655696869, + 0.025508087128400803, + 0.38533228635787964, + 0.5305332541465759, + 0.6659612059593201, + 0.6396889090538025, + 0.5396444797515869, + 0.39010515809059143, + -0.03072960674762726, + 0.014305810444056988, + 0.029885446652770042, + 0.038084372878074646, + 0.012448564171791077, + 0.034353457391262054, + 0.048626724630594254, + 0.048866890370845795, + 0.07561437785625458, + 0.09152165800333023, + 0.08432324975728989, + 0.09332144260406494, + 0.07517607510089874, + 0.049146559089422226, + 0.03146318346261978, + 0.06335246562957764, + 0.06438779830932617, + 0.06851581484079361, + 0.09263566881418228, + 0.06460423022508621, + 0.011992924846708775, + 0.03396693989634514, + 0.04433950409293175, + 0.04642309248447418, + 0.0022602551616728306, + -0.0361824594438076, + -0.0005105047021061182, + 0.030808264389634132, + 0.0022333709057420492, + -0.017826544120907784, + -0.03796307370066643, + -0.012887164019048214, + -0.028499294072389603, + -0.03367336839437485, + -0.03668365254998207, + -0.02807682938873768, + -0.07444571703672409, + -0.081318199634552, + -0.09610070288181305, + -0.05368436127901077, + -0.09006591141223907, + -0.10038736462593079, + -0.04115951433777809, + -0.056811004877090454, + -0.09935522079467773, + -0.11107856035232544, + -0.07852742075920105, + -0.0942930206656456, + -0.07625897973775864, + -0.12966541945934296, + -0.038938648998737335, + 0.04580259323120117, + 0.10179819911718369, + 0.17127273976802826, + 0.17857632040977478, + 0.13426578044891357, + 0.04687841981649399, + 0.2424812912940979, + 0.42633309960365295, + 0.5291624069213867, + 0.6012980937957764, + 0.5449428558349609, + 0.3945220708847046, + 0.07037744671106339, + 0.26918724179267883, + 0.44614800810813904, + 0.5331310629844666, + 0.568580687046051, + 0.43367546796798706, + 0.25516101717948914, + 0.08428427577018738, + 0.177769735455513, + 0.24885930120944977, + 0.2178547978401184, + 0.13834305107593536, + 0.07452446967363358, + 0.005187708884477615, + 0.050621017813682556, + -0.08428733795881271, + -0.15576106309890747, + -0.25531095266342163, + -0.34646397829055786, + -0.3276817202568054, + -0.24377694725990295, + 0.02817704901099205, + -0.2531633675098419, + -0.3907041549682617, + -0.5944734811782837, + -0.6062930822372437, + -0.5171639919281006, + -0.3501560389995575, + -0.019397703930735588, + -0.2758809030056, + -0.4118667244911194, + -0.5375933051109314, + -0.5525977611541748, + -0.44681206345558167, + -0.2748269736766815, + -0.04229651764035225, + -0.005005967803299427, + -0.011332424357533455, + 0.011387092061340809, + -0.015463154762983322, + -0.012038768269121647, + 0.011360889300704002, + 0.03551746904850006, + 0.05123865604400635, + 0.020377267152071, + 0.1065637394785881, + 0.18875306844711304, + 0.18516196310520172, + 0.12519532442092896, + -0.042940977960824966, + -0.03246130794286728, + -0.016645772382616997, + 0.07807288318872452, + -0.7815885543823242, + -0.5930942296981812, + 0.03312799707055092, + -0.04537777230143547, + -0.022234303876757622, + 0.009241255931556225, + 0.16947965323925018, + -0.0700032040476799, + -0.06346366554498672, + 0.09555318206548691, + 0.02858082763850689, + 0.009246457368135452, + 0.03902693837881088, + 0.007071994710713625, + 0.10085106641054153, + 0.0881502702832222, + 0.011019160971045494, + 0.006030070595443249, + -0.012882355600595474, + -0.01701420359313488, + 0.022596944123506546, + -0.05345382168889046, + 0.02355102449655533, + -0.0091088330373168, + 0.00015542628534603864, + -0.0004997836658731103, + -0.006951311603188515, + 0.01267238613218069, + -0.0033983420580625534, + -0.0030770134180784225, + 0.02975126914680004, + 0.010702245868742466, + -0.016947058960795403, + 0.007774800062179565, + 0.09566964209079742, + 0.07426714897155762, + 0.1621979922056198, + 0.12728945910930634, + 0.06112523376941681, + 0.06061968579888344, + 0.07934501022100449, + 0.11534841358661652, + 0.10001469403505325, + 0.15475066006183624, + 0.1828109323978424, + 0.02134544588625431, + -0.015320047736167908, + 0.012000483460724354, + -0.014393450692296028, + -1.5520576238632202, + -1.2115217447280884, + 0.017239907756447792, + -0.007013735361397266, + 0.0019166347337886691, + 0.025112343952059746, + 0.1803419440984726, + -0.30807924270629883, + -0.33957329392433167, + 0.10846519470214844, + 0.06151076406240463, + 0.054799750447273254, + 0.06235412135720253, + 0.09605015069246292, + 0.16495031118392944, + 0.12624189257621765, + 0.12234552949666977, + 0.006969878450036049, + 0.0033541936427354813, + 0.008165130391716957, + 0.035377491265535355, + -0.03170061111450195, + 0.019396571442484856, + -0.011411413550376892, + 0.019043665379285812, + 0.00957057997584343, + 0.0055394587107002735, + 0.05569477006793022, + 0.0076510305516421795, + 0.018707536160945892, + 0.06073765829205513, + 0.006503407843410969, + -0.0058801183477044106, + -0.03229741007089615, + 0.0386439748108387, + 0.03167358413338661, + 0.027749545872211456, + -0.04634377732872963, + -0.00019781991431955248, + 0.024982664734125137, + 0.009453915059566498, + 0.1091528981924057, + 0.21055325865745544, + 0.23810525238513947, + 0.13829846680164337, + -0.019112061709165573, + -0.0014926757430657744, + 0.01856786385178566, + 0.10649964213371277, + -0.8599057793617249, + -0.6383436322212219, + 0.10839059948921204, + -0.038730181753635406, + -0.030203847214579582, + -0.033147793263196945, + 0.18132103979587555, + -0.1427767276763916, + -0.11132896691560745, + 0.10957232862710953, + -0.00349965482018888, + 0.03486581891775131, + 0.016247740015387535, + 0.060106489807367325, + 0.1439678966999054, + 0.07201634347438812, + 0.07603273540735245, + -0.0072280303575098515, + 0.01600506529211998, + -0.012912745587527752, + 0.015192546881735325, + -0.034853674471378326, + 0.026164958253502846, + 0.001483929343521595, + 0.0508253313601017, + -0.010546445846557617, + -0.024398569017648697, + -0.0043407524935901165, + 0.0030393539927899837, + -0.009643012657761574, + -0.008882591500878334, + 0.01182172168046236, + 0.003359999740496278, + -0.01145304087549448, + -7.34154018573463e-05, + 0.007416137028485537, + -0.012022661976516247, + 0.013550116680562496, + -0.005982181057333946, + -0.019205773249268532, + -0.0811527743935585, + -0.06323252618312836, + -0.026379290968179703, + -0.04671972244977951, + -0.006205265875905752, + 0.05242094770073891, + 0.05065605416893959, + 0.01961991749703884, + 0.021542323753237724, + 0.04147094115614891, + 0.04451332613825798, + 0.05155060812830925, + 0.15659169852733612, + 0.4448348879814148, + 0.7207449078559875, + 0.8680058717727661, + 0.7269517779350281, + 0.36259666085243225, + 0.10394725203514099, + -0.20449180901050568, + -0.42664405703544617, + -0.7290332317352295, + -0.9376083016395569, + -0.735107958316803, + -0.3541502356529236, + -0.23789332807064056, + -0.10901623964309692, + -0.26809337735176086, + -0.38465574383735657, + -0.44440212845802307, + -0.4070444703102112, + -0.22405119240283966, + -0.14190013706684113, + 0.07151509076356888, + 0.21848519146442413, + 0.41893038153648376, + 0.4783499836921692, + 0.4281534254550934, + 0.28631147742271423, + 0.057699400931596756, + 0.0029010034631937742, + -0.02580493874847889, + -0.02152368798851967, + -0.025850815698504448, + 0.004789783153682947, + 0.021941278129816055, + 0.00574735039845109, + -0.004016151186078787, + -0.014377521350979805, + -0.0828985944390297, + -0.06380187720060349, + -0.048879947513341904, + -0.04580164700746536, + -0.030843649059534073, + 0.024663949385285378, + 0.03409295156598091, + 0.060452476143836975, + 0.037006158381700516, + 0.058853648602962494, + 0.07275765389204025, + 0.02882941998541355, + 0.14549848437309265, + 0.4268765151500702, + 0.7150183320045471, + 0.8942612409591675, + 0.7532845139503479, + 0.3846176564693451, + 0.15604183077812195, + -0.19108416140079498, + -0.42633384466171265, + -0.7508237361907959, + -0.9448286890983582, + -0.719300389289856, + -0.3583783805370331, + -0.2060524821281433, + -0.10382426530122757, + -0.2624296545982361, + -0.4049411416053772, + -0.4338999092578888, + -0.41390693187713623, + -0.22797809541225433, + -0.14593803882598877, + 0.08197329193353653, + 0.2430788278579712, + 0.3906225562095642, + 0.47147202491760254, + 0.42429792881011963, + 0.29326340556144714, + 0.06683206558227539, + 0.004355552606284618, + -0.007973028346896172, + 0.0035172239877283573, + -0.0018502225866541266, + -0.015291260555386543, + 0.0025160792283713818, + 0.0015979957534000278, + 0.011951611377298832, + -0.0004334237310104072, + -0.00172338483389467, + 0.017284434288740158, + -0.00445173867046833, + -0.004828867502510548, + 0.004030159674584866, + 0.03321678191423416, + -0.016998661682009697, + -0.029765218496322632, + -0.07912255078554153, + -0.0494595468044281, + 0.012136446312069893, + 0.029541414231061935, + -0.01129366084933281, + 0.09502168744802475, + 0.21533286571502686, + 0.3453419804573059, + 0.22987395524978638, + 0.04720258712768555, + 0.0032486498821526766, + -0.0042808204889297485, + -0.10162857174873352, + -0.21601493656635284, + -0.3040534257888794, + -0.19600912928581238, + -0.0568307563662529, + -0.0062937624752521515, + -0.021828925237059593, + -0.03831009939312935, + -0.08992031216621399, + -0.08103442937135696, + -0.07600760459899902, + -0.02319694682955742, + -0.008472982794046402, + -0.004151565954089165, + 0.05002164468169212, + 0.0985124409198761, + 0.11273156106472015, + 0.10279814153909683, + 0.032678257673978806, + -0.023295480757951736, + -0.022312145680189133, + 0.032877422869205475, + 0.08301658928394318, + -0.049675002694129944, + -0.05956050381064415, + 0.006878976244479418, + 0.011597251519560814, + -0.03617611899971962, + -0.005020621232688427, + 0.0066283573396503925, + 0.061849869787693024, + 0.0668889507651329, + -0.1120104044675827, + 0.0215831957757473, + -0.008177083916962147, + 0.019240612164139748, + -0.03794482350349426, + -0.21581093966960907, + 0.3248063623905182, + 0.0525924488902092, + -0.13873063027858734, + -0.030904211103916168, + -0.004122832324355841, + 0.2784009277820587, + -0.42068102955818176, + -0.15351417660713196, + 0.4266241192817688, + -0.10780557245016098, + 0.03840374946594238, + -0.15116721391677856, + 0.2292502224445343, + 0.23400554060935974, + -0.5023872256278992, + 0.14868289232254028, + 0.09809935092926025, + 0.03480924293398857, + -0.046804867684841156, + -0.14212554693222046, + 0.3073779344558716, + -0.029529480263590813, + -0.13998086750507355, + -0.02750661037862301, + 0.010526027530431747, + 0.032874979078769684, + -0.07645174115896225, + -0.02746269293129444, + 0.10902399569749832, + -0.00446560001000762, + -0.01339190173894167, + 0.003540819976478815, + -0.04410126060247421, + -0.10884726047515869, + 0.016081949695944786, + 0.15211890637874603, + 0.04027504846453667, + -0.05552368983626366, + 0.04718002676963806, + 0.014503135345876217, + -0.2764658033847809, + -0.16068166494369507, + 0.3356778621673584, + 0.06485499441623688, + -0.07164154946804047, + 0.084479421377182, + 0.2702949047088623, + -0.1339409202337265, + -0.9642015695571899, + 0.47433769702911377, + 0.4715694189071655, + -0.17669782042503357, + -0.04434441775083542, + 0.2641690671443939, + 0.7357130646705627, + -1.2222046852111816, + -0.8205837607383728, + 0.9091072678565979, + 0.14896778762340546, + -0.09332367032766342, + -0.16173647344112396, + 0.8782246708869934, + 0.3819980323314667, + -1.619883418083191, + 0.059255462139844894, + 0.42745286226272583, + -0.03186821565032005, + -0.16420172154903412, + 0.12124066799879074, + 0.8650834560394287, + -0.3728218674659729, + -0.5816569328308105, + 0.10949260741472244, + -0.010671291500329971, + -0.07903271913528442, + -0.09700250625610352, + 0.3192030191421509, + 0.2756008505821228, + -0.2616698145866394, + -0.11051242798566818, + 0.016789941117167473, + -0.0484573096036911, + -0.12333080172538757, + 0.0158428642898798, + 0.11172449588775635, + 0.014953864738345146, + -0.011746960692107677, + 0.05310823395848274, + 0.030244171619415283, + -0.23969320952892303, + -0.1039247065782547, + 0.285805881023407, + -0.04652552306652069, + -0.05380000174045563, + 0.05430186912417412, + 0.25547218322753906, + -0.06164371967315674, + -0.7386756539344788, + 0.4393811821937561, + 0.2623714804649353, + -0.1849273294210434, + -0.049713607877492905, + 0.1656467467546463, + 0.6638666391372681, + -0.899787187576294, + -0.5747878551483154, + 0.7465870976448059, + -0.025567445904016495, + -0.051771312952041626, + -0.19754628837108612, + 0.6828271746635437, + 0.4451557695865631, + -1.2559787034988403, + 0.07448688894510269, + 0.27905938029289246, + 0.003908769693225622, + -0.18454433977603912, + -0.011183545924723148, + 0.7449039816856384, + -0.228777676820755, + -0.47592073678970337, + 0.13784541189670563, + 0.019371675327420235, + -0.06424596160650253, + -0.1660400629043579, + 0.2080633044242859, + 0.2942465841770172, + -0.20263032615184784, + -0.0709841251373291, + -0.0021153483539819717, + -0.028180474415421486, + -0.021557176485657692, + 0.012511649169027805, + 0.06533018499612808, + 0.006560645066201687, + -0.01908997632563114, + -0.020228691399097443, + 0.10450740903615952, + 0.04476405307650566, + -0.20389842987060547, + -0.36356496810913086, + -0.18690945208072662, + 0.06581642478704453, + 0.005246834829449654, + -0.14777734875679016, + 0.04554577171802521, + 0.7314760088920593, + 1.1759854555130005, + 0.7747871279716492, + 0.08771117031574249, + 0.04425497353076935, + 0.14875195920467377, + -0.05036012455821037, + -1.0561891794204712, + -1.7835016250610352, + -1.313464879989624, + -0.4041728973388672, + -0.08825081586837769, + -0.18483860790729523, + -0.09619659930467606, + 0.6506555676460266, + 1.2331949472427368, + 1.057729721069336, + 0.3030258119106293, + 0.053314659744501114, + 0.10696353763341904, + 0.19720971584320068, + -0.19457301497459412, + -0.3546113669872284, + -0.3773464560508728, + 0.007737448439002037, + 0.007112926337867975, + -0.026632368564605713, + -0.07708505541086197, + 0.016982559114694595, + 0.03331448882818222, + 0.03235285356640816, + -0.04479134455323219, + 0.0062864539213478565, + -0.04983896017074585, + -0.014209658838808537, + 0.025105496868491173, + 0.07187403738498688, + -0.019782420247793198, + -0.0387532040476799, + 0.01098113413900137, + 0.10765481740236282, + -0.005502769257873297, + -0.29967597126960754, + -0.5370010733604431, + -0.25729984045028687, + 0.0341138020157814, + -0.01927473582327366, + -0.11736954003572464, + 0.09457080066204071, + 0.8881804943084717, + 1.5049697160720825, + 1.0347492694854736, + 0.22410355508327484, + -0.004720119293779135, + 0.1449226438999176, + -0.11916695535182953, + -1.2009364366531372, + -2.080855369567871, + -1.5549882650375366, + -0.5231477618217468, + -0.005029830615967512, + -0.11258674412965775, + 0.03710457682609558, + 0.9192798137664795, + 1.525830626487732, + 1.3018689155578613, + 0.44408130645751953, + 0.006972550880163908, + 0.07937697321176529, + 0.060622286051511765, + -0.4068094491958618, + -0.5964561104774475, + -0.6058750152587891, + -0.1743212193250656, + -0.0038881103973835707, + -0.04932431876659393, + -0.04989266395568848, + 0.07228495925664902, + 0.10359980911016464, + 0.11054171621799469, + 0.017031395807862282, + -0.012849675491452217, + -0.02224516123533249, + -0.019851619377732277, + 0.04567919671535492, + 0.12134519219398499, + 0.018673665821552277, + -0.03933878242969513, + 0.03506385162472725, + 0.07499910145998001, + -0.004981306381523609, + -0.269795298576355, + -0.4478399455547333, + -0.3141564130783081, + 0.014856644906103611, + -0.01102763693779707, + -0.11778493225574493, + -0.00048367868294008076, + 0.46917271614074707, + 0.8380635976791382, + 0.5829758048057556, + 0.14924737811088562, + 0.00504975114017725, + 0.1242799386382103, + 0.027800291776657104, + -0.5343790054321289, + -0.9185061454772949, + -0.6974499225616455, + -0.1733488291501999, + 0.028415951877832413, + -0.07513032108545303, + 0.010947657749056816, + 0.5501428246498108, + 0.8556726574897766, + 0.6854383945465088, + 0.21023745834827423, + -0.04757346957921982, + 0.028925150632858276, + -0.05005616322159767, + -0.4106282889842987, + -0.5990055203437805, + -0.5274976491928101, + -0.18928098678588867, + 0.007199999876320362, + 0.004744168370962143, + -0.006203897297382355, + 0.16117095947265625, + 0.20310591161251068, + 0.17358633875846863, + 0.057794276624917984, + 0.0018837900133803487, + -0.021730661392211914, + 0.03705505281686783, + 0.048999205231666565, + 0.017187459394335747, + -0.04760497808456421, + -0.06534644961357117, + 0.027641354128718376, + -0.02722003310918808, + -0.09557735174894333, + 0.2721945643424988, + 0.06861108541488647, + -0.17862513661384583, + 0.029542427510023117, + -0.028343068435788155, + -0.24357359111309052, + 0.2928915321826935, + 0.6317090392112732, + -0.5675624012947083, + -0.31298428773880005, + 0.119928739964962, + -0.04503166303038597, + 0.1997436285018921, + 0.9068917632102966, + -0.6105388402938843, + -1.176649808883667, + 0.391012579202652, + 0.21436090767383575, + 0.06404570490121841, + 0.4306352436542511, + -0.18372972309589386, + -1.6093186140060425, + 0.5129231810569763, + 0.8333584666252136, + -0.11607109010219574, + 0.024050598964095116, + -0.027272621169686317, + -0.8072280883789062, + 0.15613007545471191, + 1.0115277767181396, + -0.1886059194803238, + -0.1662863790988922, + -0.07484262436628342, + -0.11359186470508575, + -0.05765556916594505, + 0.48085057735443115, + 0.031143836677074432, + -0.20803743600845337, + 0.005643316078931093, + -0.011422591283917427, + -0.02063453011214733, + 0.010139239020645618, + 0.026931140571832657, + 0.02650240994989872, + 0.014503400772809982, + -0.030498046427965164, + 0.01038119662553072, + -0.041832923889160156, + -0.11747029423713684, + 0.24838468432426453, + 0.08126607537269592, + -0.17684465646743774, + 0.009867151267826557, + -0.04349489137530327, + -0.22892898321151733, + 0.3097872734069824, + 0.6229272484779358, + -0.5710748434066772, + -0.2540203332901001, + 0.15970031917095184, + -0.05765099450945854, + 0.24631772935390472, + 0.9121918678283691, + -0.6539115309715271, + -1.1680796146392822, + 0.43742635846138, + 0.1981748640537262, + 0.060766786336898804, + 0.48115089535713196, + -0.2704729437828064, + -1.668082594871521, + 0.6258481740951538, + 0.8217618465423584, + -0.17844447493553162, + 0.07583325356245041, + -0.031355466693639755, + -0.884739100933075, + 0.21298757195472717, + 1.0279508829116821, + -0.2118954360485077, + -0.16616611182689667, + -0.025157395750284195, + -0.11329160630702972, + -0.08147483319044113, + 0.46636614203453064, + 0.023730026558041573, + -0.21343427896499634, + -0.015201984904706478, + -0.00498165050521493, + 0.022955382242798805, + 0.020228328183293343, + -0.029405873268842697, + -0.032065436244010925, + 0.047389160841703415, + -0.01793060638010502, + 0.01669210195541382, + 0.05227159336209297, + -0.11703876405954361, + 0.006789325270801783, + 0.03741219639778137, + -0.04651298373937607, + -0.012846981175243855, + 0.024231625720858574, + -0.13399703800678253, + -0.024073680862784386, + 0.2970501184463501, + -0.1497301310300827, + -0.04287628084421158, + 0.08405227214097977, + -0.06020639091730118, + -0.01648692972958088, + 0.4150170087814331, + -0.17000712454319, + -0.43461430072784424, + 0.27202337980270386, + 0.006708468310534954, + -0.04474359005689621, + 0.15199843049049377, + -0.03348325565457344, + -0.6591396331787109, + 0.4057810306549072, + 0.25226324796676636, + -0.16070741415023804, + 0.03464199975132942, + 0.023064177483320236, + -0.35642316937446594, + 0.22774185240268707, + 0.37138837575912476, + -0.24171461164951324, + -0.023513946682214737, + 0.028774995356798172, + -0.02702418342232704, + -0.012504744343459606, + 0.17893734574317932, + -0.1554262489080429, + -0.09501983970403671, + 0.06177212670445442, + -0.013536165468394756, + 0.012441401369869709, + 0.006566522642970085, + -0.018207622691988945, + 0.003373368876054883, + -0.034891802817583084, + 0.002223123563453555, + 0.006169564090669155, + 0.022658145055174828, + -0.005327044054865837, + -0.023764559999108315, + -0.004386506043374538, + -0.02777106687426567, + 0.01950058527290821, + 0.004401096608489752, + 0.02882237359881401, + 0.01790205016732216, + -0.007827110588550568, + -0.005222277250140905, + -0.05361752584576607, + 0.008359426632523537, + -0.026494475081562996, + -0.015572195872664452, + -0.04412947595119476, + -0.006163781508803368, + 0.180303692817688, + 0.17117105424404144, + -0.014117442071437836, + 0.014543564058840275, + 0.03875281661748886, + 0.002004631096497178, + 0.11982911080121994, + 0.609316349029541, + 0.5792325735092163, + 0.10267578810453415, + -0.02287464588880539, + -0.011516223661601543, + -0.02587946131825447, + 0.019127164036035538, + 0.2742871046066284, + 0.23896890878677368, + -0.013414637185633183, + 0.012439075857400894, + 0.01148916780948639, + 0.0024075021501630545, + -0.028374193236231804, + -0.02938784286379814, + -0.061723873019218445, + -0.03288640081882477, + 0.010918691754341125, + 0.01171314436942339, + 0.00894222967326641, + -0.0050367508083581924, + 0.00322812981903553, + -0.01958087645471096, + 0.000401448953198269, + 0.00655051926150918, + 0.008647873997688293, + -0.015351405367255211, + -0.022286182269454002, + -0.0018973759142681956, + -0.032965533435344696, + 0.009401706047356129, + 0.01680464670062065, + 0.01722409576177597, + 0.017367251217365265, + -0.0012145076179876924, + 0.015895379707217216, + -0.013976357877254486, + 0.01587546430528164, + -0.019388504326343536, + -0.004597584251314402, + -0.026080038398504257, + 0.020517753437161446, + 0.20680218935012817, + 0.20302064716815948, + 0.03813354671001434, + 0.027738921344280243, + 0.02183712273836136, + 0.023807305842638016, + 0.14632326364517212, + 0.5991678237915039, + 0.608651340007782, + 0.15929070115089417, + -0.02112223394215107, + -0.020013611763715744, + -0.03723381832242012, + 0.032139480113983154, + 0.27032363414764404, + 0.24862462282180786, + 0.02374681644141674, + 0.007894856855273247, + 0.00042308925185352564, + -0.004832752980291843, + -0.024313796311616898, + -0.0018940505106002092, + -0.02681432105600834, + 0.002362651750445366, + 0.013330202549695969, + 0.012553646229207516, + 0.002630018163472414, + 0.002979951212182641, + 0.0015847217291593552, + -0.03376828506588936, + -0.010844729840755463, + -0.002748559694737196, + 0.012938202358782291, + -0.011872833594679832, + -0.0025761008728295565, + 0.003677211469039321, + -0.04305516183376312, + 0.001133457524701953, + 0.0020396243780851364, + 0.01797032356262207, + 0.016580887138843536, + 0.04445189982652664, + 0.013270077295601368, + -0.04839251935482025, + 0.011546633206307888, + -0.015829432755708694, + 0.019473392516374588, + -0.011464826762676239, + 0.018693143501877785, + 0.18201367557048798, + 0.16157257556915283, + 0.02082117274403572, + 0.015915032476186752, + 0.010720869526267052, + -0.0020238866563886404, + 0.09329187124967575, + 0.46998023986816406, + 0.5186727046966553, + 0.09814783185720444, + -0.016547314822673798, + 0.00325066689401865, + -0.028936590999364853, + 0.01002424769103527, + 0.21822214126586914, + 0.22012007236480713, + 0.008229314349591732, + 0.015599996782839298, + 0.014740276150405407, + 0.0019725109450519085, + 0.003613655688241124, + -0.03043546713888645, + -0.06308998167514801, + 0.014664110727608204, + 0.06775129586458206, + -0.12990300357341766, + -0.03638269379734993, + -0.03883139044046402, + 0.05194637551903725, + 0.03896122798323631, + -0.05132362246513367, + -0.07234688848257065, + -0.36106064915657043, + -0.2839237451553345, + -0.11496391147375107, + 0.3026673197746277, + 0.3528609871864319, + 0.21559017896652222, + -0.11970120668411255, + -0.5473688244819641, + -0.5362005233764648, + -0.21015112102031708, + 0.4089161455631256, + 0.6033567786216736, + 0.38614287972450256, + -0.12437233328819275, + -0.6394402384757996, + -0.6945835947990417, + -0.3482857942581177, + 0.5189254283905029, + 0.8457668423652649, + 0.6248002648353577, + -0.12700730562210083, + -0.6978924870491028, + -0.7764106392860413, + -0.4171960651874542, + 0.44747814536094666, + 0.8406224846839905, + 0.6821274161338806, + -0.07793218642473221, + -0.5459966659545898, + -0.6139025092124939, + -0.35998886823654175, + 0.27800890803337097, + 0.6048891544342041, + 0.591307520866394, + -0.04850815609097481, + -0.3863481283187866, + -0.3542836606502533, + -0.2491992861032486, + 0.1616278886795044, + 0.3402666747570038, + 0.4610227644443512, + -0.010262396186590195, + 0.0408165417611599, + 0.006382474210113287, + -0.011430315673351288, + -0.027895113453269005, + -0.009767768904566765, + 0.005882019177079201, + 0.05225436016917229, + 0.0415218211710453, + 0.08244743943214417, + 0.026765575632452965, + -0.05404946208000183, + -0.06101839989423752, + -0.028233220800757408, + 0.03128793090581894, + 0.07133004069328308, + 0.0718698799610138, + 0.042146697640419006, + -0.08380170166492462, + -0.09263177216053009, + -0.07569421827793121, + 0.032425008714199066, + 0.12351400405168533, + 0.09103626012802124, + -0.004768018145114183, + -0.05960838869214058, + -0.11922567337751389, + -0.10132396221160889, + 0.044341862201690674, + 0.100867860019207, + 0.09607693552970886, + -0.00129030947573483, + -0.05481477826833725, + -0.1278291642665863, + -0.12058380991220474, + 0.016678951680660248, + 0.09958931058645248, + 0.08456224203109741, + 0.061599165201187134, + -0.049776893109083176, + -0.11354166269302368, + -0.09844806790351868, + 0.004753128159791231, + 0.07868346571922302, + 0.06464104354381561, + 0.020981626585125923, + -0.010770543478429317, + -0.08838209509849548, + -0.07265795767307281, + -0.058313023298978806, + 0.10897739976644516, + 0.026735201478004456, + 0.03972309082746506, + -0.019998662173748016, + -0.048948734998703, + 0.03377270698547363, + 0.053406376391649246, + 0.27304399013519287, + 0.20850272476673126, + 0.07890326529741287, + -0.22241365909576416, + -0.2816997468471527, + -0.1745096743106842, + 0.08957889676094055, + 0.4962941110134125, + 0.4586986303329468, + 0.20177948474884033, + -0.3625744581222534, + -0.47758376598358154, + -0.32412785291671753, + 0.0669194757938385, + 0.5394997596740723, + 0.601328432559967, + 0.24388420581817627, + -0.4319041073322296, + -0.6893490552902222, + -0.5106037259101868, + 0.10174300521612167, + 0.5457565784454346, + 0.6549625992774963, + 0.38772058486938477, + -0.3778320252895355, + -0.6820934414863586, + -0.551069438457489, + 0.049600999802351, + 0.45137161016464233, + 0.5143972039222717, + 0.3713279068470001, + -0.26546329259872437, + -0.5121409893035889, + -0.47691628336906433, + 0.03843758627772331, + 0.30808231234550476, + 0.3185756504535675, + 0.22629432380199432, + -0.14860986173152924, + -0.2915389835834503, + -0.3552006185054779, + -0.003137432038784027, + -0.01327254343777895, + -0.027139298617839813, + 0.04800891876220703, + 0.05380738899111748, + -0.01380784809589386, + 0.0022881641052663326, + -0.012132279574871063, + 0.06182793900370598, + 0.03762871399521828, + 0.0966145321726799, + 0.08963571488857269, + 0.06551238149404526, + 0.031640589237213135, + -0.010532311163842678, + 0.07195396721363068, + 0.11343465745449066, + 0.11621421575546265, + 0.047318290919065475, + 0.1111951395869255, + 0.044054243713617325, + 0.016777141019701958, + 0.03392713516950607, + 0.06047024950385094, + -0.7924502491950989, + -0.7310910224914551, + 0.031088173389434814, + 0.0906061977148056, + 0.022829236462712288, + 0.04470035433769226, + 0.025999872013926506, + -0.8246837258338928, + -0.723675549030304, + 0.15835590660572052, + 0.07358791679143906, + -0.015819497406482697, + -0.014207872562110424, + 0.08506257086992264, + 0.08868777751922607, + 0.0976945012807846, + 0.11740022897720337, + 0.016287995502352715, + -0.024363648146390915, + 0.04249691963195801, + 0.02909177541732788, + 0.12011238187551498, + 0.10729824751615524, + 0.05927390977740288, + 0.04731644690036774, + 0.008210064843297005, + 0.03859357163310051, + -0.005175672471523285, + 0.01984376832842827, + -0.0011626111809164286, + -0.0010909241391345859, + 0.02311880886554718, + 0.007646523881703615, + 0.04582137614488602, + -0.0027255103923380375, + 0.027656713500618935, + 0.02781369723379612, + 0.015750093385577202, + 0.040563344955444336, + -0.007784596644341946, + 0.006534814368933439, + 0.002403199439868331, + -0.020037032663822174, + -0.011717663146555424, + 0.07826739549636841, + 0.018203573301434517, + 0.021228624507784843, + 0.014112413860857487, + -0.02866269089281559, + -0.9502679109573364, + -0.825043797492981, + 0.05938851460814476, + 0.06553053110837936, + 0.015418429858982563, + 0.0616452619433403, + -0.0094453701749444, + -0.9471839666366577, + -0.7922234535217285, + 0.13069523870944977, + 0.04939320683479309, + 0.007429714780300856, + 0.022599652409553528, + 0.0820123627781868, + 0.06440276652574539, + 0.09897352755069733, + 0.0856291800737381, + 0.006608777679502964, + -0.0005533680086955428, + 0.021656949073076248, + 0.014818831346929073, + 0.03757459297776222, + -0.001428246614523232, + 0.03473127633333206, + 0.03607869893312454, + 0.017313262447714806, + 0.0025767614133656025, + -0.033292777836322784, + 0.027883101254701614, + -0.007534499745815992, + -0.04302362725138664, + -0.01795666106045246, + -0.007667913101613522, + 0.012547189369797707, + -0.021762438118457794, + 0.03789107874035835, + 0.06384614109992981, + 0.0014223429607227445, + -0.01393786258995533, + -0.041693057864904404, + -0.01813604310154915, + 0.065328449010849, + 0.15736474096775055, + 0.1531635969877243, + 0.09920474886894226, + -0.04044449329376221, + 0.010558396577835083, + 0.05559245124459267, + 0.10931257158517838, + -0.5784384608268738, + -0.5109886527061462, + 0.17690584063529968, + 0.07484250515699387, + 0.010378374718129635, + 0.0890144556760788, + 0.13172735273838043, + -0.6058865785598755, + -0.49908995628356934, + 0.1835336685180664, + 0.005293308291584253, + -0.03870566934347153, + -0.025229454040527344, + 0.12571711838245392, + 0.14792272448539734, + 0.14905226230621338, + 0.0700206533074379, + -0.035034529864788055, + 0.013128797523677349, + 0.015581230632960796, + 0.005400130525231361, + 0.07070232182741165, + 0.03829728811979294, + -0.013876918703317642, + -0.019958000630140305, + -0.020086020231246948, + -0.019999003037810326, + -0.015111410059034824, + 0.11963249742984772, + -0.08270428329706192, + -0.0025947154499590397, + -0.010668564587831497, + 0.016670405864715576, + -0.03206938877701759, + -0.053453829139471054, + 0.1236601173877716, + -0.020077411085367203, + 0.00779569149017334, + -0.0318986251950264, + 0.03579804673790932, + -0.060723867267370224, + -0.009301809594035149, + 0.09249342232942581, + -0.13378725945949554, + 0.17496798932552338, + -0.0935625433921814, + 0.06569044291973114, + -0.18187756836414337, + 0.06397300213575363, + 0.3793930113315582, + -0.5664302706718445, + 0.23658618330955505, + -0.03206830099225044, + 0.03155658766627312, + 0.039305318146944046, + -0.6008145213127136, + 1.0417630672454834, + -0.5062726140022278, + -0.04698493704199791, + 0.0979752242565155, + -0.037326715886592865, + 0.26255178451538086, + -0.590207576751709, + 0.4195419251918793, + 0.12212422490119934, + -0.26122942566871643, + 0.06442253291606903, + -0.07682429254055023, + 0.12608948349952698, + -0.13872937858104706, + -0.030260663479566574, + 0.2047160565853119, + -0.13068141043186188, + 0.016608506441116333, + -0.021629147231578827, + 0.04659907519817352, + 0.024417348206043243, + 0.06751634925603867, + -0.1705978959798813, + 0.0655774399638176, + -0.0041802311316132545, + -0.02263445220887661, + -0.014069054275751114, + 0.06242800131440163, + 0.08984102308750153, + -0.19382472336292267, + 0.09380361437797546, + -0.0032764992211014032, + -0.03950225189328194, + -0.08896161615848541, + 0.28387022018432617, + 0.1668996810913086, + -0.5457127094268799, + 0.21796099841594696, + 0.012032964266836643, + 0.030721815302968025, + -0.4431600570678711, + 0.3104412257671356, + 1.0070439577102661, + -1.1077969074249268, + 0.08187273889780045, + 0.1387241780757904, + 0.09014563262462616, + -0.25378379225730896, + -0.9253583550453186, + 1.9745515584945679, + -0.6605072617530823, + -0.4394792318344116, + 0.11501576751470566, + 0.03007262572646141, + 0.2538164258003235, + -1.1462018489837646, + 0.7988958954811096, + 0.46934643387794495, + -0.4244523048400879, + -0.0001816617150325328, + -0.04351970925927162, + 0.20500127971172333, + -0.40710335969924927, + -0.15871365368366241, + 0.4640160799026489, + -0.06024328991770744, + -0.016036653891205788, + -0.012419192120432854, + 0.05552554875612259, + 0.050986770540475845, + -0.0171927809715271, + -0.12105240672826767, + 0.03947274759411812, + 0.009537882171571255, + -0.026668362319469452, + 0.017273351550102234, + 0.10812800377607346, + -0.015008139424026012, + -0.14154496788978577, + 0.08008233457803726, + -0.01306608971208334, + -0.05574854835867882, + -0.06091056764125824, + 0.2888447940349579, + 0.05022002384066582, + -0.4581625759601593, + 0.21146118640899658, + -0.01495362538844347, + 0.02946372702717781, + -0.38554418087005615, + 0.30167311429977417, + 0.7605867981910706, + -0.898481547832489, + 0.11953620612621307, + 0.12686115503311157, + 0.09949854761362076, + -0.14409342408180237, + -0.7404491901397705, + 1.5449001789093018, + -0.5307857394218445, + -0.3347839415073395, + 0.09940771013498306, + 0.009087899699807167, + 0.3081797957420349, + -0.9053899049758911, + 0.5102643370628357, + 0.4646914303302765, + -0.36200836300849915, + -0.043260715901851654, + -0.05309509113430977, + 0.22480911016464233, + -0.2674587666988373, + -0.25316888093948364, + 0.435017466545105, + -0.017485838383436203, + -0.049459364265203476, + 0.012460661120712757, + -0.02262282371520996, + -0.04392899200320244, + 0.013330060057342052, + 0.05963548645377159, + -0.020561739802360535, + -0.013496879488229752, + -0.02310933545231819, + -0.06549905985593796, + 0.12132573872804642, + 0.22165189683437347, + -0.07683887332677841, + -0.12427931278944016, + 0.05543455854058266, + 0.009089780040085316, + 0.19844494760036469, + 0.07650767266750336, + -0.48934996128082275, + -0.35080164670944214, + 0.13422781229019165, + 0.022217294201254845, + -0.006589306052774191, + -0.18357548117637634, + -0.6055922508239746, + 0.09492127597332001, + 0.7073907256126404, + 0.1777055710554123, + -0.05434347689151764, + 0.04566245526075363, + -0.023967979475855827, + 0.4856843054294586, + 0.8131930828094482, + -0.2068077027797699, + -0.3863125145435333, + 0.02887917123734951, + -0.05048410966992378, + 0.051201049238443375, + 0.057671088725328445, + -0.6412642002105713, + -0.39739903807640076, + 0.11036981642246246, + 0.06687764078378677, + -0.018151026219129562, + 0.0022760110441595316, + -0.09328305721282959, + 0.1352599710226059, + 0.19680921733379364, + 0.032235175371170044, + -0.06123670935630798, + -0.013810456730425358, + -0.01821190118789673, + -0.029903864488005638, + 0.027588335797190666, + 0.0762094110250473, + -0.046041399240493774, + 0.017117975279688835, + -0.018925148993730545, + 0.00423092395067215, + 0.2065701186656952, + 0.157025545835495, + -0.26491472125053406, + -0.24569831788539886, + 0.0873267725110054, + 0.004694689530879259, + 0.1838335543870926, + -0.18973900377750397, + -0.9744532108306885, + -0.41959065198898315, + 0.409589946269989, + 0.22223009169101715, + -0.0989728644490242, + -0.40883490443229675, + -0.8418471813201904, + 0.40256521105766296, + 1.4742398262023926, + 0.4913789629936218, + -0.14741277694702148, + -0.0028576564509421587, + 0.0861843004822731, + 1.0056577920913696, + 1.479182481765747, + -0.21940617263317108, + -0.8383130431175232, + -0.30560192465782166, + 0.12028121203184128, + 0.24013034999370575, + 0.11750353127717972, + -1.1071972846984863, + -0.9066778421401978, + -0.055051110684871674, + 0.15361995995044708, + 0.0032418384216725826, + -0.08823435008525848, + -0.3188804090023041, + -0.02160414680838585, + 0.2972750663757324, + 0.17006494104862213, + 0.03401973098516464, + 0.017106015235185623, + 0.010733614675700665, + 0.004688877146691084, + 0.02985573373734951, + 0.046415988355875015, + -0.05177726596593857, + -0.04624386876821518, + 0.026672907173633575, + 0.03479000926017761, + 0.22761401534080505, + 0.12049756944179535, + -0.23494181036949158, + -0.2207801640033722, + 0.06036320701241493, + 0.02112250216305256, + 0.16173022985458374, + -0.14196650683879852, + -0.8236543536186218, + -0.3530665934085846, + 0.3715725541114807, + 0.25781863927841187, + -0.09806561470031738, + -0.341796338558197, + -0.7201419472694397, + 0.2111824005842209, + 1.1648427248001099, + 0.3866075575351715, + -0.1955428272485733, + -0.13164694607257843, + -0.06048528477549553, + 0.7989920973777771, + 1.143347144126892, + -0.19509637355804443, + -0.6719933152198792, + -0.26912447810173035, + 0.16733723878860474, + 0.32526257634162903, + 0.1910397708415985, + -0.8516904711723328, + -0.6005953550338745, + 0.10627525299787521, + 0.16700856387615204, + 0.032433755695819855, + -0.11345972120761871, + -0.270126610994339, + -0.012052524834871292, + 0.25489771366119385, + 0.14647918939590454, + -0.014324051328003407, + -0.011148945428431034, + -0.0011708218371495605, + -0.018903911113739014, + -0.010648071765899658, + -0.017981043085455894, + 0.014055400155484676, + -0.020784996449947357, + -0.030126383528113365, + 0.1150858998298645, + -0.1112036183476448, + -0.023664508014917374, + 0.1651369333267212, + -0.055412910878658295, + -0.007318025920540094, + -0.07404221594333649, + 0.3068569302558899, + -0.6175673007965088, + 0.35226404666900635, + 0.1940349042415619, + -0.22921296954154968, + 0.06411048769950867, + 0.001689439988695085, + 0.23336739838123322, + -0.9470900893211365, + 1.2042961120605469, + -0.44587329030036926, + -0.15847182273864746, + 0.07572423666715622, + 0.11138042062520981, + -0.2075018584728241, + -0.2651064693927765, + 0.8896074295043945, + -0.7130936980247498, + 0.10370831191539764, + 0.07730382680892944, + 0.02368813008069992, + -0.20520009100437164, + 0.13611918687820435, + 0.31062978506088257, + -0.471883624792099, + 0.21489326655864716, + -0.0216743852943182, + -0.04020361602306366, + -0.022920167073607445, + 0.16054102778434753, + -0.002624030224978924, + -0.14670424163341522, + 0.12018264085054398, + -0.043656397610902786, + -0.005084550939500332, + 0.03873870149254799, + -0.07967288792133331, + -0.007439201697707176, + 0.027688704431056976, + 0.08916077762842178, + -0.0036629599053412676, + -0.01389122661203146, + 0.1402083784341812, + -0.2923351228237152, + -0.01932896114885807, + 0.224355086684227, + -0.013193303719162941, + -0.03984276205301285, + -0.04474477842450142, + 0.3302844762802124, + -0.9746807217597961, + 0.5603556036949158, + 0.3556183874607086, + -0.2713812589645386, + 0.01890619471669197, + 0.06983876973390579, + 0.09052442759275436, + -1.3613605499267578, + 1.8220031261444092, + -0.40902698040008545, + -0.31302449107170105, + 0.03893759846687317, + 0.11448371410369873, + -0.4220678210258484, + -0.3677598237991333, + 1.539440631866455, + -0.8297391533851624, + -0.08504960685968399, + 0.0629446730017662, + -0.016804160550236702, + -0.31778836250305176, + 0.2363198846578598, + 0.6452136635780334, + -0.700931191444397, + 0.09927428513765335, + 0.0019635935313999653, + -0.05397690460085869, + -0.014552262611687183, + 0.2352754771709442, + 0.09991656988859177, + -0.28891685605049133, + 0.07818552106618881, + -0.021534763276576996, + -0.009461677633225918, + -0.01069199200719595, + -0.008059840649366379, + -0.0129952197894454, + 0.038492631167173386, + 0.018906958401203156, + -0.025432486087083817, + -0.03420932963490486, + 0.09104404598474503, + -0.10342919826507568, + -0.035048507153987885, + 0.1415904313325882, + -0.052986644208431244, + -0.021596742793917656, + -0.049690280109643936, + 0.3079117238521576, + -0.5487046837806702, + 0.27024003863334656, + 0.15158434212207794, + -0.16488635540008545, + 0.027642132714390755, + 0.004561549983918667, + 0.21555493772029877, + -0.9188903570175171, + 1.0972669124603271, + -0.3528037667274475, + -0.07574182748794556, + 0.021962830796837807, + 0.08826783299446106, + -0.18681983649730682, + -0.2789378762245178, + 0.864517331123352, + -0.5642455816268921, + 0.07469761371612549, + 0.03803368657827377, + 0.014268620871007442, + -0.17712704837322235, + 0.1349189728498459, + 0.3181247115135193, + -0.45067182183265686, + 0.1391848623752594, + 0.009777083061635494, + -0.028080958873033524, + -0.03586730733513832, + 0.14503192901611328, + -0.014655024744570255, + -0.1472700834274292, + 0.07361634075641632, + -0.0029754601418972015, + -0.006887470372021198, + -0.019166842103004456, + 0.0034907464869320393, + -0.015169994905591011, + 0.053831856697797775, + -0.028789488598704338, + -0.02033298648893833, + 0.0018537036376073956, + 0.07567961513996124, + -0.07041627168655396, + -0.047083087265491486, + 0.17573483288288116, + -0.04860217124223709, + 0.013171656988561153, + 0.020158233121037483, + -0.006270059384405613, + -0.28434091806411743, + 0.2760852873325348, + 0.32198208570480347, + -0.43535903096199036, + 0.03188510239124298, + 0.019360313192009926, + -0.20063988864421844, + 0.04450676590204239, + 0.9678076505661011, + -0.683987021446228, + -0.3979112207889557, + 0.2618143558502197, + -0.049711134284734726, + -0.06456997990608215, + 0.6518288850784302, + -0.1357039213180542, + -1.1304017305374146, + 0.4881652295589447, + 0.19583553075790405, + -0.03677722439169884, + 0.21429045498371124, + 0.09559855610132217, + -0.7311355471611023, + 0.10988117009401321, + 0.4949330687522888, + -0.17359353601932526, + 0.03822369873523712, + 0.011371256783604622, + -0.1900172382593155, + -0.04778448864817619, + 0.2897090017795563, + -0.02235160581767559, + -0.05582524091005325, + 0.007624597754329443, + -0.027456223964691162, + -0.029680097475647926, + -0.023810429498553276, + 0.15409281849861145, + 0.013284318149089813, + -0.0788225457072258, + -0.025637971237301826, + 0.01406402699649334, + -0.13676859438419342, + 0.027384959161281586, + 0.30458444356918335, + -0.11150643229484558, + -0.06806201487779617, + 0.009601237252354622, + -0.0866582989692688, + -0.2328706979751587, + 0.5188567638397217, + 0.3787381649017334, + -0.655829906463623, + 0.0072118742391467094, + -0.0031494891736656427, + -0.2424815446138382, + 0.28893929719924927, + 1.2396824359893799, + -1.0406886339187622, + -0.6376030445098877, + 0.4103420078754425, + -0.05929668992757797, + 0.03918358311057091, + 0.9274081587791443, + -0.28890565037727356, + -1.6682262420654297, + 0.66976398229599, + 0.35488471388816833, + 0.027932289987802505, + 0.3169145882129669, + 0.09107685089111328, + -1.2099432945251465, + 0.11623579263687134, + 0.7632684707641602, + -0.16506360471248627, + 0.037474747747182846, + -0.005203985143452883, + -0.35939401388168335, + -0.17138688266277313, + 0.525232195854187, + 0.10247340798377991, + -0.14317406713962555, + 0.007572649512439966, + -0.006046198774129152, + 0.06188087910413742, + -0.050851333886384964, + 0.032844241708517075, + 0.0544477179646492, + -0.07947597652673721, + -0.03073730878531933, + 0.04025515541434288, + -0.010001083835959435, + -0.11831062287092209, + 0.17422229051589966, + -0.05468267202377319, + -0.04996664077043533, + 0.023996006697416306, + 0.02888253889977932, + -0.18709556758403778, + 0.13987921178340912, + 0.32867854833602905, + -0.31714990735054016, + 0.019951285794377327, + 0.027247004210948944, + -0.19416090846061707, + -0.006519266404211521, + 0.7540720105171204, + -0.5474190711975098, + -0.27137213945388794, + 0.20772530138492584, + -0.042619917541742325, + -0.09566087275743484, + 0.548494815826416, + -0.1599852293729782, + -0.9178788661956787, + 0.5456539988517761, + 0.07497559487819672, + 0.003984459210187197, + 0.18640351295471191, + 0.12121234089136124, + -0.7249511480331421, + 0.2559764087200165, + 0.4684237241744995, + -0.19216996431350708, + 0.018075481057167053, + 0.02684594877064228, + -0.221074178814888, + -0.09164194762706757, + 0.3596596121788025, + -0.08310746401548386, + -0.10815230011940002, + -0.015406409278512001, + -0.011985878460109234, + 0.028467312455177307, + -0.0879230722784996, + 0.0347294844686985, + 0.05081191286444664, + 0.00362736196257174, + 0.010529003106057644, + -0.002672453410923481, + 0.025318201631307602, + -0.06232529878616333, + 0.008822780102491379, + 0.06744717806577682, + 0.003999210894107819, + -0.0022885131184011698, + -0.046704765409231186, + 0.13673964142799377, + -0.2590992748737335, + -0.022161437198519707, + 0.258914053440094, + -0.10650330036878586, + 0.023435762152075768, + 0.06992689520120621, + 0.03760937228798866, + -0.5444027185440063, + 0.4131152629852295, + 0.25325170159339905, + -0.2482522875070572, + 0.010479461401700974, + 0.045747850090265274, + -0.1541248857975006, + -0.35291528701782227, + 0.9078133702278137, + -0.34428781270980835, + -0.14787709712982178, + -0.024105649441480637, + -0.007651817053556442, + -0.14991067349910736, + 0.17544956505298615, + 0.3692120611667633, + -0.46861159801483154, + 0.10201738774776459, + 0.003734431229531765, + -0.010433703660964966, + 0.022045455873012543, + 0.0944862961769104, + 0.01679016835987568, + -0.16537833213806152, + 0.07900089025497437, + -0.004211293533444405, + -0.01076442189514637, + 0.09729930013418198, + -0.1490965485572815, + -0.02511671558022499, + 0.0766475573182106, + 0.010980346240103245, + -0.010220799595117569, + -0.0004861881607212126, + 0.09204736351966858, + -0.179045170545578, + -0.025164175778627396, + 0.15608654916286469, + 0.004787537269294262, + -0.0005253870622254908, + 0.034556396305561066, + 0.1509256660938263, + -0.5432079434394836, + -0.03155849874019623, + 0.513609766960144, + -0.14458952844142914, + 0.015178131870925426, + 0.09172039479017258, + -0.12612608075141907, + -0.926306962966919, + 0.8281942009925842, + 0.5954549908638, + -0.492740273475647, + 0.007195526268333197, + -0.018258413299918175, + -0.4074647128582001, + -0.43008187413215637, + 1.7370752096176147, + -0.350849986076355, + -0.5158001780509949, + -0.017458094283938408, + -0.08306471258401871, + -0.2334563285112381, + 0.445117712020874, + 0.7808031439781189, + -0.7913723587989807, + -0.11814796179533005, + -0.00913319457322359, + 0.0223994143307209, + 0.1012248545885086, + 0.25349485874176025, + 0.028286214917898178, + -0.4809858798980713, + 0.05953341722488403, + 0.015634188428521156, + 0.005101620219647884, + 0.10901974141597748, + -0.11964976042509079, + -0.09117673337459564, + 0.0734483003616333, + 0.01821213960647583, + 5.350751234800555e-05, + -0.020279232412576675, + 0.1097220927476883, + -0.1354990452528, + -0.08653146773576736, + 0.11775246262550354, + -0.012575668282806873, + 0.0310806967318058, + 0.010271146893501282, + 0.20337054133415222, + -0.3854014277458191, + -0.09943562000989914, + 0.3921409249305725, + -0.08432158827781677, + 0.010676748119294643, + 0.040244489908218384, + -0.0015478944405913353, + -0.7022866010665894, + 0.49858638644218445, + 0.42338883876800537, + -0.2982582449913025, + -0.005396307446062565, + -0.008777705952525139, + -0.2325415015220642, + -0.4083922803401947, + 1.186205506324768, + -0.26399391889572144, + -0.2621048092842102, + -0.015712907537817955, + -0.04675402492284775, + -0.1797540783882141, + 0.2992522716522217, + 0.4747498333454132, + -0.5266988277435303, + 0.04581758379936218, + -0.04037958011031151, + 0.0071074217557907104, + 0.047499995678663254, + 0.16617828607559204, + -0.03973710536956787, + -0.2953551113605499, + 0.10628587752580643, + -0.00904526561498642, + 0.010427894070744514, + 0.08035022020339966, + 0.03841109946370125, + -0.06335253268480301, + -0.06992083787918091, + 0.015409895218908787, + -0.026900725439190865, + -0.04523912072181702, + 0.08087682723999023, + 0.12542113661766052, + 0.018750213086605072, + -0.23430712521076202, + 0.11755944788455963, + -0.019747508689761162, + -0.03171322122216225, + -0.12132623791694641, + 0.2640603184700012, + 0.38445138931274414, + -0.5724408030509949, + 0.15661633014678955, + 0.01949799247086048, + -0.021771302446722984, + -0.18984957039356232, + -0.23499636352062225, + 1.2112919092178345, + -0.7037869095802307, + -0.14260035753250122, + 0.01848726160824299, + 0.06443414837121964, + -0.11740390956401825, + -0.8794785141944885, + 1.4160369634628296, + 0.016899125650525093, + -0.5444768071174622, + 0.017313210293650627, + 0.0508052259683609, + 0.11102095246315002, + -0.790285587310791, + 0.3501206636428833, + 0.7238660454750061, + -0.49468666315078735, + -0.019021952524781227, + -0.01212992612272501, + 0.15032203495502472, + -0.3573611080646515, + -0.1293754130601883, + 0.45295456051826477, + -0.08407819271087646, + -0.008717959746718407, + 0.022566653788089752, + -0.012640242464840412, + 0.03181227669119835, + 0.0638526976108551, + -0.058120664209127426, + -0.042917650192976, + 0.02129550836980343, + -0.018790805712342262, + -0.00655191857367754, + 0.05951414257287979, + 0.12890471518039703, + -0.1886381357908249, + 0.059096939861774445, + -0.016928592696785927, + 0.02327263168990612, + -0.17282842099666595, + 0.13812857866287231, + 0.38889989256858826, + -0.5282873511314392, + 0.07564643770456314, + -0.006128210574388504, + -0.00876594614237547, + -0.18427829444408417, + -0.26697441935539246, + 1.2529815435409546, + -0.6549165844917297, + -0.2111111879348755, + 0.011410325765609741, + 0.07089994102716446, + -0.12627695500850677, + -0.8245998024940491, + 1.4581915140151978, + -0.01822204887866974, + -0.5626582503318787, + -0.01661459542810917, + 0.03759436681866646, + 0.10841676592826843, + -0.7652962803840637, + 0.4360819458961487, + 0.7012669444084167, + -0.47011038661003113, + 0.01529701892286539, + -0.0033166150096803904, + 0.12170535326004028, + -0.3871544301509857, + -0.05247795954346657, + 0.4504147171974182, + -0.11442532390356064, + -0.00882577896118164, + 0.005190832540392876, + -0.05153197422623634, + 0.0055236960761249065, + 0.09320031106472015, + -0.03762076050043106, + -0.021778371185064316, + 0.00750907463952899, + 0.014965789392590523, + -0.015135630965232849, + -0.037086039781570435, + 0.08020154386758804, + -0.04429963231086731, + 0.0038218852132558823, + -0.01712334342300892, + 0.053772956132888794, + -0.05226677283644676, + -0.024439912289381027, + 0.12774989008903503, + -0.18722355365753174, + 0.0683830976486206, + -0.010828870348632336, + -0.012880662456154823, + 0.02679484151303768, + -0.13696907460689545, + 0.46868517994880676, + -0.322968989610672, + 0.052930932492017746, + 0.009463602676987648, + -0.046861011534929276, + 0.07714711129665375, + -0.35792097449302673, + 0.5517901182174683, + -0.13382655382156372, + -0.12921281158924103, + 0.018562642857432365, + -0.03842621296644211, + 0.10284601897001266, + -0.28243398666381836, + 0.13314206898212433, + 0.20769073069095612, + -0.1551610678434372, + 0.018036767840385437, + -0.03553476929664612, + 0.036686040461063385, + -0.09568552672863007, + 0.008917863480746746, + 0.11340243369340897, + -0.04745811969041824, + 0.005833764094859362, + -0.04174824804067612, + 0.022730106487870216, + 0.0013601485406979918, + -0.07473982870578766, + -0.004801879171282053, + 0.05632775276899338, + -0.04081303998827934, + 0.11509573459625244, + 0.004507652949541807, + -0.24791881442070007, + 0.43171870708465576, + -0.1362573653459549, + -0.10758046060800552, + 0.02746163308620453, + -0.2954745888710022, + 0.30186471343040466, + 0.3135572075843811, + -1.2296111583709717, + 0.8754236102104187, + -0.11699853837490082, + 0.022482017055153847, + 0.24945153295993805, + -0.7858022451400757, + 0.5181443095207214, + 1.4243930578231812, + -1.876152515411377, + 0.4689188003540039, + 0.04258054122328758, + -0.030832920223474503, + 0.9340220093727112, + -1.512351632118225, + -0.3731614947319031, + 2.021338701248169, + -0.7801089286804199, + -0.09288544207811356, + -0.12423597276210785, + -0.36861127614974976, + 1.1679530143737793, + -0.4960964024066925, + -1.0398281812667847, + 0.686152458190918, + 0.02052121050655842, + 0.07246638089418411, + -0.01763315312564373, + -0.37442535161972046, + 0.33217450976371765, + 0.22260302305221558, + -0.2657756209373474, + 0.00016369696822948754, + 0.008136127144098282, + -0.03592197597026825, + 0.022231513634324074, + 0.041430093348026276, + -0.06439317017793655, + 0.03496818616986275, + -0.05143435671925545, + 0.09930871427059174, + 0.017110232263803482, + -0.3834381699562073, + 0.44344815611839294, + -0.00280396337620914, + -0.11487428843975067, + 0.050503507256507874, + -0.22837062180042267, + 0.47540077567100525, + 0.5802375674247742, + -1.7325034141540527, + 0.8587368130683899, + 0.10429240018129349, + -0.02456486038863659, + 0.1340152472257614, + -1.2299835681915283, + 0.7986555099487305, + 2.2204456329345703, + -2.4498374462127686, + 0.33742472529411316, + 0.1001473218202591, + 0.08700849115848541, + 0.9933257102966309, + -2.5278031826019287, + -0.5935835242271423, + 2.710871934890747, + -0.87749183177948, + -0.06125229224562645, + -0.19061818718910217, + -0.04017600044608116, + 1.7519460916519165, + -0.7798219919204712, + -1.28012216091156, + 0.7500321269035339, + 0.02245335467159748, + 0.08263842761516571, + -0.1563340127468109, + -0.3502165377140045, + 0.5060794949531555, + 0.11768018454313278, + -0.2394258826971054, + 0.0027446788735687733, + -0.0012661140644922853, + 0.010839025489985943, + 0.04500429332256317, + -0.04333498701453209, + -0.027386408299207687, + 0.04357098788022995, + -0.04407481476664543, + 0.08443310111761093, + -0.08108946681022644, + -0.20346391201019287, + 0.3825778365135193, + -0.16498182713985443, + -0.04287993535399437, + 0.05340999737381935, + -0.14011172950267792, + 0.29446643590927124, + 0.2738667130470276, + -1.1299961805343628, + 0.7827413082122803, + -0.07552053779363632, + -0.03602323681116104, + 0.16167275607585907, + -0.6924317479133606, + 0.4478289783000946, + 1.2428895235061646, + -1.4833877086639404, + 0.4690392315387726, + -0.00820756796747446, + -0.09873292595148087, + 0.692342221736908, + -1.0981175899505615, + -0.3906446695327759, + 1.438644528388977, + -0.719068169593811, + 0.026173872873187065, + -0.09383898228406906, + -0.3282022774219513, + 1.0363390445709229, + -0.23960772156715393, + -0.7638148069381714, + 0.5488630533218384, + -0.015319733880460262, + 0.11911362409591675, + 0.017409542575478554, + -0.4231888949871063, + 0.23724795877933502, + 0.1191876158118248, + -0.15694500505924225, + -0.03534351661801338, + 0.06342366337776184, + 0.17738288640975952, + 0.012300643138587475, + -0.06408121436834335, + -0.06030220910906792, + 0.0018237337935715914, + 0.07659764587879181, + 0.1820947527885437, + 0.24410061538219452, + -0.06998514384031296, + -0.1491813361644745, + -0.06184092164039612, + 0.04607890918850899, + 0.15362663567066193, + 0.18308304250240326, + 0.08175522834062576, + -0.305602103471756, + -0.2915116548538208, + -0.08144206553697586, + 0.07138665020465851, + -0.03521484509110451, + -0.0914112851023674, + -0.2766699492931366, + -0.6285344362258911, + -0.38168880343437195, + -0.0033710987772792578, + 0.14477019011974335, + -0.03885374590754509, + -0.11367184668779373, + -0.1979650855064392, + -0.3575190007686615, + 0.016150522977113724, + 0.28292712569236755, + 0.2836199402809143, + -0.016672370955348015, + -0.034946177154779434, + -0.014770845882594585, + -0.0004113636096008122, + 0.29938748478889465, + 0.3562523126602173, + 0.13313128054141998, + -0.029499055817723274, + 0.007187174167484045, + 0.0636785551905632, + 0.047712039202451706, + 0.20670579373836517, + 0.10999035090208054, + -0.1150810718536377, + 0.00879934523254633, + -0.009125287644565105, + -0.013732590712606907, + 0.04738131910562515, + 0.0549951009452343, + -0.014094026759266853, + -0.01195482350885868, + -0.017125386744737625, + -0.071754589676857, + -0.023961570113897324, + 0.013098018243908882, + 0.05972208455204964, + -0.032899752259254456, + -0.024354496970772743, + -0.013116234913468361, + -0.05865325778722763, + -0.006360829807817936, + 0.12809234857559204, + 0.14038555324077606, + -0.022946689277887344, + -0.039698828011751175, + 0.05144746974110603, + -0.025034509599208832, + 0.08764739334583282, + 0.24594412744045258, + 0.19307002425193787, + -0.04085381329059601, + -0.020323628559708595, + 0.022060081362724304, + 0.01799374632537365, + 0.09039195626974106, + 0.1681770235300064, + 0.0016234283102676272, + -0.23777234554290771, + -0.11634974926710129, + -0.014439117163419724, + -0.034799374639987946, + 0.0457066111266613, + 0.049919649958610535, + -0.1926913857460022, + -0.2680967450141907, + 0.0018220803467556834, + -0.012749310582876205, + -0.04389086738228798, + 0.0060565415769815445, + -0.012036234140396118, + -0.12737582623958588, + -0.05777670815587044, + 0.09932202100753784, + 0.09969642758369446, + -0.1296343356370926, + -0.2964152693748474, + -0.05487265810370445, + 0.12073978036642075, + 0.06634647399187088, + 0.004042446613311768, + -0.1586746722459793, + -0.6267098784446716, + -0.5184157490730286, + -0.032286129891872406, + 0.28023189306259155, + 0.12663227319717407, + -0.08828771114349365, + -0.2600027620792389, + -0.5287090539932251, + -0.0994620993733406, + 0.7820600271224976, + 0.9638882279396057, + 0.2193463146686554, + -0.13466303050518036, + 0.042050741612911224, + -0.02292742393910885, + 0.7523098587989807, + 1.7435946464538574, + 1.111282229423523, + -0.2104763388633728, + -0.35129284858703613, + 0.08224371820688248, + 0.11167984455823898, + 0.6513852477073669, + 0.9696454405784607, + -0.1501394510269165, + -1.1777327060699463, + -0.7738466262817383, + 0.01114045549184084, + 0.004884988535195589, + 0.2849186658859253, + 0.14232710003852844, + -1.0306764841079712, + -1.2078118324279785, + -0.14658716320991516, + 0.036605384200811386, + 0.0001495486794738099, + 0.12111346423625946, + -0.24653346836566925, + -0.7028710246086121, + -0.18977169692516327, + 0.5171932578086853, + -0.02514370158314705, + 0.0885375589132309, + -0.1023016944527626, + 0.023200739175081253, + 0.11839435249567032, + -0.09749021381139755, + 0.008283962495625019, + 0.0106261121109128, + -0.031724803149700165, + -0.1594654619693756, + 0.433218389749527, + -0.33944255113601685, + 0.14406877756118774, + -0.0339396670460701, + 0.09370072185993195, + -0.35916459560394287, + 0.7577320337295532, + -0.5531823635101318, + -0.016844574362039566, + 0.2994873523712158, + -0.21487002074718475, + -0.16125759482383728, + 0.35567227005958557, + 0.09099612385034561, + -1.3889282941818237, + 1.9466298818588257, + -1.2556309700012207, + 0.4389301836490631, + -0.010665428824722767, + 0.4707520306110382, + -1.4310415983200073, + 2.0986156463623047, + -1.5515614748001099, + 0.3905705511569977, + 0.01881679706275463, + 0.057307951152324677, + -0.29734691977500916, + 0.369127094745636, + -0.05115725100040436, + -0.44008156657218933, + 0.48642784357070923, + -0.13904061913490295, + -0.004375698510557413, + -0.06351548433303833, + 0.256020188331604, + -0.34121274948120117, + 0.22490821778774261, + 0.004067304544150829, + -0.059063635766506195, + -0.010710661299526691, + 0.03514768183231354, + -0.08577805012464523, + 0.05103181675076485, + 0.04276616871356964, + -0.10832246392965317, + 0.03325289487838745, + 0.06318283081054688, + -0.11063538491725922, + -0.062119144946336746, + 0.40978243947029114, + -0.5597845315933228, + 0.34106317162513733, + -0.030269838869571686, + 0.057014383375644684, + -0.44329890608787537, + 1.0965592861175537, + -1.0767146348953247, + 0.13287265598773956, + 0.517289400100708, + -0.310720294713974, + -0.15501761436462402, + 0.5854693055152893, + -0.12469431757926941, + -1.7694847583770752, + 2.6433238983154297, + -1.596714735031128, + 0.3888415992259979, + -0.02415616251528263, + 0.42178481817245483, + -1.8008503913879395, + 2.8845136165618896, + -1.7628657817840576, + 0.1951047033071518, + 0.11415407806634903, + 0.07305648922920227, + -0.34212157130241394, + 0.46562451124191284, + 0.03175807744264603, + -0.7942091226577759, + 0.6133171319961548, + -0.14596694707870483, + 0.010496735572814941, + -0.03459644690155983, + 0.2948842942714691, + -0.47654271125793457, + 0.2612597346305847, + 0.016025209799408913, + -0.05287598818540573, + -0.01606004498898983, + 0.022197037935256958, + 0.028397703543305397, + -0.0390767939388752, + 0.0037972000427544117, + -0.07010228931903839, + 0.10934390872716904, + 0.017220165580511093, + 0.02215729095041752, + -0.14772991836071014, + 0.2353552132844925, + -0.3846408724784851, + 0.23990634083747864, + -0.02300707995891571, + 0.12085225433111191, + -0.3576957881450653, + 0.6410096883773804, + -0.532350480556488, + -0.002389132045209408, + 0.41821879148483276, + -0.24739143252372742, + -0.10216745734214783, + 0.16793736815452576, + 0.16367803514003754, + -1.1304419040679932, + 1.676539421081543, + -1.064436435699463, + 0.26995453238487244, + -0.07634275406599045, + 0.3324422240257263, + -1.11312997341156, + 1.8095507621765137, + -1.2477567195892334, + 0.3605581820011139, + -0.06627745926380157, + 0.008511146530508995, + -0.19528241455554962, + 0.4320055842399597, + -0.22881783545017242, + -0.18463851511478424, + 0.3064245581626892, + -0.14437103271484375, + 0.02049900032579899, + 0.018321938812732697, + 0.14011529088020325, + -0.26683253049850464, + 0.2172057181596756, + -0.12119362503290176, + 0.025965997949242592, + -0.03424325957894325, + 0.0433838777244091, + 0.1072857677936554, + 0.1997794657945633, + 0.0648089200258255, + -0.06444115936756134, + -0.13146057724952698, + 0.02106364443898201, + -0.22582228481769562, + -0.007233713287860155, + 0.18876874446868896, + -0.5612399578094482, + 0.2632557451725006, + 0.44088244438171387, + 0.11389002948999405, + -0.2791701555252075, + -0.18004432320594788, + 0.8571203947067261, + -1.9517340660095215, + -1.4906251430511475, + 0.3436146676540375, + 0.31222787499427795, + -0.20083315670490265, + -0.217665895819664, + 3.801243782043457, + 1.2014728784561157, + -0.9149202704429626, + 0.6968244910240173, + 0.12756747007369995, + -0.06783506274223328, + -2.086660385131836, + 0.5455523133277893, + 0.49095916748046875, + -0.5991013050079346, + 0.7938552498817444, + -0.1335069239139557, + 0.4730406701564789, + -1.00951087474823, + -0.537578821182251, + -0.49764835834503174, + -1.2683815956115723, + -0.045739322900772095, + -0.16049732267856598, + 0.30239275097846985, + 0.035600025206804276, + 0.6344828605651855, + 0.8256548643112183, + -0.12940075993537903, + 0.09257010370492935, + -0.11000311374664307, + 0.003206665627658367, + -0.008585316129028797, + -0.14573170244693756, + 0.172541081905365, + 0.2107972949743271, + -0.05270108953118324, + -0.08480435609817505, + 0.1914149820804596, + 0.21630872786045074, + -0.23309426009655, + -0.29484814405441284, + -0.1899339109659195, + 0.02601807750761509, + -0.05416746065020561, + 0.20924429595470428, + 0.15566189587116241, + -0.1556546688079834, + -0.23387494683265686, + -0.5112816691398621, + 0.24130745232105255, + -0.049835484474897385, + -0.2685615122318268, + -0.024764614179730415, + 0.5458847880363464, + 0.9501044750213623, + 0.1328524947166443, + 0.21218529343605042, + 0.2524968683719635, + -0.5205130577087402, + -0.3361912667751312, + 1.1678112745285034, + -0.004513490945100784, + -0.9149109125137329, + 0.2125048041343689, + 0.22423015534877777, + -0.08384363353252411, + -0.2866036593914032, + -0.20210212469100952, + -1.2377471923828125, + -0.7704879641532898, + 0.365038126707077, + -0.08308980613946915, + -0.08326874673366547, + 0.456358402967453, + 0.35142943263053894, + 0.19268833100795746, + 0.3706081509590149, + -0.04951317980885506, + 0.10151109844446182, + 0.005193099845200777, + -0.1124582439661026, + -0.08353164792060852, + -0.18709596991539001, + -0.18975794315338135, + 0.17628741264343262, + 0.05536900460720062, + 0.008301885798573494, + -0.1890449970960617, + 0.056875281035900116, + 0.7981322407722473, + -0.05872391164302826, + -0.4860122501850128, + -0.08073797076940536, + 0.13145819306373596, + -0.03608228266239166, + -0.6600452661514282, + 2.243560314178467, + 1.9288626909255981, + -0.5698518753051758, + -0.2486664056777954, + 0.42693793773651123, + 0.2667267322540283, + -4.395429611206055, + -2.15342378616333, + 0.819127082824707, + -0.9362612962722778, + -0.3760467767715454, + 0.5671858787536621, + 2.468177080154419, + -1.6694080829620361, + -0.49952322244644165, + 1.502772569656372, + -1.0188850164413452, + -0.10419629514217377, + -0.36795151233673096, + 1.2645196914672852, + 0.7223924994468689, + 1.751431941986084, + 2.018704891204834, + -0.3197852671146393, + 0.22054125368595123, + -0.19326329231262207, + -0.5307535529136658, + -0.9362435936927795, + -1.0772119760513306, + -0.19870880246162415, + -0.0650869607925415, + -0.0796947032213211, + 0.15733301639556885, + 0.08798394352197647, + 0.0010860684560611844, + 0.05327683687210083, + 0.1107875183224678, + 0.13224183022975922, + 0.08979664742946625, + 0.004348093178123236, + -0.07060158997774124, + -0.19925491511821747, + -0.15811985731124878, + -0.08220887929201126, + -0.022623460739850998, + 0.08509720861911774, + 0.00792989507317543, + -0.14345014095306396, + -0.2720486521720886, + -0.18885627388954163, + -0.11063539236783981, + -0.0355350486934185, + 0.048891279846429825, + -0.12828074395656586, + -0.2712610363960266, + -0.20134924352169037, + -0.1863398402929306, + -0.19976121187210083, + -0.09535074234008789, + 0.009852319024503231, + -0.2776590585708618, + -0.3087778687477112, + -0.21431012451648712, + -0.19772370159626007, + -0.23412325978279114, + -0.11640459299087524, + 0.09514907747507095, + -0.17561811208724976, + -0.29451555013656616, + -0.2381855845451355, + -0.18296842277050018, + -0.18682444095611572, + -0.023345205932855606, + 0.1438502073287964, + 0.02504260651767254, + -0.1554802507162094, + -0.1477985382080078, + -0.07874225080013275, + -0.002977968193590641, + 0.1048416793346405, + -0.1779504120349884, + 0.13204343616962433, + 0.14215172827243805, + 0.049610622227191925, + 0.0888131782412529, + 0.07250366359949112, + 0.0696505531668663, + 0.009899160824716091, + 0.032067786902189255, + 0.08401404321193695, + -0.03567894548177719, + -0.004740188363939524, + -0.0021664693485945463, + -0.011156522668898106, + 0.0821070745587349, + 0.10295391082763672, + -0.0017653254326432943, + -0.16915833950042725, + -0.062223054468631744, + 0.004783258773386478, + 0.038355808705091476, + 0.10124270617961884, + -0.003437258303165436, + -0.18881437182426453, + -0.15905225276947021, + -0.12576808035373688, + -0.11059725284576416, + 0.021587060764431953, + 0.07237453758716583, + -0.1706620156764984, + -0.27434206008911133, + -0.23003827035427094, + -0.20530915260314941, + -0.20856624841690063, + -0.021966496482491493, + 0.13395215570926666, + -0.03810539469122887, + -0.2409798800945282, + -0.2515420913696289, + -0.1872486174106598, + -0.15951117873191833, + 0.04223426431417465, + 0.09909931570291519, + 0.12328703701496124, + -0.057749148458242416, + -0.1300545036792755, + -0.046062104403972626, + 0.019744107499718666, + 0.09484386444091797, + -0.2709728479385376, + 0.03540695831179619, + 0.1206774190068245, + 0.057636432349681854, + 0.10385740548372269, + 0.032486993819475174, + -0.020434774458408356, + -0.10122086852788925, + -0.0023329253308475018, + 0.16941140592098236, + 0.098082534968853, + 0.1250472217798233, + 0.06134447827935219, + -0.025240115821361542, + 0.004181401338428259, + 0.14425808191299438, + 0.17515034973621368, + 0.04739757999777794, + 0.1618604063987732, + 0.1751406490802765, + 0.09162088483572006, + 0.09512057155370712, + 0.13736343383789062, + 0.028775952756404877, + 0.042535409331321716, + 0.08839954435825348, + 0.09229374676942825, + 0.1658262014389038, + 0.09852072596549988, + 0.002680110279470682, + -0.05479496717453003, + -0.03634755313396454, + -0.002902726177126169, + -0.023990361019968987, + 0.1277875006198883, + 0.12727677822113037, + 0.1002269834280014, + -0.040967896580696106, + -0.07101184874773026, + -0.007902896963059902, + 0.019561029970645905, + 0.145268052816391, + 0.017638152465224266, + 0.19240263104438782, + 0.12857146561145782, + 0.05043037235736847, + 0.11596394330263138, + 0.12513381242752075, + 0.12088746577501297, + 0.04333524778485298, + 0.05500142276287079, + 0.05169082432985306, + -0.09941842406988144, + -0.005959822330623865, + -0.032586321234703064, + -0.03065132349729538, + -0.04826900362968445, + 0.14192889630794525, + 0.2543988823890686, + 0.09563885629177094, + -0.28965362906455994, + -0.1341734230518341, + 0.033991701900959015, + -0.22402706742286682, + -0.3190857768058777, + 0.011840387247502804, + 0.9620282053947449, + 1.0609054565429688, + -0.13429726660251617, + -0.20191268622875214, + 0.05324135720729828, + -0.16234318912029266, + -0.9101927280426025, + -1.7916113138198853, + 0.3981992304325104, + 1.3173034191131592, + 0.53525310754776, + 0.18472574651241302, + 0.3719426691532135, + 0.7792536020278931, + -0.027768991887569427, + -2.245561122894287, + -1.2211185693740845, + 0.22817185521125793, + -0.0023349972907453775, + -0.12598364055156708, + 0.06836964190006256, + 0.9917387366294861, + 1.1885775327682495, + -0.2851368486881256, + -0.7428704500198364, + -0.04798422381281853, + -0.00811613816767931, + -0.19619861245155334, + -0.28184008598327637, + 0.0828644260764122, + 0.44643187522888184, + 0.1461745798587799, + -0.005575121380388737, + -0.06604957580566406, + 0.011459077708423138, + 0.03927984461188316, + 0.0634538009762764, + -0.005732079967856407, + -0.01014732290059328, + 0.07607843726873398, + 0.06948187947273254, + -0.010600326582789421, + -0.056259915232658386, + -0.24602480232715607, + -0.01649448834359646, + 0.11143466085195541, + -0.0027401424013078213, + -0.012853104621171951, + 0.08452893793582916, + 0.639316201210022, + 0.5167437195777893, + -0.2775256335735321, + -0.22241903841495514, + -0.07067711651325226, + -0.06368192285299301, + -0.4687917232513428, + -1.1776493787765503, + 0.36015447974205017, + 0.9171182513237, + 0.1905054748058319, + -0.010661551728844643, + 0.10800722986459732, + 0.5352235436439514, + 0.18558207154273987, + -1.5184046030044556, + -0.8130561709403992, + 0.15417319536209106, + 0.0713079422712326, + -0.07369451224803925, + -0.09037846326828003, + 0.6168488264083862, + 0.9663773775100708, + -0.007113471627235413, + -0.33585548400878906, + -0.02738586813211441, + 0.061310965567827225, + -0.0955657884478569, + -0.23896107077598572, + -0.1107473075389862, + 0.1830059289932251, + 0.10748914629220963, + -0.040772341191768646, + -0.05803938955068588, + -0.0004895658930763602, + 0.07664632797241211, + 0.039049405604600906, + -0.002806248841807246, + -0.02642429992556572, + 0.05169009417295456, + -0.036710865795612335, + -0.1002974808216095, + -0.12001149356365204, + -0.08043934404850006, + 0.11466419696807861, + 0.12322796136140823, + 0.07564827799797058, + 0.10148002207279205, + 0.04720174893736839, + 0.14046646654605865, + -0.0819464847445488, + -0.30803975462913513, + -0.0838734582066536, + -0.0801682323217392, + 0.05861072987318039, + 0.04970559477806091, + -0.20592759549617767, + 0.2673366665840149, + 0.2431953400373459, + -0.10027645528316498, + -0.07884806394577026, + -0.09939537942409515, + 0.1181628480553627, + 0.25269386172294617, + -0.3439132571220398, + -0.11160463094711304, + 0.08640077710151672, + 0.07200870662927628, + -0.03449570760130882, + -0.17610406875610352, + -0.021308166906237602, + 0.30556705594062805, + 0.05186203494668007, + -0.004691269714385271, + -0.005278654862195253, + 0.06289899349212646, + 0.052224051207304, + -0.05927770212292671, + -0.1586783081293106, + -0.022610770538449287, + 0.03463536128401756, + 0.004338411148637533, + 0.01452699676156044, + -0.008622901514172554, + 0.010536444373428822, + -0.038111478090286255, + 0.013373414985835552, + 0.007125865668058395, + -0.003420598339289427, + 0.03533756732940674, + 0.0320388600230217, + 0.045789655297994614, + -0.08139114826917648, + -0.03447948023676872, + -0.01453007198870182, + -0.004573625046759844, + 0.10279268026351929, + 0.10881853848695755, + 0.07537791877985, + -0.10887791216373444, + -0.0980544164776802, + -0.06889445334672928, + 0.006558350287377834, + 0.197514146566391, + 0.17890937626361847, + 0.07630149275064468, + -0.16081148386001587, + -0.16685302555561066, + -0.11421715468168259, + -0.013679573312401772, + 0.22477784752845764, + 0.20761631429195404, + 0.07321957498788834, + -0.17697854340076447, + -0.17810045182704926, + -0.1579347848892212, + -0.02679254300892353, + 0.1408146619796753, + 0.15144851803779602, + 0.08801613748073578, + -0.13237154483795166, + -0.13181765377521515, + -0.1279487907886505, + -0.01779216341674328, + 0.08145096898078918, + 0.05625852569937706, + 0.07724357396364212, + -0.04653938114643097, + -0.07479449361562729, + -0.06189379468560219, + -0.04310920089483261, + 0.02028634026646614, + -0.006228619255125523, + 0.03549303859472275, + -0.043929651379585266, + 0.007818001322448254, + 0.00874761026352644, + -0.017027731984853745, + 0.11014463752508163, + 0.0841977447271347, + 0.05960552394390106, + -0.12814101576805115, + -0.0544624924659729, + -0.045333195477724075, + 0.02336869016289711, + 0.22365787625312805, + 0.18523427844047546, + 0.09366372227668762, + -0.20144090056419373, + -0.16367222368717194, + -0.13003699481487274, + 0.0590205080807209, + 0.3301562964916229, + 0.26524844765663147, + 0.09425198286771774, + -0.26156124472618103, + -0.28513699769973755, + -0.21749621629714966, + 0.04356053099036217, + 0.35879984498023987, + 0.29898661375045776, + 0.0977487862110138, + -0.28175386786460876, + -0.2964495122432709, + -0.249031201004982, + 0.028877725824713707, + 0.26395633816719055, + 0.23059280216693878, + 0.09593978524208069, + -0.22489066421985626, + -0.2248908430337906, + -0.19214706122875214, + 0.007535146549344063, + 0.15299226343631744, + 0.09148521721363068, + 0.06946425884962082, + -0.1445557326078415, + -0.11587042361497879, + -0.0978587418794632, + -0.00984917301684618, + -0.012626220472157001, + -0.02837960794568062, + 0.02399199828505516, + -0.005340439733117819, + 0.023224178701639175, + 0.011642432771623135, + 0.003958537708967924, + 0.042965203523635864, + 0.01099414099007845, + 0.024063799530267715, + -0.0702008455991745, + 0.007805663626641035, + 0.0050195748917758465, + 0.017281856387853622, + 0.10123670846223831, + 0.06401767581701279, + 0.02626805007457733, + -0.1073761060833931, + -0.03802435100078583, + -0.014407800510525703, + -0.0006281707319431007, + 0.15516239404678345, + 0.12629136443138123, + 0.033691491931676865, + -0.17609107494354248, + -0.15251316130161285, + -0.07914211601018906, + -0.015578335151076317, + 0.18422608077526093, + 0.1740245372056961, + 0.06139932945370674, + -0.17213505506515503, + -0.1602732092142105, + -0.08922445774078369, + -0.012822975404560566, + 0.13543544709682465, + 0.12543149292469025, + 0.07651004195213318, + -0.13805902004241943, + -0.09661149233579636, + -0.052669934928417206, + -0.03268992528319359, + 0.0391642227768898, + 0.01116940937936306, + 0.04585625231266022, + -0.06474924832582474, + -0.023607701063156128, + -0.007017284631729126, + -0.026150476187467575, + 0.05729387328028679, + -0.10095079243183136, + 0.16617903113365173, + -0.13664309680461884, + 0.026482274755835533, + 0.008411461487412453, + -0.03410203382372856, + 0.022963764145970345, + 0.008903563022613525, + 0.11244194954633713, + -0.20863348245620728, + 0.11064451932907104, + -0.024916114285588264, + 0.009591493755578995, + -0.26092270016670227, + 0.5717483758926392, + -0.38539814949035645, + 0.035056713968515396, + 0.08623965084552765, + -0.016184961423277855, + 0.11129201203584671, + -0.6138678789138794, + 1.3646206855773926, + -1.4969615936279297, + 0.8465064764022827, + -0.2794847786426544, + 0.05826558917760849, + 0.07709132134914398, + -0.5444677472114563, + 1.3013663291931152, + -1.5686073303222656, + 0.9930508732795715, + -0.39188963174819946, + 0.08085884898900986, + -0.05875617265701294, + 0.03498996049165726, + 0.23967482149600983, + -0.3468690514564514, + 0.19146253168582916, + 0.019604403525590897, + -0.027150027453899384, + -0.024670494720339775, + 0.09944183379411697, + -0.11718503385782242, + 0.09772855788469315, + -0.11857263743877411, + 0.09660946577787399, + -0.03638811036944389, + -0.0295167975127697, + 0.1032838523387909, + -0.12557579576969147, + 0.11812210828065872, + -0.08446288853883743, + 0.027706580236554146, + 0.010997293516993523, + -0.06348618865013123, + 0.09578556567430496, + -0.0165568757802248, + -0.014778072014451027, + -0.07772849500179291, + 0.11245536059141159, + -0.043248821049928665, + 0.013345679268240929, + -0.22149333357810974, + 0.6456363797187805, + -0.7280437350273132, + 0.3046833574771881, + 0.06304280459880829, + -0.07310052216053009, + 0.08824795484542847, + -0.65179842710495, + 1.6453673839569092, + -2.046448230743408, + 1.3267604112625122, + -0.42399832606315613, + 0.0010522910160943866, + 0.07953720539808273, + -0.5960973501205444, + 1.5601089000701904, + -2.084894895553589, + 1.4612183570861816, + -0.5491638779640198, + 0.13709494471549988, + -0.09170618653297424, + 0.07287970930337906, + 0.24422486126422882, + -0.4581631124019623, + 0.29479551315307617, + -0.07515113800764084, + -0.012292998842895031, + -0.04451148584485054, + 0.14961428940296173, + -0.15577177703380585, + 0.06323063373565674, + -0.07806269824504852, + 0.07061618566513062, + -0.026793144643306732, + -0.051938362419605255, + 0.13946141302585602, + -0.14129231870174408, + 0.11092118173837662, + -0.08889970183372498, + 0.034787945449352264, + -0.008983314968645573, + -0.04930088296532631, + 0.09856640547513962, + -0.09350966662168503, + 0.07015673816204071, + -0.06468848884105682, + 0.08028972148895264, + -0.02378295361995697, + 0.004251216538250446, + -0.11239825189113617, + 0.2660067081451416, + -0.367576539516449, + 0.2212517410516739, + -0.035011082887649536, + -0.037866897881031036, + 0.11835235357284546, + -0.4868132174015045, + 0.9402765035629272, + -1.0933791399002075, + 0.9518744349479675, + -0.5096855759620667, + 0.12277142703533173, + 0.12916085124015808, + -0.4648635983467102, + 0.8895858526229858, + -1.0776352882385254, + 1.023865818977356, + -0.5914785861968994, + 0.1682877242565155, + -0.05646277964115143, + 0.04132156819105148, + -0.01790236309170723, + -0.059831030666828156, + 0.10092897713184357, + -0.1268356889486313, + 0.013669619336724281, + -0.02746082842350006, + 0.11544085294008255, + -0.2124193012714386, + 0.2733248472213745, + -0.1360178142786026, + 0.025302443653345108, + 0.01249375008046627, + -0.015119954012334347, + 0.017966970801353455, + 0.00269943755120039, + 0.014392177574336529, + 0.007648292928934097, + 0.011665135622024536, + -0.006192799191921949, + 0.004215092398226261, + 0.017718149349093437, + 0.046436555683612823, + 0.044417623430490494, + 0.01518242433667183, + -0.0020157198887318373, + -0.01828707568347454, + -0.029163505882024765, + -0.03131464868783951, + -0.004393945913761854, + 0.048599082976579666, + 0.015757638961076736, + -0.015650734305381775, + -0.002684049541130662, + -0.0697445422410965, + -0.25050923228263855, + -0.4758685231208801, + -0.5382962822914124, + -0.38907238841056824, + -0.12599025666713715, + -0.00266047241166234, + 0.0758173018693924, + 0.26593172550201416, + 0.4203726053237915, + 0.4958920478820801, + 0.3697706162929535, + 0.12434400618076324, + 0.026325728744268417, + 0.022295912727713585, + 0.08135133236646652, + 0.2627769708633423, + 0.26325660943984985, + 0.12326934933662415, + 0.058665141463279724, + 0.04346219077706337, + -0.0013142779935151339, + -0.10037153959274292, + -0.27075886726379395, + -0.28071707487106323, + -0.17300420999526978, + -0.06914675980806351, + 0.004067219793796539, + -0.020674005150794983, + 0.02103183977305889, + 0.0033879741095006466, + 0.013523808680474758, + -0.007318845018744469, + -0.009975744411349297, + -0.02981705591082573, + 0.023193644359707832, + 0.09624253213405609, + 0.1077117845416069, + 0.11186518520116806, + 0.07592211663722992, + 0.04614634811878204, + 0.015908582136034966, + -0.05212458223104477, + -0.1262977123260498, + -0.10974782705307007, + -0.07645918428897858, + -0.06987964361906052, + -0.08783216774463654, + -0.046172842383384705, + -0.22593465447425842, + -0.5281140804290771, + -0.8424770832061768, + -0.9608982801437378, + -0.7363743185997009, + -0.3312055170536041, + -0.10426472127437592, + 0.24067367613315582, + 0.5504152178764343, + 0.81276935338974, + 0.9592635035514832, + 0.7479950785636902, + 0.32608768343925476, + 0.14525265991687775, + 0.15008939802646637, + 0.32246851921081543, + 0.5287250876426697, + 0.5817036032676697, + 0.37340155243873596, + 0.20366452634334564, + 0.1546182781457901, + -0.11224830150604248, + -0.29856279492378235, + -0.5281672477722168, + -0.5890122056007385, + -0.4024880528450012, + -0.23706914484500885, + -0.0641399398446083, + -0.0025121152866631746, + 0.0051757702603936195, + -0.014290476217865944, + 0.0043721878901124, + -0.004783981014043093, + 0.021787043660879135, + -0.004969750996679068, + -0.022116241976618767, + 0.05208030343055725, + 0.07022145390510559, + 0.03730607405304909, + 0.03242917358875275, + 0.04344351217150688, + -0.01189794484525919, + -0.0418211966753006, + -0.059125497937202454, + -0.014576594345271587, + 0.01294493954628706, + -0.011262460611760616, + -0.059920165687799454, + -0.04733816161751747, + -0.12665517628192902, + -0.29677024483680725, + -0.5247481465339661, + -0.6474934816360474, + -0.4751538038253784, + -0.1937171369791031, + -0.05117221921682358, + 0.14646948873996735, + 0.32891425490379333, + 0.5415402054786682, + 0.6071264147758484, + 0.4653589427471161, + 0.18045872449874878, + 0.09937354922294617, + 0.1264665126800537, + 0.18507222831249237, + 0.31783968210220337, + 0.3545042872428894, + 0.22468777000904083, + 0.09973976761102676, + 0.1227618008852005, + -0.07824759930372238, + -0.20465101301670074, + -0.36476215720176697, + -0.38243186473846436, + -0.2540777623653412, + -0.13525226712226868, + -0.03621843457221985, + -0.012233156710863113, + -0.01481863297522068, + -0.04313792288303375, + 0.002874002791941166, + -0.028444716706871986, + -0.04687628522515297, + -0.026806645095348358, + -0.0228339321911335, + -0.015892738476395607, + -0.015550780110061169, + 0.07011140882968903, + 0.0017389585264027119, + -0.05721491947770119, + -0.017484690994024277, + -0.03954736143350601, + -0.006339249666780233, + 0.08166316151618958, + 0.37439921498298645, + 0.2830294966697693, + 0.00668215099722147, + -0.038873329758644104, + -0.012295035645365715, + 0.04932165890932083, + 0.31826695799827576, + 0.8449289202690125, + 0.7123299241065979, + 0.2574000954627991, + 0.04747961834073067, + -0.04416817054152489, + -0.005029442720115185, + 0.2027042657136917, + 0.6639980673789978, + 0.6243636012077332, + 0.21359916031360626, + 0.027929672971367836, + -0.05395142361521721, + -0.04981911554932594, + -0.006375179626047611, + 0.23660773038864136, + 0.2155737280845642, + 0.020577391609549522, + -0.032118700444698334, + -0.02332071214914322, + -0.009217707440257072, + -0.038096409291028976, + 0.05811609327793121, + 0.03776064142584801, + -0.03570764884352684, + -0.042420413345098495, + 0.017812976613640785, + 0.019242385402321815, + 0.030057156458497047, + 0.003040613606572151, + 0.02378096617758274, + 0.04043402150273323, + 0.0243258997797966, + 0.014026327058672905, + 0.005650558043271303, + -0.002831381279975176, + -0.0645776093006134, + -0.03761167451739311, + 0.043774381279945374, + 0.010685136541724205, + 0.031011218205094337, + -0.0025828774087131023, + -0.11959855258464813, + -0.3524792194366455, + -0.30037227272987366, + -0.053334690630435944, + 0.009859252721071243, + 0.0010005333460867405, + -0.04819931834936142, + -0.3154168128967285, + -0.7240553498268127, + -0.6380828022956848, + -0.25695785880088806, + -0.06639125943183899, + 0.03295261785387993, + -0.012727363035082817, + -0.24232468008995056, + -0.6055921912193298, + -0.5679556727409363, + -0.20067356526851654, + -0.03628019988536835, + 0.04774145409464836, + 0.029560575261712074, + -0.038632482290267944, + -0.24032950401306152, + -0.2095729559659958, + -0.006905315909534693, + 0.02563827484846115, + 0.03053808957338333, + 0.0012747920118272305, + 0.004095789045095444, + -0.07932732999324799, + -0.046672020107507706, + 0.02153847925364971, + 0.019504766911268234, + -0.006118285935372114, + 0.0026654782705008984, + 0.013819373212754726, + -0.01078135147690773, + 0.0070082321763038635, + 0.00906399916857481, + 0.010149766691029072, + 0.000516490894369781, + 0.00034157291520386934, + 0.02412085421383381, + 0.006926041562110186, + 0.023299943655729294, + 0.01129852794110775, + -0.0018704778049141169, + 0.016042279079556465, + 0.023886069655418396, + 0.04207555204629898, + -0.0021778997033834457, + 0.041684601455926895, + 0.05059140920639038, + 0.03518521040678024, + -0.0032736151479184628, + -0.0007146652205847204, + 0.015503454953432083, + -0.11896659433841705, + -0.07006713002920151, + 0.007565992418676615, + 0.012584990821778774, + 0.00843358226120472, + 0.017024952918291092, + 0.0359124094247818, + -0.05997823178768158, + -0.04116949439048767, + -0.016472430899739265, + 0.002696823561564088, + 0.00829327292740345, + 0.016238784417510033, + 0.0455794483423233, + 0.0019872160628437996, + -0.005927432328462601, + -0.003552153240889311, + 0.020063765347003937, + 0.00010026743984781206, + 0.01045019831508398, + 0.034689340740442276, + 0.014206668362021446, + 0.015128945000469685, + 0.00972809735685587, + 0.019944868981838226, + 0.020581791177392006, + 0.02938947267830372, + 0.03923909366130829, + 0.03601628914475441, + 0.030168617144227028, + 0.05403255671262741, + 0.03985666483640671, + 0.020015308633446693, + 0.0285494402050972, + 0.013555807992815971, + -0.04409409686923027, + -0.07503483444452286, + 0.01716756261885166, + 0.02053452841937542, + 0.057520389556884766, + 0.02973104454576969, + -0.04563397541642189, + -0.2676408588886261, + -0.30933722853660583, + -0.11671236902475357, + 0.0020135289523750544, + 0.022801443934440613, + -0.03161352127790451, + -0.2704106271266937, + -0.5803710222244263, + -0.5762420296669006, + -0.30449461936950684, + -0.0780220776796341, + 0.017343536019325256, + -0.05319945886731148, + -0.2906038463115692, + -0.598426342010498, + -0.5925986766815186, + -0.31852787733078003, + -0.09950074553489685, + 0.05888299271464348, + 0.01939479075372219, + -0.1060815081000328, + -0.3505017161369324, + -0.3200446665287018, + -0.10609738528728485, + 0.03659524768590927, + 0.056114207953214645, + 0.03447861596941948, + 0.014380007050931454, + -0.09436371922492981, + -0.07562272250652313, + 0.04223132133483887, + 0.06327345967292786, + -0.03735652193427086, + -0.052881840616464615, + -0.058017320930957794, + -0.02474917098879814, + -0.02431381866335869, + -0.0629878118634224, + -0.05212349444627762, + -0.03820814937353134, + -0.0034579068887978792, + -0.004930540919303894, + 0.07968354970216751, + 0.07278168946504593, + 0.015167324803769588, + -0.013638288713991642, + -0.05875609815120697, + -0.008851750753819942, + 0.10708516091108322, + 0.33075177669525146, + 0.3502756953239441, + 0.14791442453861237, + 0.03131852671504021, + -0.028764141723513603, + 0.07454497367143631, + 0.3000347316265106, + 0.6147283315658569, + 0.6289594173431396, + 0.3398674726486206, + 0.13494613766670227, + -0.03705109655857086, + 0.0633230209350586, + 0.3147434592247009, + 0.595033586025238, + 0.594217836856842, + 0.33864542841911316, + 0.11264053732156754, + -0.059276629239320755, + 0.005206871312111616, + 0.14524762332439423, + 0.37473905086517334, + 0.34477534890174866, + 0.12632343173027039, + 0.011062734760344028, + -0.06149457022547722, + -0.028670497238636017, + 0.011082210578024387, + 0.13112866878509521, + 0.1106843650341034, + -0.0025933771394193172, + -0.03781202808022499, + 0.030325254425406456, + 0.017758814617991447, + 0.01635698974132538, + -0.008786264806985855, + -0.0005018062074668705, + 0.005934061016887426, + 0.020206287503242493, + 0.019497420638799667, + -0.01290479488670826, + -0.010817185044288635, + -0.032760608941316605, + -0.026973316445946693, + -0.0021766452118754387, + -0.012848617509007454, + -0.0002560729335527867, + -0.02383977733552456, + -0.05322824791073799, + -0.05382781848311424, + -0.04459262639284134, + -0.04581240937113762, + -0.03465775027871132, + 0.0026904877740889788, + -0.026097090914845467, + -0.05170493200421333, + -0.04981262609362602, + -0.05221042037010193, + -0.05268307775259018, + -0.04735802114009857, + 0.019142162054777145, + -0.019374292343854904, + -0.03312355652451515, + -0.04133244976401329, + -0.033129844814538956, + -0.01844680868089199, + -0.024726904928684235, + 0.0012146441731601954, + -0.025521529838442802, + -0.03120318427681923, + -0.04863203689455986, + -0.021450525149703026, + -0.04190714284777641, + -0.02833862416446209, + 0.017827404662966728, + -0.010181388817727566, + -0.020994380116462708, + -0.04290826618671417, + -0.031555648893117905, + -0.030525390058755875, + -0.024981478229165077, + -0.017512500286102295, + 0.019927235320210457, + 0.00433371402323246, + -0.009276121854782104, + -0.03990143537521362, + -0.021251117810606956, + 0.017825132235884666, + -0.02313065528869629, + 0.012881814502179623, + 0.0009175563463941216, + -0.0656605213880539, + -0.007037178613245487, + 0.023603176698088646, + 0.04873553663492203, + 0.013912673108279705, + 9.78652315097861e-05, + -0.03166677802801132, + -0.11772678792476654, + -0.034320034086704254, + 0.04952533170580864, + 0.10113520920276642, + 0.030472615733742714, + -0.05131377652287483, + -0.1371452510356903, + -0.2326214611530304, + -0.0629519522190094, + 0.12444627285003662, + 0.15845368802547455, + 0.014535457827150822, + -0.06888624280691147, + -0.18798232078552246, + -0.24720685184001923, + -0.04858007654547691, + 0.26889580488204956, + 0.2433905005455017, + -0.01772989332675934, + -0.06027546152472496, + -0.12164203822612762, + -0.20018024742603302, + 0.0035393801517784595, + 0.27190765738487244, + 0.1929154396057129, + -0.012923460453748703, + -0.013931642286479473, + -0.043986693024635315, + -0.0655391663312912, + 0.04751605913043022, + 0.13482201099395752, + 0.06690078228712082, + -0.01862635649740696, + 0.02938506379723549, + 0.01789080537855625, + -0.006509440019726753, + -0.029202938079833984, + -0.023693149909377098, + 0.01042762491852045, + -0.0035929735749959946, + 0.024952176958322525, + -0.013459124602377415, + -0.10798560827970505, + -0.020217353478074074, + 0.017876077443361282, + 0.07628928124904633, + 0.04444783553481102, + 0.012667268514633179, + -0.09012818336486816, + -0.22452381253242493, + -0.07556752860546112, + 0.07942477613687515, + 0.17035256326198578, + 0.0396822914481163, + -0.08236342668533325, + -0.23916372656822205, + -0.3645225763320923, + -0.10748416185379028, + 0.1996970921754837, + 0.3076043725013733, + -0.0033923503942787647, + -0.13259321451187134, + -0.28894615173339844, + -0.3605952262878418, + -0.07969008386135101, + 0.3583948314189911, + 0.4267900586128235, + -0.02228585258126259, + -0.11386624723672867, + -0.21445821225643158, + -0.26956692337989807, + 0.026791207492351532, + 0.37918713688850403, + 0.37130093574523926, + -0.05172214284539223, + -0.05132569745182991, + -0.07469630241394043, + -0.11400169134140015, + 0.07863093167543411, + 0.24061299860477448, + 0.19393151998519897, + -0.03217098489403725, + 0.013085477985441685, + 0.032348379492759705, + 0.03207695484161377, + 0.010604938492178917, + -0.026534704491496086, + -0.018284842371940613, + -0.01768680103123188, + -0.001516501884907484, + 0.013829287141561508, + -0.034318119287490845, + 0.015753330662846565, + -0.0018936718115583062, + 0.014737343415617943, + 0.03306088596582413, + 0.020835628733038902, + -0.03396771103143692, + -0.10758449137210846, + -0.03052518330514431, + 0.020080547779798508, + 0.06180800125002861, + 0.03735671192407608, + -0.037925880402326584, + -0.09720461815595627, + -0.21495617926120758, + -0.06842153519392014, + 0.08532039076089859, + 0.13350333273410797, + 0.03649023920297623, + -0.03904158994555473, + -0.1483580619096756, + -0.2068314403295517, + -0.05687328055500984, + 0.21108660101890564, + 0.21018920838832855, + 0.009318819269537926, + -0.037683792412281036, + -0.09845960140228271, + -0.1535443514585495, + 0.004504916723817587, + 0.20256847143173218, + 0.1799001693725586, + -0.03175490349531174, + -0.020391397178173065, + -0.007309200707823038, + -0.06765769422054291, + 0.013149870559573174, + 0.08469820767641068, + 0.04147877171635628, + -0.0027241194620728493, + 0.008016721345484257, + 0.001382349175401032, + 0.0001219741752720438, + -0.059255484491586685, + -0.03761141747236252, + 0.0381690077483654, + -0.01603613793849945, + 0.0017731477273628116, + -0.016544193029403687, + 0.09518970549106598, + 0.1735895872116089, + 0.005558829288929701, + -0.13464735448360443, + -0.0703420490026474, + 0.001990854274481535, + -0.03426021337509155, + -0.4390500485897064, + -0.11292288452386856, + 0.20430812239646912, + 0.14832687377929688, + 0.06074441969394684, + -0.03749264031648636, + 0.408058226108551, + 0.43119552731513977, + -0.3804298937320709, + -0.3694773018360138, + -0.03696960583329201, + 0.04022200033068657, + -0.0812998041510582, + -0.4322642385959625, + 0.19638888537883759, + 0.7809834480285645, + 0.11584538966417313, + -0.04975399747490883, + -0.015579828992486, + 0.1362757831811905, + 0.027220597490668297, + -0.4703449606895447, + -0.3726261258125305, + 0.11754196882247925, + -0.01204066164791584, + -0.00118898821529001, + -0.05152498185634613, + 0.08767394721508026, + 0.14183296263217926, + 0.01692730002105236, + -0.04587334021925926, + 0.011115594767034054, + 0.021572716534137726, + -0.021584773436188698, + -0.012763801962137222, + 0.05708793178200722, + 0.021982798352837563, + -0.02731800265610218, + 0.03000856563448906, + 0.006653181277215481, + -0.02485630102455616, + -0.20296195149421692, + -0.10483214259147644, + 0.20483383536338806, + 0.1350196748971939, + -0.08543248474597931, + 0.02644401416182518, + 0.26855263113975525, + 0.1071053072810173, + -0.8168368935585022, + -0.6617473363876343, + 0.02877889946103096, + 0.21807144582271576, + -0.02164696715772152, + -0.03712613880634308, + 0.9743875861167908, + 1.1631361246109009, + -0.45643851161003113, + -0.8180081844329834, + -0.28109386563301086, + -0.09115415811538696, + -0.4352502226829529, + -0.7433719038963318, + 0.5383746027946472, + 1.7271664142608643, + 0.509749174118042, + -0.0689467042684555, + 0.010011479258537292, + 0.11752951890230179, + -0.28825971484184265, + -1.113126277923584, + -0.6029489636421204, + 0.357056587934494, + 0.19766344130039215, + 0.023361098021268845, + 0.04305602237582207, + 0.24867205321788788, + 0.16359609365463257, + -0.2485191822052002, + -0.2251967489719391, + 0.030422789976000786, + 0.0049157580360770226, + -0.05497031658887863, + -0.030760835856199265, + 0.034536562860012054, + 0.019565051421523094, + -0.00933124776929617, + 0.01611645519733429, + 0.07988770306110382, + -0.021982649341225624, + -0.21876110136508942, + -0.10555483400821686, + 0.1893070936203003, + 0.14684906601905823, + -0.031080693006515503, + 0.09768003225326538, + 0.3261844515800476, + 0.1466774046421051, + -0.6738073825836182, + -0.5424039363861084, + 0.04689512774348259, + 0.22039148211479187, + -0.07084018737077713, + -0.07436021417379379, + 0.8260523080825806, + 1.0253428220748901, + -0.38162854313850403, + -0.727206289768219, + -0.2605172097682953, + -0.0996573269367218, + -0.3653049170970917, + -0.6791687607765198, + 0.43514078855514526, + 1.4186147451400757, + 0.38797008991241455, + -0.12675431370735168, + 0.02766786515712738, + 0.14237603545188904, + -0.2306709885597229, + -0.9204807877540588, + -0.5071616172790527, + 0.32662850618362427, + 0.20703284442424774, + -0.020968681201338768, + 0.014105334877967834, + 0.24642448127269745, + 0.20103473961353302, + -0.15519124269485474, + -0.22072142362594604, + 0.049920063465833664, + -0.05465548485517502, + 0.018651481717824936, + 0.030082669109106064, + 0.05234164372086525, + 0.10243640840053558, + 0.03569166734814644, + 0.038984544575214386, + 0.05248976871371269, + 0.24501988291740417, + 0.4674161374568939, + 0.7142530083656311, + 0.7423628568649292, + 0.6262048482894897, + 0.4019012451171875, + -0.010997634381055832, + 0.17266513407230377, + 0.4467124342918396, + 0.7795005440711975, + 0.8282667994499207, + 0.6824804544448853, + 0.3955397605895996, + 0.009771074168384075, + 0.10707246512174606, + 0.23039454221725464, + 0.33151063323020935, + 0.36120596528053284, + 0.3240644633769989, + 0.17939962446689606, + -0.01115038525313139, + -0.11081521213054657, + -0.2146066278219223, + -0.3572347164154053, + -0.44021451473236084, + -0.38320258259773254, + -0.24643990397453308, + 0.031578775495290756, + -0.21325217187404633, + -0.4312629997730255, + -0.7276368141174316, + -0.8273008465766907, + -0.718246340751648, + -0.4161607027053833, + -0.06636986136436462, + -0.28078269958496094, + -0.476252943277359, + -0.734549880027771, + -0.7796792984008789, + -0.6637035608291626, + -0.41896238923072815, + 0.021693198010325432, + 0.006199972704052925, + -0.016619624570012093, + -0.010678192600607872, + 0.012267512269318104, + 0.004102918319404125, + -0.004080160986632109, + -0.0029241242446005344, + -0.027252744883298874, + -0.0772257149219513, + -0.09107967466115952, + -0.11302012205123901, + -0.08569496124982834, + -0.07242150604724884, + -0.016465697437524796, + -0.04874062165617943, + -0.09103028476238251, + -0.09025602042675018, + -0.07523388415575027, + -0.06320428103208542, + -0.048220545053482056, + -0.028701437637209892, + -0.008647853508591652, + -0.022354092448949814, + -0.06076030433177948, + -0.030872423201799393, + -0.045786645263433456, + -0.04190178960561752, + 0.03718986362218857, + 0.021405767649412155, + 0.007675759959965944, + 0.02794131636619568, + 0.030316906049847603, + 0.007403802592307329, + 0.04861852154135704, + 0.023217258974909782, + 0.04545973241329193, + 0.07504793256521225, + 0.06824314594268799, + 0.07417462021112442, + 0.0769289955496788, + 0.0766506940126419, + -0.0028638055082410574, + 0.05911175534129143, + 0.055706772953271866, + 0.10735032707452774, + 0.10494870692491531, + 0.11092723160982132, + 0.09338293969631195, + 0.04235343262553215, + -0.022347571328282356, + -0.026347652077674866, + -0.06954608112573624, + -0.06944439560174942, + -0.05570404976606369, + -0.042987462133169174, + -0.056951191276311874, + -0.2151203453540802, + -0.3603246510028839, + -0.5899456143379211, + -0.6453464031219482, + -0.5338351726531982, + -0.31790611147880554, + 0.049492284655570984, + -0.12898015975952148, + -0.40155911445617676, + -0.6737278699874878, + -0.7170611619949341, + -0.5817899703979492, + -0.32979026436805725, + -0.005899591837078333, + -0.07673019915819168, + -0.190496027469635, + -0.34019437432289124, + -0.3314637243747711, + -0.2796767055988312, + -0.1381818801164627, + -0.008025999180972576, + 0.08429048955440521, + 0.2105528861284256, + 0.3415210545063019, + 0.4151126444339752, + 0.34003961086273193, + 0.21059827506542206, + -0.03514896333217621, + 0.1792585551738739, + 0.3903186321258545, + 0.6413942575454712, + 0.7557680010795593, + 0.6069726943969727, + 0.3415443003177643, + 0.03447553142905235, + 0.21517080068588257, + 0.4215562045574188, + 0.6151171922683716, + 0.6550290584564209, + 0.5680058002471924, + 0.33561068773269653, + -0.12205997854471207, + -0.0038300298620015383, + 0.3281119763851166, + -0.2328944057226181, + -0.03834507241845131, + 0.05432930961251259, + -0.014430212788283825, + 0.006271198857575655, + 0.32864242792129517, + 0.47277259826660156, + -0.5593215227127075, + -0.14971251785755157, + 0.13066314160823822, + -0.09738356620073318, + 0.2966129779815674, + 0.5606555342674255, + -0.3184640407562256, + -2.022890090942383, + -0.361995667219162, + 0.5496177673339844, + 0.02796279452741146, + -0.21818380057811737, + -0.5373459458351135, + -1.9538941383361816, + -1.9984712600708008, + 1.6747761964797974, + 1.5063239336013794, + -0.24534250795841217, + -0.040306344628334045, + -0.16963164508342743, + -0.40690454840660095, + 1.3548375368118286, + 3.922116279602051, + 0.8723023533821106, + -0.8986141681671143, + 0.06912416964769363, + 0.2192920595407486, + 0.352949321269989, + 1.2243634462356567, + 1.1395865678787231, + -1.5146961212158203, + -1.1557590961456299, + -0.05440744385123253, + -0.04629289731383324, + -0.002693743444979191, + -0.21906790137290955, + -0.5464610457420349, + -1.1933224201202393, + 0.01913866586983204, + 0.09363497048616409, + -0.06080613285303116, + -0.049100056290626526, + 0.04482033848762512, + -0.04087500274181366, + -0.009318803437054157, + 0.009458474814891815, + -0.09565524011850357, + -0.2264278084039688, + -0.0698866918683052, + 0.13825084269046783, + 0.014815542846918106, + -0.05801662430167198, + 0.012776852585375309, + -0.0753035843372345, + -0.07555855065584183, + 0.484436959028244, + 0.6397283673286438, + 0.12687323987483978, + -0.01779526099562645, + 0.05689511448144913, + 0.06747376173734665, + 0.26353734731674194, + 0.5908273458480835, + 0.4315526783466339, + -0.5426794290542603, + -0.44501280784606934, + -0.019558124244213104, + -0.03320806100964546, + -0.025809556245803833, + 0.17376014590263367, + -0.5201969742774963, + -1.2842578887939453, + -0.3674038052558899, + 0.0882175862789154, + -0.030023137107491493, + -0.1173325777053833, + 0.02555503323674202, + -0.39882710576057434, + -0.37364596128463745, + 0.3550366163253784, + 0.3903135359287262, + 0.04022252932190895, + 0.016731394454836845, + 0.11207644641399384, + -0.020967213436961174, + -0.028497911989688873, + 0.37590932846069336, + 0.14920172095298767, + 0.029958104714751244, + 0.039632707834243774, + -0.24969367682933807, + 0.16809938848018646, + 0.07703239470720291, + -0.03522319719195366, + -0.007072617299854755, + 0.07751759141683578, + -0.06782346963882446, + -0.4010501801967621, + 0.41269779205322266, + 0.1311105638742447, + -0.07331988960504532, + 0.08240311592817307, + -0.20034979283809662, + -0.4718745946884155, + -0.178948312997818, + 1.3285318613052368, + 0.20384186506271362, + -0.48546233773231506, + -0.09941625595092773, + 0.13249020278453827, + 0.29977336525917053, + 1.2681238651275635, + 1.5725642442703247, + -1.0834472179412842, + -1.0335719585418701, + 0.25975045561790466, + 0.06584863364696503, + 0.1609305590391159, + 0.25940945744514465, + -0.8426372408866882, + -2.590407609939575, + -0.4723183214664459, + 0.7581043243408203, + -0.03634117543697357, + -0.10199672728776932, + -0.3744191527366638, + -0.7823801636695862, + -0.7062401175498962, + 1.116550087928772, + 0.7735803127288818, + 0.012776976451277733, + 0.034575968980789185, + -0.10188565403223038, + 0.2212170958518982, + 0.5182898044586182, + 0.8056022524833679, + -0.1897655427455902, + -0.005556725896894932, + -0.003909373190253973, + -0.02175678312778473, + -0.04085654392838478, + -0.03573022410273552, + -0.0038509985897690058, + 0.02454996667802334, + 0.039437733590602875, + 0.02077251859009266, + 0.02166259102523327, + 0.17245841026306152, + 0.09513862431049347, + -0.10491111874580383, + -0.08084940910339355, + -0.026179829612374306, + 0.0215831957757473, + -0.16602416336536407, + -0.2803819179534912, + 0.23894084990024567, + 0.3269801735877991, + 0.04504352807998657, + 0.0009768904419615865, + 0.01959501951932907, + 0.24426960945129395, + -0.1451571136713028, + -0.5944203734397888, + -0.17875447869300842, + 0.028336334973573685, + 0.004323791246861219, + -0.045389141887426376, + 0.0343034490942955, + 0.46665430068969727, + 0.3707427978515625, + -0.114569291472435, + 0.04335101321339607, + -0.018011711537837982, + -0.021181274205446243, + -0.19074901938438416, + -0.20113815367221832, + 0.048786211758852005, + 0.08533122390508652, + -0.06084573268890381, + 0.01217757910490036, + 0.030666939914226532, + 0.05272842198610306, + 0.010849648155272007, + -0.05913804844021797, + -0.04202868044376373, + -0.0015147016383707523, + -0.03421122953295708, + 0.015080726705491543, + 0.12191007286310196, + 0.10450142621994019, + -0.04972418025135994, + -0.07557133585214615, + -0.02221665158867836, + -0.0861242413520813, + -0.14919178187847137, + -0.04388582333922386, + 0.4605262875556946, + 0.5697804093360901, + 0.1583399623632431, + -0.045628566294908524, + -0.05220475420355797, + -0.13630147278308868, + -0.7103163599967957, + -1.0178179740905762, + 0.1927143931388855, + 0.7479860186576843, + 0.47013771533966064, + 0.16943301260471344, + 0.2398149073123932, + 0.4710526168346405, + -0.5974176526069641, + -1.8564051389694214, + -0.7726883292198181, + 0.05584309995174408, + 0.08902852982282639, + 0.0931839719414711, + 0.46213099360466003, + 1.2080260515213013, + 0.6001025438308716, + -0.590207576751709, + -0.4145379662513733, + -0.04529324173927307, + -0.08303339034318924, + -0.2470429688692093, + -0.03481363505125046, + 0.4808541238307953, + 0.4001348614692688, + -0.1292688548564911, + -0.03635162487626076, + -0.006270444020628929, + -0.0314505510032177, + -0.13043232262134552, + -0.10837803781032562, + 0.10718243569135666, + 0.07523836195468903, + -0.00597786670550704, + 0.06580565124750137, + 0.11166563630104065, + 0.021869506686925888, + -0.10510984063148499, + -0.07651247084140778, + 0.01229890063405037, + -0.08976037800312042, + -0.14929910004138947, + -0.018859578296542168, + 0.4408939778804779, + 0.4029107689857483, + -0.05015433207154274, + -0.13887189328670502, + -0.04514491930603981, + -0.07346425950527191, + -0.5277182459831238, + -0.7335640788078308, + 0.24182197451591492, + 0.626846432685852, + 0.23399080336093903, + 0.09675730019807816, + 0.15529058873653412, + 0.42680656909942627, + -0.4012089967727661, + -1.3605350255966187, + -0.4793834686279297, + 0.10987094044685364, + 0.07592830061912537, + 0.003319029463455081, + 0.24004696309566498, + 0.9590277671813965, + 0.4946591258049011, + -0.4889579117298126, + -0.34744441509246826, + -0.020535729825496674, + -0.026767954230308533, + -0.2090117186307907, + -0.11841326951980591, + 0.37452432513237, + 0.39960840344429016, + -0.07025045901536942, + -0.022984744980931282, + 0.022319970652461052, + -0.0027356306090950966, + -0.13681942224502563, + -0.09797768294811249, + 0.09914079308509827, + 0.10856777429580688, + ] + value = numpy.array(list_value, dtype=numpy.float32).reshape((64, 3, 7, 7)) + tensor = numpy_helper.from_array(value, name="onnx::Conv_501") + + initializers.append(tensor) + + list_value = [ + 3.085598945617676, + 2.2436060905456543, + 4.244357585906982, + 1.4069645404815674, + -4.00622034072876, + 2.595770835876465, + 2.7202603816986084, + 2.4405417442321777, + 1.1759933233261108, + 2.021026372909546, + 2.6628992557525635, + 6.445226192474365, + -7.029932498931885, + 1.1305793523788452, + 2.537140369415283, + 5.456772327423096, + 4.780154705047607, + 10.039976119995117, + 2.912492275238037, + 15.781542778015137, + 2.5154318809509277, + 2.628824472427368, + 2.2992050647735596, + 2.0950584411621094, + -7.93365478515625, + 2.067786931991577, + 4.094852447509766, + 1.673399806022644, + 3.1814424991607666, + 22.49496078491211, + 2.232640027999878, + 2.6427979469299316, + -9.418174743652344, + 1.790976643562317, + 2.3774726390838623, + 2.5836219787597656, + 2.5608203411102295, + 2.287343978881836, + 2.6439085006713867, + 16.859027862548828, + 1.8699607849121094, + -3.6987526416778564, + 2.6861538887023926, + 2.8997464179992676, + 2.689293384552002, + 2.6654043197631836, + 2.3799915313720703, + 2.5603086948394775, + 3.146122694015503, + 2.715951681137085, + 2.889486789703369, + 2.966134548187256, + -4.960191249847412, + 2.6123547554016113, + 1.3074164390563965, + 2.2033026218414307, + 2.2114620208740234, + 4.132844924926758, + 4.893764495849609, + 2.6469600200653076, + 2.654136896133423, + 1.9311997890472412, + 2.881012439727783, + 2.6991193294525146, + ] + value = numpy.array(list_value, dtype=numpy.float32) + tensor = numpy_helper.from_array(value, name="onnx::Conv_502") + + initializers.append(tensor) + + list_value = [ + 0.057212892919778824, + 0.06299274414777756, + -0.018499961122870445, + -0.06501776725053787, + -0.015820641070604324, + 0.024293724447488785, + 0.05624663084745407, + -0.025112055242061615, + 0.043546054512262344, + 0.08439744263887405, + 0.005678815301507711, + 0.0034800865687429905, + 0.030301403254270554, + -0.011669250205159187, + -0.005434689112007618, + -0.1591511219739914, + 0.02324092946946621, + -0.018942436203360558, + 0.025366367772221565, + -0.07414374500513077, + 0.03468436002731323, + -0.003742520697414875, + -0.06651683896780014, + 0.005561002530157566, + 0.04527103528380394, + -0.13710148632526398, + 0.0025444801431149244, + 0.03583350405097008, + 0.015219246037304401, + -0.053635064512491226, + 0.004856681916862726, + -0.07223699986934662, + 0.016770021989941597, + 0.0012010147329419851, + 0.014582094736397266, + -0.005172556731849909, + 0.02009868621826172, + -0.0064261858351528645, + -0.029086023569107056, + 0.001915874076075852, + 0.0008194410474970937, + 0.01620865799486637, + 0.03067426010966301, + -0.0018463254673406482, + 0.05358384922146797, + -0.003966080490499735, + -0.05991416424512863, + -0.06455761194229126, + 0.01634763367474079, + -0.013959774747490883, + 0.03615918383002281, + 0.004434086848050356, + 0.02086004987359047, + -0.004025993403047323, + -0.8869641423225403, + 0.05558132007718086, + 0.024729542434215546, + -0.005809253081679344, + -0.025079259648919106, + 0.04757235199213028, + 0.0023902510292828083, + 0.01522061601281166, + 0.011692625470459461, + 0.023033330217003822, + -0.012664714828133583, + -0.29325294494628906, + -0.006855700630694628, + -0.243958979845047, + 0.0024398649111390114, + -0.060877203941345215, + -0.21996521949768066, + -0.008708474226295948, + -0.06639625877141953, + -0.03170674294233322, + -0.09708897024393082, + 0.013403226621448994, + 0.024766888469457626, + 0.2594103217124939, + -0.02221749909222126, + 0.0662861093878746, + -0.15123076736927032, + -0.010314224287867546, + -0.0029192541260272264, + 0.05985910817980766, + 0.021665453910827637, + 0.003247617743909359, + -0.006802591495215893, + 0.00772367138415575, + 0.0399332195520401, + 0.005198766943067312, + 0.006013805978000164, + -0.04212838411331177, + -0.03166411817073822, + 0.13363900780677795, + 0.006383878644555807, + -0.05536859482526779, + 0.02053261175751686, + 0.015062958002090454, + 0.03352641686797142, + -0.2944328486919403, + 0.019855381920933723, + -0.15567174553871155, + -0.06759943068027496, + 0.07467031478881836, + 0.01674237661063671, + 0.004549413453787565, + -0.0032498433720320463, + -0.1837870180606842, + -0.04725493863224983, + -0.111307792365551, + 0.022237055003643036, + 0.004200428258627653, + 0.00970534235239029, + -0.045657914131879807, + -0.024577995762228966, + 0.0035376595333218575, + 0.008936531841754913, + -0.03904002904891968, + 0.05013228952884674, + -0.011168933473527431, + -0.008444730192422867, + 0.0035155978985130787, + -0.023502476513385773, + 0.005275514908134937, + -0.09448224306106567, + -0.009177467785775661, + -0.010720008052885532, + 0.004110944457352161, + -0.0060218218713998795, + 0.058124978095293045, + -0.0016586220590397716, + 0.15812785923480988, + -0.049118027091026306, + -0.007983109913766384, + -0.04265601187944412, + -0.01627231575548649, + 0.33705562353134155, + 0.01555223111063242, + 0.035853929817676544, + 0.0005046340520493686, + 0.054810188710689545, + -0.08808254450559616, + -0.0013819067971780896, + -0.14938786625862122, + -0.019771935418248177, + 0.004152575507760048, + 0.021979758515954018, + 0.1985529363155365, + -0.07694264501333237, + 0.013187955133616924, + -0.016572976484894753, + -0.03094586730003357, + -0.03673199936747551, + -0.03916170820593834, + -0.003836784977465868, + -0.012262578122317791, + 0.005559554789215326, + 0.1488093137741089, + -0.01842501200735569, + -0.004847189411520958, + -0.02391587756574154, + 0.015824301168322563, + 0.012022596783936024, + 0.06724318116903305, + -0.032682593911886215, + 0.00450896704569459, + -0.0024625889491289854, + 0.00933725107461214, + -0.04473242908716202, + 0.06270455569028854, + -0.02062271721661091, + -0.01071448065340519, + -0.017757099121809006, + 0.01575278490781784, + -0.06489317119121552, + -0.01519051194190979, + 0.0028058059979230165, + 0.00917835533618927, + -0.01291860081255436, + -0.009537308476865292, + 0.041757628321647644, + 0.03203853219747543, + -0.10918509215116501, + -0.007152496371418238, + -0.06777876615524292, + 0.03223242610692978, + 0.01780836284160614, + -0.09791012853384018, + -0.009385241195559502, + 0.013184775598347187, + 0.0031673219054937363, + -0.010640445165336132, + 0.024713385850191116, + -0.026738369837403297, + -0.004191657993942499, + -0.13764967024326324, + -0.003720735665410757, + 0.01737186871469021, + 0.015459887683391571, + 0.033229030668735504, + 0.008042111992835999, + -0.007184108253568411, + 0.008226306177675724, + 0.0031303109135478735, + 0.0406314842402935, + -0.8669105768203735, + 0.02079751342535019, + -0.17030003666877747, + -0.03849703446030617, + 0.034153200685977936, + -0.007219486869871616, + 0.11227627843618393, + -0.2681085467338562, + 0.015872526913881302, + 0.10855260491371155, + -0.008631505072116852, + 0.02556358277797699, + 0.06043418496847153, + -0.012900532223284245, + -0.08834894001483917, + 0.028099440038204193, + -0.05156330019235611, + 0.032628703862428665, + 0.044928934425115585, + 0.006176372990012169, + 0.007333829998970032, + -0.037409231066703796, + -0.046724822372198105, + -0.011172871105372906, + 0.04603327810764313, + 0.03288746625185013, + -0.20848578214645386, + 0.0028185085393488407, + -0.032673876732587814, + 0.061944279819726944, + 0.016787173226475716, + 0.02703898213803768, + -0.0060023171827197075, + 0.06870592385530472, + 0.03154531493782997, + 0.02784041129052639, + 0.007780189625918865, + 0.02033168077468872, + 0.0019289497286081314, + 0.02545374445617199, + 0.04262726008892059, + 0.01301807351410389, + -0.023882156237959862, + 0.027872221544384956, + -0.013518108054995537, + -0.0031075032893568277, + 0.03753834590315819, + 0.0369209349155426, + -0.014378191903233528, + 0.004397932440042496, + -0.030286893248558044, + -0.007679021451622248, + -0.045032769441604614, + 0.032050322741270065, + -0.03373495861887932, + -0.04363032802939415, + 0.034301597625017166, + -0.07021668553352356, + 0.03942524269223213, + -0.11061309278011322, + 0.049139462411403656, + 0.04161922261118889, + -0.01507576834410429, + -0.012748259119689465, + 0.06599434465169907, + 0.007602245546877384, + -0.03973209857940674, + -0.06923151016235352, + 0.026153067126870155, + -0.04221056029200554, + -0.4828230142593384, + 0.03360651433467865, + 0.01847662217915058, + -0.08594681322574615, + 0.04071836546063423, + -0.0035729086957871914, + 0.0049045816995203495, + -0.036198534071445465, + 0.03046257793903351, + 0.013275806792080402, + 0.09266786277294159, + -0.03625647351145744, + -0.059672992676496506, + 0.050213005393743515, + -0.018153885379433632, + -0.0858495831489563, + 0.01621098257601261, + -0.03029749169945717, + 0.02193332649767399, + 0.0422661192715168, + 0.6109512448310852, + -0.01068826112896204, + -0.02184930257499218, + -0.03213764354586601, + -0.03148162364959717, + -0.055331334471702576, + 0.006972005590796471, + -0.00815682765096426, + 0.014874683693051338, + -0.012943249195814133, + -0.03318992629647255, + -0.0010484680533409119, + 0.005414161365479231, + -0.013610370457172394, + 0.008836873807013035, + -0.05890084058046341, + -0.022663919255137444, + -0.018899116665124893, + -0.01037894282490015, + 0.005064660683274269, + 0.08522599190473557, + 0.0075323861092329025, + 0.013720778748393059, + 0.032096460461616516, + -0.008450351655483246, + 0.020377663895487785, + 0.04537765309214592, + 0.014030816033482552, + 0.024340089410543442, + 0.0231801588088274, + -0.10347768664360046, + 0.041163086891174316, + -0.060614243149757385, + -0.09241361171007156, + 0.05831432715058327, + -0.16008608043193817, + -0.04505622759461403, + 0.04866329953074455, + -0.0656094029545784, + 0.09627313911914825, + 0.1153625100851059, + 0.008151216432452202, + 0.03813345730304718, + 0.05990723893046379, + 0.24788673222064972, + 0.06294118613004684, + 0.11761849373579025, + -0.0722033903002739, + -0.013892017304897308, + -0.016778236255049706, + 0.038522012531757355, + -0.015539593063294888, + 0.01263216882944107, + 0.0003969807003159076, + -0.0224238783121109, + -0.005919966846704483, + 0.031987495720386505, + -0.014712700620293617, + 0.03508169203996658, + 0.07568854838609695, + -0.011961974203586578, + 0.027983952313661575, + -0.03512958809733391, + -0.010324078612029552, + -0.2895449995994568, + 0.007338976487517357, + -0.042290836572647095, + -0.1640917807817459, + -0.034807007759809494, + -0.1268443465232849, + 0.18418198823928833, + -0.3867812156677246, + -0.14214494824409485, + 0.001021744217723608, + 0.11288078874349594, + 0.006741920951753855, + -0.006421610247343779, + 0.021150892600417137, + 0.02486848644912243, + 0.002660338068380952, + 0.03732302784919739, + 0.10844919830560684, + -0.032568808645009995, + 0.009477612562477589, + 0.053578171879053116, + -0.07421902567148209, + 0.05660263076424599, + 0.03038308583199978, + 0.049440011382102966, + 0.0395139642059803, + 0.0217339675873518, + 0.028231965377926826, + 0.1661153882741928, + -0.02168717049062252, + 0.055143170058727264, + -0.14159196615219116, + 0.05894732475280762, + 0.006888065952807665, + -0.06988262385129929, + 0.017527412623167038, + -0.007171930745244026, + -0.00448343763127923, + 0.02932717651128769, + -0.00652179354801774, + -0.002897858154028654, + 0.020487705245614052, + -0.027063967660069466, + -0.02539752423763275, + -0.1066114604473114, + -0.10011029988527298, + -0.03331710025668144, + -0.003807300003245473, + -0.010441976599395275, + -0.005605363752692938, + 0.09679440408945084, + 0.020033519715070724, + -0.010188378393650055, + -0.030630890280008316, + -0.00955540407449007, + 0.02825581096112728, + -0.4307324290275574, + 0.012557203881442547, + 0.043258048593997955, + 0.09386534243822098, + -0.009555542841553688, + 0.05304868891835213, + 0.014706632122397423, + -0.012911850586533546, + 0.0981304720044136, + -0.010722141712903976, + -0.027317194268107414, + 0.0893903523683548, + -0.19983792304992676, + -0.15778200328350067, + -0.1012115329504013, + -0.3758164644241333, + -0.05782865360379219, + -0.01230492815375328, + -0.37126046419143677, + -0.01596723683178425, + 0.0020407456904649734, + -0.017498979344964027, + 0.005369496997445822, + -0.023121315985918045, + 0.022279681637883186, + -0.006232256535440683, + 0.05115891620516777, + 0.006679570768028498, + 0.0026316209696233273, + 0.04291496425867081, + 0.04381528124213219, + -0.05994122102856636, + 0.007081915624439716, + -0.04571640491485596, + 0.07592425495386124, + -0.00836833007633686, + 0.008123279549181461, + -0.008003163151443005, + -0.003938044421374798, + 0.005643180105835199, + 0.016194086521863937, + -0.004063089843839407, + 0.012334472499787807, + 0.017072021961212158, + 0.005761854816228151, + 0.004702428821474314, + 0.005736868362873793, + 0.0017962371930480003, + 0.059996701776981354, + 0.19533602893352509, + 0.02649352326989174, + -0.06493135541677475, + -0.05955052375793457, + 0.015692468732595444, + -0.10623155534267426, + 0.07290898263454437, + 0.036108434200286865, + -0.01248949021100998, + 0.16444285213947296, + -0.005899128969758749, + 0.07875277101993561, + 0.0014204353792592883, + 0.03381470963358879, + -0.09680792689323425, + 0.002102318685501814, + 0.026962973177433014, + 0.031665392220020294, + -0.18168538808822632, + 0.11163855344057083, + -0.5409999489784241, + 0.07833191007375717, + -0.005324948113411665, + 0.0267564058303833, + 0.02250477857887745, + 0.03249068558216095, + -0.18441715836524963, + -0.006447427906095982, + 0.037927329540252686, + 0.0005173985846340656, + -0.02617005631327629, + 0.05929232016205788, + -0.028510913252830505, + 0.05447050556540489, + 0.012390155345201492, + 0.00046797769027762115, + -0.008598590269684792, + -0.17247197031974792, + -0.02855759859085083, + 0.033968932926654816, + -0.09011702984571457, + 0.05276056379079819, + 0.03299655020236969, + -0.005699596833437681, + -0.1954648792743683, + 0.011109501123428345, + -0.0013570536393672228, + -0.6543989181518555, + 0.009102803654968739, + 0.0407538004219532, + 0.04312055557966232, + 0.027609223499894142, + -0.035538043826818466, + 0.027167823165655136, + -0.024043193086981773, + 0.0047575319185853004, + -0.006788836792111397, + 0.025714389979839325, + 0.007848678156733513, + -0.07680192589759827, + 0.009700766764581203, + -0.0097329281270504, + 0.00586724653840065, + 0.022815868258476257, + -0.023448282852768898, + -0.05608998239040375, + 0.10786863416433334, + -0.02803603559732437, + 0.012898198328912258, + -0.009270391426980495, + -0.021972229704260826, + 0.26533082127571106, + -0.01021308358758688, + -0.01972626894712448, + 0.062940314412117, + 0.022569671273231506, + 0.027042347937822342, + -0.05669092759490013, + -0.01200617104768753, + -0.006279367487877607, + -0.009608528576791286, + -0.013600943610072136, + -0.02187415212392807, + 0.0351138636469841, + 0.006282923277467489, + -0.011123511008918285, + -0.009205769747495651, + 0.001010146806947887, + -0.4796978235244751, + -0.0030205894727259874, + -0.011987377889454365, + -0.027548225596547127, + 0.009372347965836525, + -0.005388603545725346, + -0.006444129627197981, + -0.02501147985458374, + 0.027465635910630226, + 0.027784524485468864, + 0.006878893356770277, + -0.027763860300183296, + -0.0047700353898108006, + -0.018965192139148712, + 0.027898501604795456, + 0.022454144433140755, + 0.02973407506942749, + 0.03505602851510048, + 0.04003170132637024, + -0.004336829297244549, + -0.01998550072312355, + -0.06097743660211563, + -0.07844759523868561, + 0.0013787010684609413, + 0.0066132270731031895, + -0.03124997951090336, + 0.0313432514667511, + 0.047656893730163574, + 0.06175797060132027, + -0.02077358029782772, + -0.004535601008683443, + -0.10219905525445938, + -0.07125344127416611, + -0.06927482783794403, + -0.04813461750745773, + -0.02618095651268959, + -0.01255929097533226, + -0.009180150926113129, + -0.005838831886649132, + 0.09108023345470428, + -0.032710760831832886, + 0.03091445378959179, + -0.01955563761293888, + 0.0959300771355629, + -0.09353741258382797, + -0.0761636272072792, + -0.023445438593626022, + -0.012328366748988628, + 0.05850536748766899, + -0.052494827657938004, + 0.0025638933293521404, + -0.017152179032564163, + -0.004435579292476177, + 0.12312240898609161, + -0.007241012528538704, + 0.09605048596858978, + 0.03355967625975609, + -0.015987426042556763, + -0.03470349311828613, + -0.02499505691230297, + -0.015004142187535763, + -0.018609771504998207, + -0.06654462963342667, + 0.013861652463674545, + -0.005973289255052805, + -0.04734775796532631, + 0.08755116909742355, + 0.03012942522764206, + 0.07887610793113708, + -0.01827712170779705, + 0.10793066769838333, + 0.10793614387512207, + -0.01075535174459219, + 0.03439560532569885, + 0.011567444540560246, + 0.0016386889619752765, + -0.031207261607050896, + -0.01707504875957966, + 0.20471863448619843, + 0.0025428179651498795, + 0.004082779865711927, + -0.012389302253723145, + 0.0400562584400177, + -0.21075034141540527, + 0.012872264720499516, + -0.01639414019882679, + 0.016652485355734825, + 0.0016037120949476957, + -0.006540367379784584, + -0.0068405005149543285, + -0.2484254390001297, + 0.0008089764742180705, + -0.022340824827551842, + -0.005441636312752962, + 0.002882100408896804, + 0.008654038421809673, + 0.07159754633903503, + -0.02537086047232151, + 0.011997461318969727, + -0.49913132190704346, + -0.02300887741148472, + 0.044442202895879745, + 0.001787978457286954, + 0.010291379876434803, + 0.009601960889995098, + -0.5312613248825073, + -0.014247804880142212, + 0.06685849279165268, + 0.035772595554590225, + 0.03432310372591019, + 0.03151272237300873, + -0.10318460315465927, + -0.030476456508040428, + -0.004469831008464098, + -0.16645164787769318, + -0.021104637533426285, + 0.013934006914496422, + -0.011767406016588211, + 0.008054615929722786, + 0.06089277192950249, + 0.0003409573109820485, + -0.0053401123732328415, + 0.05970478057861328, + -0.004363172687590122, + 0.014423285610973835, + -0.002795026171952486, + -0.019875092431902885, + -0.07540513575077057, + -0.09043378382921219, + 0.00750827556475997, + -0.045314721763134, + -0.00724808732047677, + 0.005193864461034536, + -0.020468784496188164, + -0.01098695583641529, + -0.0003122477210126817, + -0.007263806648552418, + -0.03325646370649338, + 0.021689830347895622, + -0.13272541761398315, + 0.02332465350627899, + -0.019292252138257027, + 0.05533658340573311, + -0.018616480752825737, + -0.015228793025016785, + -0.28432801365852356, + -0.29721561074256897, + 0.04648810625076294, + -0.014750649221241474, + -0.15370936691761017, + -0.1497083604335785, + 0.013243601657450199, + 0.042343802750110626, + -0.017519792541861534, + -0.0161418616771698, + 0.00807454064488411, + -0.023562468588352203, + -0.0315413773059845, + 0.03386805206537247, + 0.2854529917240143, + 0.0191020630300045, + -0.49126777052879333, + 0.052687134593725204, + -0.023298051208257675, + -0.009119837544858456, + 0.05149759724736214, + -0.8527837991714478, + 0.08062390983104706, + 0.057379938662052155, + -0.020724931731820107, + -0.006624895613640547, + 0.05322050303220749, + 0.017887847498059273, + 0.04229281470179558, + 0.04171830415725708, + 0.029683062806725502, + -0.00028416322311386466, + 0.1112222746014595, + -0.0448714978992939, + -0.005255761090666056, + 0.017773712053894997, + -0.0016064767260104418, + -0.013840594328939915, + -0.00398495327681303, + -4.32919041486457e-05, + 0.040796443819999695, + 0.018185198307037354, + -0.018671950325369835, + 0.0028256692457944155, + -0.020582057535648346, + 0.05567716807126999, + -0.056062404066324234, + 0.01614757999777794, + -0.0029299987945705652, + 0.048686008900403976, + 0.04299888014793396, + 0.12249592691659927, + 0.01469603180885315, + -0.1254546344280243, + -0.18532024323940277, + -0.003263876074925065, + 0.014804725535213947, + 0.004450956825166941, + -0.013681051321327686, + -0.0030781759414821863, + -0.03433656692504883, + -0.0035507124848663807, + 0.1600082814693451, + -0.028547707945108414, + -0.00989136379212141, + -0.012126478366553783, + -0.12963305413722992, + 0.008547360077500343, + 0.017959514632821083, + -0.012571084313094616, + 0.0008666724897921085, + -0.010519342496991158, + -0.009684977121651173, + -0.04285729303956032, + 0.015031769871711731, + -0.030043724924325943, + 0.018907636404037476, + 0.08019450306892395, + -0.04836742579936981, + 0.01025464478880167, + -0.004908542148768902, + -0.10327022522687912, + -0.10163667798042297, + -0.03403499722480774, + -0.019678063690662384, + -0.043049123138189316, + 0.0384567566215992, + -0.05596519634127617, + -0.09381429851055145, + -0.18688108026981354, + -0.09762943536043167, + -0.03164997324347496, + -0.006416287273168564, + 0.07003920525312424, + -0.016646990552544594, + -0.025972194969654083, + -0.028768088668584824, + -0.06332779675722122, + 0.045144014060497284, + -0.03735211119055748, + -0.010442189872264862, + 0.10948455333709717, + 0.14629514515399933, + -0.023416690528392792, + -0.01347778458148241, + 0.020830679684877396, + 0.0003131759003736079, + 0.007049075793474913, + 0.06547018885612488, + 0.03152740001678467, + 0.08380027115345001, + 0.03185325488448143, + -0.015359007753431797, + 0.08864206075668335, + 0.032676901668310165, + -0.002908645663410425, + 0.053111132234334946, + 0.0026159954722970724, + -0.05177146941423416, + -0.033048152923583984, + -0.0020293137058615685, + -0.07363513857126236, + -0.17662747204303741, + 0.004798125941306353, + 0.07139395922422409, + 0.019802849739789963, + 0.009199771098792553, + -0.009043877013027668, + -0.07681646943092346, + -0.06748555600643158, + 0.05094710737466812, + 0.0014789587585255504, + -0.0166088305413723, + -0.27988284826278687, + 0.03634800389409065, + 0.05322619527578354, + -0.15566207468509674, + -0.019964642822742462, + -0.010204506106674671, + -0.011832086369395256, + -0.0680927112698555, + -0.05793820694088936, + 0.0020100779365748167, + -0.24647225439548492, + 0.04904041066765785, + -0.05589786171913147, + -0.030167482793331146, + 0.023974033072590828, + -0.22719347476959229, + 0.019620347768068314, + -0.18078163266181946, + -0.11321499198675156, + -0.023790234699845314, + -0.1266157031059265, + 0.01117659267038107, + 0.13824795186519623, + -0.024211348965764046, + -0.0548308864235878, + 0.04849318787455559, + -0.0016174454940482974, + -0.01826266385614872, + 0.006709347013384104, + -0.350631982088089, + 0.03139018639922142, + 0.021502504125237465, + -0.12596893310546875, + 0.04311670735478401, + -0.005905786994844675, + -0.0807335153222084, + -0.07214773446321487, + -0.2054852843284607, + -0.04526854678988457, + -0.09145382046699524, + 0.002603817731142044, + -0.01951524056494236, + -0.0028278473764657974, + -0.03270411863923073, + -0.0003385065938346088, + -0.019816655665636063, + -0.003430107608437538, + 0.010664679110050201, + 0.030127109959721565, + 0.02611778862774372, + 0.030213139951229095, + 0.04682943969964981, + 0.010338326916098595, + -0.02618880569934845, + 0.014982170425355434, + -0.06979402899742126, + 0.06403722614049911, + 0.025545112788677216, + -0.11981001496315002, + 0.004320457112044096, + 0.008849565871059895, + 0.07450827211141586, + -0.04322020336985588, + -0.07648278027772903, + 0.009221173822879791, + -0.12771189212799072, + 0.027474528178572655, + -0.1637975573539734, + -0.022587651386857033, + 0.0713210329413414, + -0.09652210026979446, + -0.04942077025771141, + -0.08977267891168594, + -0.004629603121429682, + -0.09891843795776367, + 0.0004028059483971447, + 0.12999524176120758, + 0.009417874738574028, + -0.012465995736420155, + 0.09959464520215988, + 0.012048770673573017, + 0.00529639283195138, + -0.1231047734618187, + -0.010156300850212574, + -0.0067022680304944515, + 0.09231371432542801, + 0.1372271031141281, + 0.01140755694359541, + -0.014376018196344376, + 0.009014246053993702, + -0.0558021254837513, + 0.009297777898609638, + -0.023461824283003807, + 0.12312523275613785, + 0.0013492326252162457, + -0.10130659490823746, + 0.07867099344730377, + -0.04363301396369934, + -0.05203291028738022, + 0.010715829208493233, + 0.2679101228713989, + 0.047242000699043274, + 0.009700302965939045, + -0.004188477993011475, + 0.04595324397087097, + -0.10256988555192947, + 0.013266253285109997, + 0.13415516912937164, + -0.06461263447999954, + -0.04262775555253029, + 0.014638054184615612, + -0.020396970212459564, + 0.016008291393518448, + 0.012964261695742607, + 0.030219901353120804, + -0.03906702250242233, + -0.009459082037210464, + -0.006880247965455055, + 0.009383107535541058, + 0.0591101311147213, + -0.049882922321558, + -0.014105924405157566, + -0.04896679148077965, + 0.021726086735725403, + -0.013863577507436275, + -0.05801064148545265, + -0.031143831089138985, + 0.0010298469569534063, + -0.03104572743177414, + 0.1193046048283577, + 0.00880056619644165, + -0.01678626798093319, + 0.0014990485506132245, + -0.001967367948964238, + -0.0053575835190713406, + -0.006879259832203388, + -0.008937212638556957, + 0.014141763560473919, + 0.00687083275988698, + -0.0012949275551363826, + 0.017160816118121147, + -0.035110652446746826, + -0.00976842176169157, + 0.026605995371937752, + 0.004003277514129877, + 0.010927689261734486, + 0.002173327375203371, + -0.05133439600467682, + -0.04658171907067299, + 0.03023359179496765, + -0.015038624405860901, + 0.016580749303102493, + 0.02393144741654396, + 0.004817661829292774, + -0.008468102663755417, + 0.017239807173609734, + 0.019924553111195564, + 0.02557404898107052, + 0.01985766738653183, + -0.01881517469882965, + -0.14637643098831177, + -0.005403783638030291, + -0.013156545348465443, + -0.3882855176925659, + 0.01537711638957262, + 0.005061861593276262, + 0.018044542521238327, + 0.00010373388067819178, + -0.01769324019551277, + -0.020439250394701958, + 0.01761222817003727, + 0.017716309055685997, + -0.01828574948012829, + 0.0059916484169662, + 0.006117791403084993, + -0.0025541253853589296, + 0.01598154753446579, + 0.0015296537894755602, + 0.006711189169436693, + -0.005831963382661343, + 0.024547481909394264, + 0.011665170080959797, + 0.013990279287099838, + -0.009193074889481068, + -0.0014407691778615117, + 0.0025373499374836683, + -0.001535113900899887, + 0.022016262635588646, + 0.002165747107937932, + -0.00010288839985150844, + -0.01185672264546156, + 0.3959958255290985, + -0.06701132655143738, + 0.024550342932343483, + -0.007259713020175695, + 0.00011224728223169222, + 0.08959072828292847, + 0.006745494436472654, + -0.007461291737854481, + -0.0010788652580231428, + -0.003997487016022205, + 0.0023250498343259096, + 0.005845727398991585, + 0.002441686810925603, + 0.0010628585005179048, + 0.004687050357460976, + 0.03825820982456207, + 0.0027951127849519253, + 0.004356732591986656, + 0.0036379920784384012, + -0.00048690394032746553, + -0.31681910157203674, + 0.01621195860207081, + 0.009373913519084454, + -0.005099120549857616, + 0.004866141825914383, + 0.008112045004963875, + -0.009933174587786198, + -0.006929770577698946, + 0.005561198107898235, + -0.2225065976381302, + -0.00019208311277907342, + -0.003284667618572712, + 0.010527989827096462, + -0.010160842910408974, + -0.008410060778260231, + 0.004605174530297518, + 0.01542133092880249, + 0.013958578929305077, + 0.0021779180970042944, + 0.002810562262311578, + 0.001369283301755786, + -0.0003347232413943857, + 0.013902815990149975, + -0.0022218015510588884, + 0.00024955783737823367, + -0.0019350153161212802, + 0.0025213193148374557, + -0.0054915109649300575, + -0.00011564489977899939, + -0.0037644850090146065, + -0.002863431815057993, + -0.0025196163915097713, + 0.02352992817759514, + 0.00354134407825768, + -0.010700036771595478, + -0.03428381308913231, + 0.008170859888195992, + 0.005420713219791651, + -0.0013479178305715322, + 0.0015741022070869803, + -0.18286381661891937, + 0.03189067915081978, + 0.0014371845172718167, + -4.885893940809183e-05, + -0.004666821099817753, + -0.026595929637551308, + -0.0064376350492239, + 0.01583540253341198, + -0.085715651512146, + -0.00916224904358387, + -0.3605174124240875, + 0.019973354414105415, + 0.05533794313669205, + 0.053907446563243866, + 0.030877795070409775, + -0.919844925403595, + 8.968543988885358e-05, + -0.02068270742893219, + 0.012602192349731922, + 0.03245612978935242, + 0.06622699648141861, + 0.00882122665643692, + -0.03616628423333168, + -0.02428283728659153, + 0.003318701172247529, + -0.0007259293342940509, + -0.026197656989097595, + -0.059503961354494095, + 0.029495801776647568, + -0.006955073680728674, + -0.01926456019282341, + 0.009927013888955116, + 0.059641581028699875, + 0.0016886347439140081, + -0.029346982017159462, + 0.01948450319468975, + -0.04397860914468765, + 0.025248751044273376, + 0.04597266763448715, + 0.009454794228076935, + -0.018872544169425964, + -0.039650529623031616, + 0.026324709877371788, + -0.01808176562190056, + 0.028935831040143967, + 0.009501701220870018, + -0.05183069407939911, + -0.005787428934127092, + -0.021436212584376335, + 0.029735956341028214, + 0.0350160151720047, + 0.033825185149908066, + 0.03185566887259483, + 0.018431033939123154, + 0.02450188808143139, + 0.03271135315299034, + -0.0027792940381914377, + -0.0004625302099157125, + 0.01268392987549305, + 0.045023106038570404, + 0.05562014505267143, + 0.029052015393972397, + -0.002513203304260969, + -0.08349838852882385, + 7.017837560852058e-06, + -0.0014392733573913574, + 0.016982918605208397, + 0.016358936205506325, + -0.024013325572013855, + -0.004375616554170847, + -0.03734249249100685, + 0.04336351156234741, + 0.07323610782623291, + -0.0243068914860487, + 0.009403819218277931, + 0.02663031965494156, + 0.01930687017738819, + 0.02175578847527504, + 0.01639295555651188, + 0.024892140179872513, + 0.031219134107232094, + 0.02986173704266548, + -0.002100786194205284, + 0.05054357647895813, + 0.04015854373574257, + 0.0048207067884504795, + -0.03244275599718094, + 0.027246609330177307, + 0.00409608893096447, + -0.0054193479008972645, + 0.07014931738376617, + 0.009954879060387611, + 0.022472694516181946, + -0.47738370299339294, + -0.019097158685326576, + 0.028984038159251213, + -0.042564358562231064, + -0.006040808744728565, + 0.04094231128692627, + -0.007740774191915989, + -0.07854597270488739, + 0.003920051269233227, + -0.050799619406461716, + 0.023691626265645027, + 0.019952887669205666, + 0.00716764759272337, + -0.0046928380616009235, + 0.00041822553612291813, + 0.006359069608151913, + 0.017860781401395798, + -0.22999149560928345, + -0.02180831879377365, + -0.024055887013673782, + -0.0226126741617918, + -0.01795077696442604, + 0.015591473318636417, + -0.004053472075611353, + 0.016760380938649178, + 0.03378744795918465, + -0.0027090508956462145, + 0.00999806821346283, + 0.019252799451351166, + 0.0027550198137760162, + 0.03454355522990227, + -0.0295003242790699, + -0.007663591764867306, + 0.061172280460596085, + 0.049142658710479736, + -0.00858291145414114, + -0.0035321018658578396, + -0.7689260244369507, + 0.0004916944890283048, + 0.02915046364068985, + 0.017000442370772362, + -0.003298018593341112, + -0.0405484102666378, + 0.021160880103707314, + 0.0013289587805047631, + -0.07510386407375336, + 0.03890690207481384, + 0.03729970380663872, + -0.04906352981925011, + -0.10020274668931961, + 0.01506283599883318, + -0.053726132959127426, + 0.016631007194519043, + 0.03425036743283272, + 0.03358260169625282, + -0.023937245830893517, + -0.13656578958034515, + -0.13947314023971558, + 0.012915699742734432, + 0.02431132085621357, + -0.03089652583003044, + 0.1382707953453064, + 0.056695129722356796, + -0.09263960272073746, + 0.10406216233968735, + 0.02619105577468872, + -0.01678614132106304, + -0.16045455634593964, + 8.974489173851907e-05, + -0.03521093726158142, + -0.028908027336001396, + 0.21234789490699768, + -0.02046572044491768, + -0.09703273326158524, + 0.05248226970434189, + 0.011973158456385136, + 0.004557646345347166, + -0.018632734194397926, + -0.1649131029844284, + -0.00682018743827939, + -0.12712189555168152, + 0.10513507574796677, + 0.020745709538459778, + 0.02996259182691574, + -0.15409024059772491, + -0.08719073981046677, + -0.14634187519550323, + -0.16255779564380646, + -0.15963757038116455, + -0.1324772834777832, + -0.022830091416835785, + -0.06426219642162323, + -0.025459224358201027, + 0.00281702633947134, + 0.03255268186330795, + -0.05778049677610397, + -0.30381152033805847, + -0.06582051515579224, + -0.033722274005413055, + 0.014956191182136536, + 0.004153797868639231, + 0.2391217201948166, + -0.0311420951038599, + 0.001518488978035748, + 0.019769812002778053, + -0.056324463337659836, + -0.006009253207594156, + -0.21367721259593964, + -0.0481688529253006, + 0.22422266006469727, + 0.0402204655110836, + 0.1432792693376541, + 0.14159953594207764, + -0.0025862890761345625, + -0.028965365141630173, + 0.011978867463767529, + 0.161293163895607, + 0.028642605990171432, + -0.008417634293437004, + -0.10145614296197891, + 0.08381767570972443, + 0.05199432373046875, + 0.18680602312088013, + -0.023287687450647354, + 0.03601476550102234, + 0.03738229721784592, + 0.19291405379772186, + 0.03553088754415512, + 0.05483124405145645, + 0.09577616304159164, + -0.004635817836970091, + 0.052481625229120255, + -0.042084019631147385, + -0.2629147469997406, + -0.006157668773084879, + -0.0401761569082737, + 0.02154349908232689, + -0.056558139622211456, + -0.003753019031137228, + 0.01922912523150444, + 0.1291409730911255, + -0.21358416974544525, + 0.004696246236562729, + 0.13787509500980377, + -0.07022479176521301, + -0.06828727573156357, + 0.09193858504295349, + -0.06863763928413391, + -0.05677935853600502, + -0.030970478430390358, + -0.10181070864200592, + -0.1247706487774849, + 0.014181962236762047, + -0.09259836375713348, + -0.03174220770597458, + -0.014812505804002285, + -0.024658311158418655, + -0.04815720021724701, + -0.01683010160923004, + 0.015726473182439804, + 0.002938281511887908, + -0.1586887538433075, + -0.29276973009109497, + -0.029981529340147972, + -0.046828676015138626, + -0.04909103736281395, + 0.06043976545333862, + 0.03698069602251053, + -0.04807118698954582, + 0.0943484902381897, + 0.01930702105164528, + 0.06498143821954727, + 0.0381690077483654, + -0.19611406326293945, + 0.006944946013391018, + 0.06454038619995117, + -0.19779883325099945, + 0.04966692253947258, + 0.046355295926332474, + 0.0590626522898674, + -0.24392037093639374, + -0.0018132536206394434, + 0.010944955050945282, + -0.014556891284883022, + 0.051466893404722214, + -0.0059846509248018265, + -0.06719732284545898, + 0.030604040250182152, + 0.051190104335546494, + -0.053196243941783905, + -0.06912374496459961, + -0.06263922154903412, + 0.05626852437853813, + 0.013047950342297554, + -0.005828890949487686, + 0.056055404245853424, + 0.007044378202408552, + 0.030499491840600967, + -0.035373322665691376, + 0.030934391543269157, + 0.04358363524079323, + 0.001537138712592423, + 0.005963161122053862, + -0.005889860913157463, + 0.053225863724946976, + 0.052091702818870544, + -0.02871675044298172, + 0.05662619322538376, + -0.4585985839366913, + 0.06490323692560196, + 0.02542230300605297, + 0.017592567950487137, + 0.05066920816898346, + -0.20954127609729767, + -0.06689731031656265, + -0.3632309138774872, + -0.03407476842403412, + 0.04976007342338562, + 0.03856723755598068, + 0.009329214692115784, + -0.10107281804084778, + 0.007077769376337528, + -0.005482642911374569, + 0.04388934373855591, + 0.03984231874346733, + 0.005358297843486071, + 0.05032944679260254, + 0.007170544005930424, + 0.017318176105618477, + -0.03577208146452904, + -0.02195456624031067, + 0.014414021745324135, + -0.008203372359275818, + 0.04585091397166252, + -0.012298643589019775, + 0.03959968313574791, + -0.06015963852405548, + -0.1360240876674652, + -0.07704123109579086, + -0.0842466950416565, + -0.11261942237615585, + 0.0433686338365078, + -0.1059969812631607, + 0.014813154004514217, + 0.04216694459319115, + 0.10441470146179199, + 0.04579426348209381, + 0.026033954694867134, + 0.08725529909133911, + -0.14662955701351166, + -0.0726592168211937, + 0.1293957382440567, + 0.013497715815901756, + -0.01318936888128519, + -0.05188713222742081, + 0.08793413639068604, + 0.1094818189740181, + 0.07991892844438553, + 0.03549068048596382, + -0.04469897970557213, + -0.10442564636468887, + 0.13456915318965912, + 0.01154977548867464, + -0.05959299951791763, + 0.01768219843506813, + 0.0179652888327837, + -0.010112428106367588, + 0.020603090524673462, + -0.7144030928611755, + 0.20126283168792725, + 0.058172807097435, + -0.10543914139270782, + 0.07461538910865784, + -0.1744592934846878, + 0.055722273886203766, + -0.046595826745033264, + 0.06237049773335457, + 0.05800141766667366, + 0.04118870943784714, + 0.002582935383543372, + 0.010623090900480747, + -0.0439014658331871, + 0.044685740023851395, + -0.017063472419977188, + -0.0173367727547884, + -0.04761765897274017, + 0.06136244907975197, + 0.08495236933231354, + 0.24923592805862427, + -0.061080869287252426, + 0.15922360122203827, + -0.09322690963745117, + -0.09617402404546738, + 0.0029533954802900553, + 0.12630371749401093, + 0.0011397749185562134, + 0.0005059551913291216, + -0.060922350734472275, + -0.16446451842784882, + 0.057099178433418274, + 0.03073902614414692, + -0.031064951792359352, + 0.012277435511350632, + 0.020447896793484688, + 0.06010727211833, + 0.07065457105636597, + 0.026963504031300545, + 0.010798406787216663, + -0.02631279267370701, + 0.02046871930360794, + -0.004800989292562008, + -0.03282550349831581, + 0.053904879838228226, + -0.03294985368847847, + -0.4204113185405731, + 0.028552187606692314, + 0.023685462772846222, + 0.0017703581834211946, + 0.02868991158902645, + -0.3585520088672638, + -0.011516556143760681, + -0.00248165475204587, + 0.011379038915038109, + 0.0459531806409359, + 0.015357235446572304, + 0.05573337897658348, + 0.06516549736261368, + 0.02981666848063469, + 0.05498211458325386, + 0.028714550659060478, + -0.005899528972804546, + 0.008476868271827698, + 0.11328839510679245, + 0.020578190684318542, + -0.15382742881774902, + 0.015724696218967438, + -0.08402770012617111, + 0.060314107686281204, + 0.032343748956918716, + 0.014438764192163944, + -0.13614842295646667, + -0.0017508765449747443, + 0.09998518973588943, + -0.06364594399929047, + 0.049632295966148376, + -0.11922458559274673, + -0.08834195137023926, + 0.019541991874575615, + 0.06320779770612717, + 0.017419861629605293, + -0.0028468866366893053, + -0.14753428101539612, + 0.02623703144490719, + -0.011462770402431488, + 0.06676206737756729, + -0.014891563914716244, + -0.002118025440722704, + 0.02519390918314457, + -0.29581141471862793, + 0.0264339130371809, + 0.04027356952428818, + 0.00412194337695837, + 0.03778498247265816, + -0.012331741861999035, + 0.15336745977401733, + -0.034510836005210876, + 0.0319819413125515, + 0.01916184462606907, + 0.04952343553304672, + -0.026733938604593277, + -0.014996573328971863, + 0.0010714810341596603, + 0.01959756202995777, + -0.0392388179898262, + -0.0052064210176467896, + -0.05015777423977852, + -0.0002977418771479279, + -0.04029487073421478, + -0.012846150435507298, + -0.09198840707540512, + 0.0118671590462327, + -0.06176264211535454, + 0.006427878048270941, + 0.04043034091591835, + -0.017270859330892563, + -0.012422707863152027, + 0.01713552325963974, + -0.026697810739278793, + 0.2446632832288742, + -0.020500628277659416, + -0.0012782106641680002, + -0.13429665565490723, + 0.07528743892908096, + -0.002225265372544527, + 0.06695574522018433, + 0.0017388156848028302, + -0.0629071593284607, + -0.05081196129322052, + 0.042025983333587646, + 0.029097404330968857, + 0.07048555463552475, + -0.11881273239850998, + 0.012633765116333961, + -0.06181430071592331, + 0.038810230791568756, + 0.05186169967055321, + 0.03248963877558708, + 0.07868267595767975, + 0.024977494031190872, + 0.023991582915186882, + 0.0023529180325567722, + 0.07197123020887375, + 0.02653665468096733, + 0.058702051639556885, + 0.015001803636550903, + 0.043739400804042816, + -0.07251746207475662, + 0.045659150928258896, + -0.02111324854195118, + 0.26666632294654846, + 0.1975221484899521, + -0.031074335798621178, + 0.029075143858790398, + 0.013020229525864124, + 0.015244663693010807, + 0.01387549377977848, + -0.025354426354169846, + 0.06151636317372322, + -0.034430794417858124, + 0.00752665288746357, + 0.1678706705570221, + -0.016560610383749008, + 0.0421285480260849, + -0.02527586743235588, + -0.02166694961488247, + -0.034658536314964294, + 0.036866605281829834, + -0.036233626306056976, + 0.02042747661471367, + 0.028099242597818375, + 0.020503878593444824, + 0.022789381444454193, + 0.08666791766881943, + -0.06426636874675751, + -0.043599683791399, + 0.1136128157377243, + 0.020200412720441818, + -0.003839759388938546, + -0.06010120362043381, + -0.02218424715101719, + 0.09008956700563431, + 0.008711264468729496, + -0.04874516651034355, + -0.011533043347299099, + -0.036206502467393875, + -0.006006627343595028, + -0.0350450798869133, + 0.005623341538012028, + 0.09562186151742935, + -0.03952183946967125, + -0.013931595720350742, + -0.020029470324516296, + 0.0022144403774291277, + -0.020198611542582512, + 0.012238736264407635, + 0.054415784776210785, + -0.024457741528749466, + -0.01174110360443592, + 0.031656913459300995, + 0.060322560369968414, + 0.01573050767183304, + 0.03361794352531433, + 0.022875478491187096, + 0.036340806633234024, + -0.02932620421051979, + 0.0224352665245533, + -0.013475337065756321, + -0.030774995684623718, + 0.013921404257416725, + -0.01229875348508358, + -0.07986237108707428, + -0.007543445099145174, + 0.05208213999867439, + -0.04440496116876602, + -0.029659371823072433, + -0.029070377349853516, + 0.07376870512962341, + -0.07208643853664398, + -0.05429431423544884, + -0.007887271232903004, + 0.011400371789932251, + 0.014227204024791718, + 0.01763899251818657, + -0.0426466204226017, + 0.0024213625583797693, + 0.02564665488898754, + 0.0020850151777267456, + 0.027386819943785667, + 0.12722602486610413, + -0.060991525650024414, + -0.009061425924301147, + 0.014208497479557991, + -0.006956137716770172, + 0.09096626192331314, + 0.0037735258229076862, + -0.8347064852714539, + -0.2857951521873474, + 0.0011818337952718139, + 0.0341162234544754, + -0.04230167716741562, + 0.05230262130498886, + 0.08486262708902359, + -0.34235459566116333, + -0.02393503487110138, + 0.02718495950102806, + 0.050966840237379074, + 0.024611525237560272, + -0.004936584271490574, + -0.036420952528715134, + -0.009803534485399723, + 0.05421328917145729, + 0.008357672952115536, + 0.020987343043088913, + -0.007292840629816055, + 0.018060531467199326, + 0.06739793717861176, + 0.06161382421851158, + 0.000842935056425631, + -0.007857701741158962, + 0.023870037868618965, + -0.009690430946648121, + -0.04231289029121399, + -0.22531479597091675, + 0.034284885972738266, + 0.07360551506280899, + 0.0421777106821537, + 0.000788167177233845, + -0.3953339457511902, + -0.042627450078725815, + -0.02774403616786003, + 0.02647743932902813, + -0.01561375055462122, + 0.04745408892631531, + 0.021774733439087868, + 0.006606150884181261, + 0.03879173845052719, + 0.06500626355409622, + 0.044954728335142136, + 0.01523532159626484, + 0.04741065576672554, + -0.13645507395267487, + 0.0038059696089476347, + -0.012993253767490387, + -0.004529603291302919, + 0.03268986567854881, + -0.025349941104650497, + -0.02268051542341709, + -0.0001516443444415927, + -0.010289257392287254, + -0.0010476588504388928, + -0.0690254345536232, + 0.04298266023397446, + -0.05470968782901764, + 0.04369102790951729, + -0.007372597698122263, + 0.027607066556811333, + 0.0009343988494947553, + -0.09573916345834732, + 0.04389296472072601, + -0.01522558368742466, + -0.03138086944818497, + 0.04511113464832306, + -0.0342172235250473, + -0.00033129166695289314, + -0.037289440631866455, + 0.055575959384441376, + 0.01849759928882122, + 0.03041103295981884, + -0.01965116336941719, + 0.07604960352182388, + -0.0399625338613987, + -0.008190250024199486, + -0.015386211685836315, + -0.04315667226910591, + 0.0023679479490965605, + 0.018971435725688934, + -0.005599244497716427, + -0.029607947915792465, + 0.07574024051427841, + -0.013816094025969505, + 0.04464992880821228, + 0.00032806122908368707, + 0.06071484833955765, + 0.04261377081274986, + 0.012208743952214718, + 0.0801805928349495, + 0.02875029854476452, + -0.0662921741604805, + 0.015754999592900276, + 0.05831082537770271, + 0.03810921683907509, + 0.05483977496623993, + -0.019509335979819298, + 0.0032034649048000574, + 0.011807492934167385, + -0.01916244812309742, + 0.022101666778326035, + -0.0366031751036644, + 0.10915965586900711, + 0.030322788283228874, + -0.028386037796735764, + -0.05443429946899414, + -0.02489445172250271, + 0.0892239362001419, + -0.05427740886807442, + -0.034238025546073914, + -0.04136161506175995, + -0.041148390620946884, + 0.06879492849111557, + -0.37424594163894653, + 0.028803903609514236, + 0.05349116027355194, + 0.0359492301940918, + -0.3629145622253418, + -0.17875684797763824, + -0.012246759608387947, + 0.2744927704334259, + -0.010421697050333023, + -0.19415415823459625, + 0.005668101832270622, + 0.018326066434383392, + 0.28319111466407776, + -0.008164885453879833, + -0.07401272654533386, + -0.04154321923851967, + 0.030028337612748146, + -0.008959534578025341, + -0.03160349279642105, + -0.0191870778799057, + 0.044875819236040115, + 0.052173007279634476, + 0.012135458178818226, + 0.008775291964411736, + 0.005302258301526308, + 0.009224606677889824, + -0.07574712485074997, + 0.06096252053976059, + 0.02645082212984562, + 0.05135556682944298, + 0.021985528990626335, + 0.0076704383827745914, + 0.02961125783622265, + -0.07608609646558762, + -0.17564956843852997, + 0.03679918497800827, + -0.2696506083011627, + 0.0627906322479248, + 0.031165480613708496, + 0.01799822598695755, + 0.02351829782128334, + 0.015595306642353535, + -0.25137314200401306, + -0.011266927234828472, + 0.04895596578717232, + 0.01718883402645588, + 0.0009224268142133951, + 0.021923478692770004, + 0.044791676104068756, + 0.079147569835186, + 0.02014082670211792, + -0.0003547854721546173, + -0.02535748854279518, + -0.029639363288879395, + -0.01965961419045925, + -0.37630724906921387, + 0.01674639992415905, + 0.01316642016172409, + -0.025120021775364876, + -0.12474260479211807, + 0.059980470687150955, + 0.036066047847270966, + -0.15973420441150665, + -0.010871605016291142, + 0.014708316884934902, + -0.2174367904663086, + 0.012985467910766602, + -0.03782057762145996, + -0.003427069401368499, + -0.011010636575520039, + 0.02433733455836773, + 0.08641276508569717, + -0.004630533047020435, + 0.019430357962846756, + -0.02088969387114048, + -0.06182911619544029, + 0.02577812969684601, + 0.015741532668471336, + 0.04723552614450455, + -0.003783567575737834, + 0.11646346747875214, + 0.01827184483408928, + -0.0999741181731224, + -0.0031216999050229788, + -0.002268272452056408, + -0.019456079229712486, + -0.003156653605401516, + 0.0067732855677604675, + 0.027299508452415466, + 0.06979037076234818, + 0.013329057022929192, + -0.016705401241779327, + 0.33774301409721375, + 0.007617524825036526, + 0.044453222304582596, + 0.0016282782889902592, + 0.0010982973035424948, + 0.04183036834001541, + 0.016857653856277466, + 0.006673034280538559, + -0.0187662523239851, + 0.0037163379602134228, + -0.04568779841065407, + -0.007807960733771324, + 0.016653010621666908, + 0.0033014933578670025, + 0.015063234604895115, + 0.012843966484069824, + -0.012042546644806862, + 0.016909126192331314, + 0.022089935839176178, + -0.002550398698076606, + 0.04166745766997337, + -0.0014742743223905563, + -0.010846617631614208, + -0.12333541363477707, + 0.0018612967105582356, + 0.04913188889622688, + -0.029431112110614777, + 0.01824735291302204, + 0.10425490140914917, + -0.08880072832107544, + 0.03029320202767849, + 0.018876856192946434, + 0.016104502603411674, + 0.00882721971720457, + 0.0029782119672745466, + 0.007922517135739326, + -0.02030068263411522, + -0.029835309833288193, + 0.006661414168775082, + -0.04313879832625389, + -0.001850730157457292, + -0.0035070034209638834, + -0.0070700813084840775, + 0.009637435898184776, + -0.016844747588038445, + -0.026075454428792, + 0.0030682040378451347, + 0.004208600614219904, + -0.005515689495950937, + -0.018976539373397827, + -0.019196776673197746, + -0.008948019705712795, + 0.016215825453400612, + 0.00296461652033031, + 0.14222395420074463, + -0.029066482558846474, + -0.011013337410986423, + -0.01267730537801981, + -0.004976287949830294, + -0.016607511788606644, + -0.0005681798211298883, + -0.012520174495875835, + -0.0015903630992397666, + -0.0013642794219776988, + -0.21956196427345276, + -0.0011431180173531175, + -0.0008808697457425296, + -0.022889399901032448, + 0.024718068540096283, + -0.054929111152887344, + -0.015585094690322876, + -0.018188318237662315, + -0.0008287815726362169, + -0.01957552134990692, + 0.10818513482809067, + -0.0034382494632154703, + -0.02667389065027237, + -0.01304248720407486, + -0.0034645304549485445, + -0.008519704453647137, + -0.015123830176889896, + -0.008219013921916485, + -0.009952309541404247, + -2.3375787350232713e-05, + -0.012512428686022758, + -0.001955948770046234, + -0.0029842876829206944, + -0.004291659686714411, + 0.006655955221503973, + 0.007771315053105354, + 0.014132227748632431, + -0.007390063256025314, + -0.024650415405631065, + -0.022503213956952095, + 0.0032607221510261297, + -0.008497492410242558, + 0.00860870536416769, + 0.002819088753312826, + -0.01841069757938385, + -0.010009711608290672, + -0.2912862300872803, + 0.017160022631287575, + 0.11349690705537796, + -0.027656083926558495, + -0.04482223838567734, + -0.019336597993969917, + 0.07413014769554138, + 0.014554106630384922, + 0.020965611562132835, + -0.028231356292963028, + -0.0582813061773777, + 0.05617539584636688, + -0.05042734369635582, + 0.025630727410316467, + -0.0956532284617424, + -0.14554104208946228, + -0.020851148292422295, + 0.006990485824644566, + 0.08457829803228378, + -0.11314752697944641, + 0.004020951222628355, + -0.03477870300412178, + 0.005594289395958185, + 0.011181964538991451, + 0.010988114401698112, + 0.019416088238358498, + 0.026451971381902695, + -0.00452260859310627, + 0.0004952011513523757, + 0.012377702631056309, + -0.0063480171374976635, + 0.0256175734102726, + -0.020753338932991028, + 0.03223377838730812, + -0.1147943064570427, + -0.009170151315629482, + 0.015267477370798588, + -0.0009072314132936299, + -0.1621374636888504, + 0.022807778790593147, + 0.007394107989966869, + 0.01378557924181223, + -0.10719677805900574, + -0.000919080339372158, + -0.006567052565515041, + -0.007409179583191872, + -0.007469762582331896, + -0.004784661345183849, + -0.03967805579304695, + 0.015857066959142685, + -0.02015744335949421, + 0.056037548929452896, + 0.03962035849690437, + 0.08429893851280212, + 0.022117067128419876, + -0.2675061821937561, + 0.016738418489694595, + 0.0037785861641168594, + 0.004771686624735594, + -0.134505033493042, + -0.010618447326123714, + -0.004784524440765381, + 0.014044507406651974, + -0.03105556219816208, + 0.05049083009362221, + 0.012162688188254833, + 0.005920265335589647, + 0.008554516360163689, + 0.0025892227422446012, + 0.023483717814087868, + -0.20711173117160797, + 0.03360452130436897, + -0.24758699536323547, + -0.05136318504810333, + -0.015016172081232071, + 0.06466241925954819, + 0.023470288142561913, + 0.023495715111494064, + 0.004300899337977171, + 0.02461574412882328, + 0.025745516642928123, + -0.026187308132648468, + 0.08441776037216187, + -0.06955462694168091, + -0.11116205900907516, + -0.2169608771800995, + -0.004244703333824873, + -0.024184226989746094, + -0.10068271309137344, + -0.021129190921783447, + -0.021129680797457695, + -0.0054467362351715565, + 0.17416934669017792, + 0.015367642976343632, + -0.01237915363162756, + 0.024573752656579018, + 0.004588739015161991, + 0.05616860091686249, + -0.0018992060795426369, + -0.12394066900014877, + -0.03691404312849045, + -0.15878455340862274, + 0.10572423785924911, + 0.014409378170967102, + -0.008566108532249928, + -0.20319701731204987, + -0.018277373164892197, + -0.21615462005138397, + -0.11269525438547134, + -0.2767113745212555, + -0.25617966055870056, + -0.0036413148045539856, + -0.008058675564825535, + -0.051732294261455536, + -0.013052727095782757, + 0.05229722708463669, + -0.03535814583301544, + 0.3111231327056885, + -0.044130608439445496, + -0.02232682704925537, + -0.0040402463637292385, + 0.013798556290566921, + -0.07689940929412842, + -0.028940049931406975, + -0.00565366679802537, + -0.028972560539841652, + -0.007728889584541321, + 0.013665011152625084, + -0.014678380452096462, + -0.06747694313526154, + -0.06480871140956879, + -0.00028885426581837237, + -0.01525174267590046, + 0.027096102014183998, + -0.05200905352830887, + 0.0066903820261359215, + 0.0023834225721657276, + -0.002379713812842965, + -0.0208051148802042, + 0.335977703332901, + 0.03895771875977516, + -0.04814215749502182, + -0.037339694797992706, + -0.004409746266901493, + 0.07042848318815231, + -0.08318590372800827, + -0.04138712212443352, + 0.06309781968593597, + 0.007484383415430784, + 0.09696535021066666, + 0.024134323000907898, + -0.009859816171228886, + -0.06243982911109924, + 0.04630015045404434, + -0.06593744456768036, + 0.009306293912231922, + 0.5033899545669556, + 0.007804783061146736, + 0.024170484393835068, + -0.036085959523916245, + 0.016438491642475128, + 0.01678072288632393, + -0.006299734115600586, + -0.027441656216979027, + -0.014344800263643265, + 0.022293711081147194, + 0.011197407729923725, + -0.0026971842162311077, + 0.2685070335865021, + 0.01403988990932703, + -0.005100077483803034, + -0.026031343266367912, + -0.005419034510850906, + -0.014735087752342224, + -0.0283498577773571, + 0.002656748052686453, + -0.07137783616781235, + 0.02235356532037258, + -0.02970476634800434, + 0.20672672986984253, + 0.017398398369550705, + 0.02438206970691681, + 0.025746773928403854, + -0.03279582038521767, + 0.043908532708883286, + -0.003417646512389183, + 0.020200302824378014, + 0.007243862375617027, + -0.004560714587569237, + -0.01142876222729683, + -0.028091270476579666, + -0.2949703335762024, + 0.0729827880859375, + 0.004566277377307415, + 0.16689160466194153, + 0.034872010350227356, + -0.09590360522270203, + -0.13309867680072784, + 0.06429398059844971, + 0.04174232855439186, + -0.022723963484168053, + -0.04695400968194008, + 0.013115685433149338, + 0.013574879616498947, + 0.04794493317604065, + -0.015077140182256699, + 0.09493618458509445, + 0.008845972828567028, + 0.020302923396229744, + 0.02037016488611698, + 0.009083293378353119, + 0.0747746080160141, + -0.008078188635408878, + 0.024796344339847565, + -0.015212535858154297, + -0.005867444910109043, + 0.08309170603752136, + 0.03676094114780426, + 0.07232356816530228, + -0.3577176630496979, + 0.0013658110983669758, + -0.0009247250854969025, + 0.02284996211528778, + 0.012630275450646877, + 0.013745593838393688, + 0.003447894938290119, + 0.03563565015792847, + -0.031025355681777, + -0.07258180528879166, + -0.13482442498207092, + -0.029425248503684998, + -0.014927731826901436, + 0.045984312891960144, + -0.0176406130194664, + -0.22678181529045105, + -0.025248311460018158, + -0.11617762595415115, + -0.056157518178224564, + 0.009453062899410725, + -0.34616726636886597, + 0.05691010504961014, + -0.32302799820899963, + -0.026544231921434402, + -0.007374088745564222, + -0.07682909071445465, + -0.021214107051491737, + -0.07102422416210175, + 0.02693488635122776, + 0.014817211776971817, + 0.015572831965982914, + 0.04313618317246437, + -0.1277216374874115, + 0.02174532599747181, + -0.0226149819791317, + -0.00010956164624076337, + 0.023728065192699432, + 0.008212783373892307, + 0.010561724193394184, + -0.011036543175578117, + -0.022485855966806412, + 0.008243439719080925, + -0.03383245691657066, + -0.5630682110786438, + 0.0015974265988916159, + -0.28416821360588074, + 0.04123701527714729, + -0.0042976438999176025, + 0.03786511346697807, + 0.01862393692135811, + -0.04082413762807846, + -0.05792848393321037, + 0.0068894242867827415, + 0.0024085959885269403, + 0.001471342402510345, + 0.030681759119033813, + -0.026314062997698784, + 0.0555737242102623, + 0.03169534355401993, + 0.0031395808327943087, + 0.018701769411563873, + -0.5604594945907593, + 0.01526441890746355, + -0.00621993700042367, + 0.0009401043644174933, + 0.01587403193116188, + 0.030135583132505417, + -0.007350685074925423, + 0.006527469493448734, + 0.016000108793377876, + -0.042957425117492676, + 0.018247080966830254, + 0.0025622656103223562, + -0.03169511258602142, + 0.09235119074583054, + -0.013365034945309162, + 0.01607452519237995, + 0.017734844237565994, + 0.05609896034002304, + 0.04819876700639725, + -0.0871855691075325, + 0.05157865956425667, + 0.009171447716653347, + 0.022200705483555794, + -0.005507844965904951, + -0.024452703073620796, + 0.010224574245512486, + -0.006914906669408083, + 0.004650818649679422, + 0.02167516015470028, + 0.10456826537847519, + -0.07652094960212708, + -6.050072988728061e-05, + 0.012855490669608116, + 0.022669879719614983, + 0.022655120119452477, + 0.033012885600328445, + 0.025709744542837143, + 0.00481270719319582, + 0.005920717027038336, + -0.08545156568288803, + -0.004363589454442263, + -0.01531639602035284, + 0.030760569497942924, + 0.02796284481883049, + -0.03690989315509796, + 0.044959694147109985, + -0.14276015758514404, + -0.0002254673163406551, + -0.15694372355937958, + 0.012381293810904026, + -0.021977441385388374, + 0.005496624857187271, + -0.035593707114458084, + -0.0950438603758812, + 0.03825876861810684, + 0.05915532633662224, + -0.023323312401771545, + 0.017213119193911552, + -0.03807183355093002, + 0.02619507722556591, + 0.02741156332194805, + 0.005847832188010216, + 0.0020307491067796946, + 0.025714349001646042, + -0.04780200496315956, + 0.010206928476691246, + -0.01345440000295639, + 0.029133174568414688, + -0.0014764482621103525, + 0.004046705551445484, + -0.007725241594016552, + 0.013041527941823006, + 0.0018969239899888635, + 0.002417983952909708, + -0.010975837707519531, + 0.0015862436266615987, + 0.00597577728331089, + 0.002882696921005845, + 0.02855525352060795, + -0.005954153370112181, + 0.04090835899114609, + -0.39500924944877625, + 0.03586621209979057, + -0.5250031352043152, + -0.05697731301188469, + -0.09568691998720169, + -0.07179264724254608, + 0.04683076590299606, + 0.009320023469626904, + -0.11629963666200638, + -0.0016945215174928308, + 0.01624997705221176, + -0.0063682254403829575, + 0.15033549070358276, + -0.5171176791191101, + -0.01525783073157072, + 0.016417231410741806, + -0.00303818890824914, + 0.2500321865081787, + 0.022074062377214432, + 0.01191191840916872, + 0.012274803593754768, + 0.016534989699721336, + -0.028437916189432144, + 0.04241323843598366, + -0.01824999786913395, + -0.34815871715545654, + 0.04734490439295769, + -0.06419701874256134, + -0.022288290783762932, + -0.0004865761147812009, + 0.05369419604539871, + -0.058212973177433014, + -0.2196469008922577, + 0.010950890369713306, + 0.029042819514870644, + -0.07349151372909546, + -0.0422789566218853, + 0.062069639563560486, + 0.05589267984032631, + 0.014877256006002426, + 0.04236084595322609, + 0.03975239768624306, + 0.16930873692035675, + 0.03981085494160652, + 0.11499395221471786, + 0.0271450225263834, + 0.013969083316624165, + -0.0002660648606251925, + 0.010936664417386055, + -0.18389767408370972, + -0.10237602889537811, + 0.03041323646903038, + -0.013864071108400822, + -0.015729930251836777, + 0.037400804460048676, + -0.009598327800631523, + -0.09533312171697617, + -0.014712700620293617, + 0.08537333458662033, + -0.007200485561043024, + -0.31139102578163147, + -0.06366845220327377, + 0.02039063163101673, + -0.023356139659881592, + -0.0029549277387559414, + -0.12494662404060364, + 0.011755092069506645, + -0.26468148827552795, + -0.11541861295700073, + 0.010529865510761738, + -0.05965733155608177, + -0.05945499241352081, + -0.08796169608831406, + -0.014683439396321774, + 0.008732054382562637, + 0.010073489509522915, + 0.09553763270378113, + 0.034884922206401825, + 0.018675342202186584, + -0.009549405425786972, + -0.0007051719003356993, + -0.16936513781547546, + -0.0030460187699645758, + -0.022060535848140717, + -0.06689190864562988, + 0.013926704414188862, + 0.012043816037476063, + -0.0587068572640419, + -0.03814113140106201, + 0.06235629320144653, + 0.013228330761194229, + 0.04154474660754204, + -0.08039120584726334, + 0.028436705470085144, + -0.042226389050483704, + -0.019135186448693275, + 0.03747033327817917, + -0.14261123538017273, + 0.02827540971338749, + 0.0455685593187809, + -0.031124960631132126, + -0.007588588632643223, + 0.0034326373133808374, + -0.07682976871728897, + 0.24654042720794678, + -0.014518304727971554, + -0.07052458822727203, + -0.08241941034793854, + -0.04116151109337807, + -0.048463717103004456, + -0.038745298981666565, + 0.036902472376823425, + 0.0442035011947155, + 0.05572585016489029, + -0.014312628656625748, + 0.010794793255627155, + -0.3440641760826111, + -0.5161325335502625, + 0.0005156552069820464, + -0.010257269255816936, + -0.02412656880915165, + -0.023385023698210716, + 0.05533458665013313, + -0.012186119332909584, + -0.029286568984389305, + 0.04116401448845863, + -0.044610101729631424, + -0.019175484776496887, + 0.06835268437862396, + 0.06366674602031708, + 0.0373748242855072, + 0.03804386034607887, + 0.05369521677494049, + -0.04451881721615791, + 0.0018838117830455303, + 0.34775662422180176, + 0.010958605445921421, + -0.047990139573812485, + 0.04386777803301811, + -0.10427688807249069, + 0.04417382925748825, + 4.402965714689344e-05, + 0.01935163326561451, + -0.06753949075937271, + 0.02735923044383526, + 0.01465953141450882, + 0.06198301538825035, + -0.015980403870344162, + -0.2108263075351715, + 0.008177559822797775, + 0.006046924740076065, + 0.002665479900315404, + 0.20868580043315887, + -0.013740362599492073, + 0.008203004486858845, + -0.005066391546279192, + 0.026405498385429382, + 0.01383009273558855, + 0.012581533752381802, + 0.009014940820634365, + 0.022820021957159042, + -0.008534795604646206, + 0.2603924572467804, + 0.02297227643430233, + -0.000749691273085773, + 0.044753506779670715, + 0.018596511334180832, + 0.006852792575955391, + -0.008686172775924206, + -0.10452616959810257, + 0.017021872103214264, + 0.003722329391166568, + -0.025453045964241028, + -0.011473417282104492, + -0.017907623201608658, + 0.01400628499686718, + -0.1670989990234375, + 0.004298652987927198, + -0.0022204748820513487, + 0.16521315276622772, + -0.008831127546727657, + 0.026490870863199234, + 0.006190746556967497, + -0.0177209060639143, + 0.08967147767543793, + 0.0033069502096623182, + -0.005021366756409407, + 0.0004906906979158521, + 0.0169216375797987, + -0.06124846637248993, + -0.005200678016990423, + 0.08404737710952759, + -0.010559299029409885, + -0.006309974938631058, + 0.023113396018743515, + -0.010227260179817677, + 0.001256447983905673, + 0.019783375784754753, + -0.006308461539447308, + -0.04529590904712677, + -0.00908862054347992, + -0.043217338621616364, + -0.32200074195861816, + 0.02592635713517666, + 0.030795685946941376, + -0.001814531977288425, + 0.0092842485755682, + 0.07088880985975266, + -0.0867588147521019, + 0.024099843576550484, + -0.0034031609538942575, + 0.007234686985611916, + -0.02505563199520111, + 0.0030480287969112396, + -0.019158190116286278, + 0.26473408937454224, + -0.011918547563254833, + -0.023240016773343086, + -0.06084466353058815, + -0.021916134282946587, + -0.010251260362565517, + -0.0009625791572034359, + 0.082605741918087, + -0.013018425554037094, + 0.007627277635037899, + -0.0010813736589625478, + 0.007952406071126461, + 0.06551267951726913, + -0.026020025834441185, + 0.050048135221004486, + -0.010610008612275124, + -0.02429312653839588, + -0.025263017043471336, + -0.04611891135573387, + 0.04451768472790718, + -0.08045025914907455, + -0.048037610948085785, + 0.008019295521080494, + 0.0160224549472332, + 0.002078550634905696, + -0.0202508345246315, + -0.5446130633354187, + 0.012585492804646492, + -0.0331973135471344, + 0.08371605724096298, + -0.00590998912230134, + -0.013058983720839024, + 0.027742384001612663, + 0.1042199358344078, + -0.3072803318500519, + 0.06284149736166, + -0.28551968932151794, + 0.026768438518047333, + 0.022245990112423897, + 0.018242113292217255, + -0.035077981650829315, + 0.03546127676963806, + 0.10165776312351227, + -0.025475669652223587, + -0.014933750964701176, + 0.040547240525484085, + -0.033055808395147324, + 0.011755919083952904, + -0.014459444209933281, + -0.03455093130469322, + 0.020743343979120255, + 0.02720930427312851, + -0.287664532661438, + 0.008260028436779976, + -0.009877690114080906, + 0.16657423973083496, + -0.010943812318146229, + -0.012381386943161488, + 0.030678801238536835, + 0.1559792459011078, + 0.038967035710811615, + -0.023399239405989647, + 0.015019542537629604, + -0.014201333746314049, + -0.014202176593244076, + -0.006699408870190382, + -0.13175444304943085, + 0.004643211141228676, + 0.012747463770210743, + -0.04086190089583397, + 0.06581410765647888, + -0.12192045897245407, + -0.03126347437500954, + 0.011175516061484814, + -0.00914736744016409, + -0.02883930690586567, + -0.11305265873670578, + -0.04405384883284569, + -0.009120048955082893, + -0.008926079608500004, + -0.03169447183609009, + 0.05464877560734749, + 0.25674498081207275, + 0.08497058600187302, + -0.023222925141453743, + 0.35592252016067505, + -0.006929511670023203, + 0.025255810469388962, + -0.05150032415986061, + 0.039239466190338135, + -0.07082924991846085, + -0.017321549355983734, + 0.17293211817741394, + -0.02155853807926178, + -0.014333213679492474, + 0.0031305316369980574, + -0.013490653596818447, + -0.1376512199640274, + -0.021713266149163246, + -0.029826253652572632, + -0.0011473714839667082, + -0.012434332631528378, + -0.04860873892903328, + 0.013857590034604073, + 0.0703854188323021, + 0.034528713673353195, + -0.014423011802136898, + 0.0882454589009285, + -0.091700978577137, + 0.038885727524757385, + 0.012043441645801067, + -0.03183690831065178, + -0.014495689421892166, + -0.019726552069187164, + -0.010094117373228073, + -0.004218627233058214, + -0.04413086175918579, + -0.1344134360551834, + -0.0004976870259270072, + -0.0008357573533430696, + 0.04518067091703415, + 0.046797975897789, + 0.24766182899475098, + 0.01065139751881361, + -0.0034267394803464413, + -0.016103556379675865, + -0.05139121413230896, + 0.012563390657305717, + -0.03310413286089897, + -0.030157553032040596, + 0.046670909970998764, + 0.012565785087645054, + -0.040275491774082184, + 0.023816417902708054, + -0.38536572456359863, + 0.04508889466524124, + 0.13637560606002808, + -0.010654824785888195, + 0.0459851399064064, + -0.0046302699483931065, + -0.020852191373705864, + 0.10662271827459335, + 0.06486576050519943, + 0.05727925896644592, + 0.09816201776266098, + 0.04878557100892067, + -0.16256237030029297, + 0.014547038823366165, + 0.018567964434623718, + -0.07284612208604813, + 0.017150163650512695, + 0.0246741883456707, + -0.38470372557640076, + -0.07465949654579163, + 0.03010236658155918, + -0.004397575277835131, + -0.06618984788656235, + -0.02908281609416008, + 0.060166433453559875, + -0.0020949048921465874, + 0.007689109072089195, + -0.0047390698455274105, + -0.014199030585587025, + -0.01794746331870556, + -0.02528063952922821, + 0.002218312583863735, + 0.10169881582260132, + 0.010602130554616451, + -0.06605861335992813, + -0.0008762837387621403, + -0.035027723759412766, + -0.011684391647577286, + 0.02247578091919422, + 0.17245104908943176, + 0.22525252401828766, + -0.010771296918392181, + 0.05595310404896736, + 0.06338834017515182, + -0.0038216698449105024, + -0.0032836494501680136, + 0.005779017228633165, + -0.18020786345005035, + -0.05066698044538498, + -0.0035458216443657875, + -0.10578767210245132, + -0.041712939739227295, + 0.2104150652885437, + -0.03753345459699631, + 0.013989892788231373, + 0.01988149993121624, + 0.05108603090047836, + 0.04496738687157631, + -0.3034508526325226, + 0.0226743221282959, + -0.0431472510099411, + -0.025635428726673126, + -0.18961989879608154, + -0.17218825221061707, + 0.03576141223311424, + 0.060613714158535004, + -0.011970550753176212, + -0.21435107290744781, + 0.01422552578151226, + 0.02974064089357853, + -0.061079952865839005, + 0.031064646318554878, + 0.009629320353269577, + -0.13762925565242767, + 0.01928475871682167, + 0.007310172542929649, + 0.06103459745645523, + -0.16216528415679932, + 0.03330384939908981, + 0.09578404575586319, + -0.0037327276077121496, + 0.029233848676085472, + -0.0015759399393573403, + 0.005511409603059292, + -0.4195749759674072, + 0.024169376119971275, + 0.13220365345478058, + 0.007961929775774479, + 0.008045470342040062, + 0.01919495314359665, + -0.023188553750514984, + 0.07084394991397858, + -0.24922333657741547, + 0.02011212892830372, + -0.18514998257160187, + 0.03114209696650505, + 0.09826567023992538, + 0.00592303741723299, + -0.010020115412771702, + 0.027117054909467697, + -0.214133620262146, + -0.01214816514402628, + 0.06564164906740189, + 0.02513044886291027, + 0.02132420241832733, + -0.02127540111541748, + -0.041606876999139786, + 0.04196378216147423, + -0.02060609683394432, + 0.01730814389884472, + -0.17418994009494781, + 0.03462710976600647, + -0.017470642924308777, + -0.3992193639278412, + 0.02652592957019806, + 0.025042008608579636, + 0.026447610929608345, + -0.19199316203594208, + 3.27593952533789e-05, + 0.002988220192492008, + -0.21171888709068298, + 0.03300239518284798, + 0.015727035701274872, + -0.008947308175265789, + 0.03924538940191269, + -0.08990193158388138, + 0.023726975545287132, + 0.03463870286941528, + -0.05018220469355583, + 0.13170146942138672, + 0.054000236093997955, + 0.01158218178898096, + 0.062349993735551834, + -0.014724616892635822, + 0.039657603949308395, + 0.04436490684747696, + 0.014076294377446175, + 0.07666806876659393, + 0.09630247205495834, + -0.04152659326791763, + -0.1860806941986084, + -0.07671733945608139, + 0.031573690474033356, + -0.44617798924446106, + -0.004897239152342081, + -0.03991628438234329, + 0.01880800537765026, + -0.04769768565893173, + 0.02198435738682747, + 0.01341161783784628, + -0.12239313870668411, + 0.019765935838222504, + 0.005221452098339796, + -0.025201082229614258, + 0.005132562946528196, + 0.08668412268161774, + 0.0035341952461749315, + 0.008583099581301212, + 0.032979920506477356, + 0.03324040770530701, + 0.04411708936095238, + -0.008390798233449459, + 0.040486790239810944, + -0.059673551470041275, + 0.02003314346075058, + -0.0990666076540947, + 0.03971675783395767, + 0.012021057307720184, + 0.0017271327087655663, + 0.01818535290658474, + 0.0025106174871325493, + 0.043714240193367004, + 0.019146842882037163, + -0.0041794623248279095, + 0.033447377383708954, + 0.06863203644752502, + -0.004350902978330851, + 0.0113364327698946, + -0.05825724080204964, + -0.04649435728788376, + -0.10618306696414948, + 0.02653644233942032, + 0.012514552101492882, + 0.019399365410208702, + -0.0022177041973918676, + 0.017741208896040916, + 0.04115311801433563, + 0.05122101679444313, + 0.055051617324352264, + 0.01687677949666977, + -0.03698579967021942, + 0.10053858160972595, + -0.007528421934694052, + 0.003968802746385336, + 0.02458524890244007, + -0.02144794538617134, + 0.026791265234351158, + -0.016701897606253624, + 0.014119372703135014, + -0.03460531681776047, + -0.02320348098874092, + 0.056146953254938126, + 0.028700685128569603, + -0.14820916950702667, + -0.016996873542666435, + 0.025667931884527206, + 0.08408629894256592, + 0.00034475952270440757, + 0.007573155220597982, + 0.06784884631633759, + 0.025982951745390892, + -0.08363039791584015, + -0.015748541802167892, + -0.0029514851048588753, + -0.01523523684591055, + 0.10500328987836838, + 0.3070858418941498, + -0.024624783545732498, + 0.0058471946977078915, + -0.039751242846250534, + 0.0012745993444696069, + -0.0796508714556694, + 0.024727927520871162, + 0.056764136999845505, + -0.013338261283934116, + -0.04794292524456978, + -0.02609768509864807, + -0.010784422047436237, + -0.048712026327848434, + 0.020345501601696014, + 0.0021618579048663378, + -0.0021724768448621035, + 0.03056410700082779, + -0.01633712649345398, + -0.47168225049972534, + -0.014639903791248798, + -0.012550815008580685, + 0.03358187526464462, + 0.07889427989721298, + -0.03615899011492729, + -0.002809660043567419, + -0.006953644100576639, + 0.02024337276816368, + -0.0738825723528862, + -0.006984011270105839, + -0.04472561925649643, + -0.027498915791511536, + 0.07207506150007248, + -0.09166522324085236, + -0.008861960843205452, + 0.05264359340071678, + 0.01889069564640522, + -0.1380404680967331, + -0.010141258127987385, + 0.015403619967401028, + -0.16416165232658386, + -0.03529815003275871, + 0.042106859385967255, + 0.11173021793365479, + -0.3143587112426758, + 0.011045016348361969, + 0.0012351945042610168, + 0.03840603306889534, + 0.0685538575053215, + -0.000746160454582423, + -0.028142500668764114, + 0.027154160663485527, + 0.005731801502406597, + 0.04433267563581467, + -0.8158469796180725, + 0.02226361259818077, + -0.07650655508041382, + 0.026958195492625237, + -0.005810025613754988, + -0.020102059468626976, + -0.0019310436910018325, + 0.07697021961212158, + -0.057701658457517624, + 0.05954534560441971, + 0.0027106746565550566, + -0.06311310827732086, + 0.011713752523064613, + -0.0034454476553946733, + -0.0006881420267745852, + 0.08937360346317291, + -0.0008253820124082267, + -0.031066063791513443, + -0.14708301424980164, + -0.04438449814915657, + 0.004772413522005081, + 0.05992274731397629, + 0.07473544776439667, + -0.1784757375717163, + -0.19057415425777435, + -0.014637955464422703, + -0.24898527562618256, + 0.13606221973896027, + -0.018039124086499214, + -0.047193415462970734, + -0.06526428461074829, + 0.04075757786631584, + 0.049901530146598816, + -0.008585861884057522, + 0.01616351678967476, + -3.091737016802654e-05, + 0.024283329024910927, + 0.008861682377755642, + -0.0005823548417538404, + 0.0997646301984787, + 0.051001910120248795, + 0.009473294951021671, + -0.0032046104315668344, + 0.018362928181886673, + 0.008627718314528465, + -0.4148157835006714, + -0.016077928245067596, + 0.0745391696691513, + 0.00724065862596035, + 0.08948155492544174, + 0.11626332253217697, + -0.052439428865909576, + 0.005599102005362511, + 0.002622961765155196, + 0.07586965709924698, + 0.03274847939610481, + -0.02099076844751835, + -0.04666733741760254, + -0.0013019372709095478, + 0.04945925995707512, + 0.11393380910158157, + 0.006346395239233971, + 0.04721064493060112, + 0.010331138968467712, + 0.08918803185224533, + 0.04288423806428909, + -0.09234773367643356, + 0.020141584798693657, + -3.256054696976207e-05, + -0.02799108810722828, + 0.018966441974043846, + -0.4136410355567932, + -0.07217283546924591, + 0.01840362884104252, + -0.055327851325273514, + 0.003275467548519373, + -0.017174070701003075, + -0.032178670167922974, + 0.09021560847759247, + -0.524413526058197, + 0.01994725503027439, + 0.10380692034959793, + -0.01043684035539627, + -0.00011200909648323432, + 0.01331041194498539, + 0.020127851516008377, + -0.025159789249300957, + 0.05252581834793091, + 0.04759140685200691, + 0.0032084162812680006, + -0.03579062595963478, + 0.054719552397727966, + -0.04674411937594414, + 0.028389262035489082, + 0.001127603929489851, + -0.0006243048119358718, + -0.00550495833158493, + -0.022523507475852966, + -0.024282312020659447, + 0.009519628249108791, + -0.39908328652381897, + -0.009265545755624771, + -0.00037090369733050466, + 0.06425131112337112, + -0.05998316407203674, + -0.015221518464386463, + -0.004825026262551546, + 0.11847284436225891, + -0.011302731931209564, + -0.006884834263473749, + -0.04678218811750412, + -0.012078279629349709, + 0.021638741716742516, + -0.016819776967167854, + -0.009127719327807426, + -0.002491263672709465, + 0.0016752213705331087, + -0.016600262373685837, + 0.011772023513913155, + -0.013447183184325695, + -0.020662957802414894, + -0.011593316681683064, + 0.008270744234323502, + -0.0026990456972271204, + -0.004406482446938753, + -0.023110052570700645, + -0.00208942755125463, + -0.1711198389530182, + 0.012432538904249668, + -0.0045453268103301525, + 0.024807902052998543, + -0.0035043740645051003, + -0.004001997876912355, + -0.013488625176250935, + -0.02020987868309021, + -0.01216109935194254, + -0.004432092886418104, + 0.09323672950267792, + -0.015641510486602783, + -0.019307948648929596, + 0.01117538008838892, + -0.01422040443867445, + 0.01705607771873474, + -0.0029596879612654448, + -0.0021530911326408386, + -0.006551788654178381, + 0.00429268553853035, + -0.1620807945728302, + -0.014128226786851883, + -0.005428737495094538, + -0.006771362852305174, + 0.005730633158236742, + 0.0007243106956593692, + 0.0024031582288444042, + -0.00199915561825037, + 0.006133859045803547, + -0.013380909338593483, + 0.00733462069183588, + -0.001863821060396731, + -0.0020169683266431093, + -0.014070986770093441, + -0.006501683499664068, + -0.029421553015708923, + 0.0009377509704791009, + -0.01718256250023842, + -0.05819401144981384, + -0.018859732896089554, + 0.0010356366401538253, + 0.006394123658537865, + -0.021985618397593498, + -0.01204769592732191, + -0.002014884725213051, + -0.019398409873247147, + -0.013122898526489735, + -0.017277296632528305, + -0.002270353492349386, + -0.05294327810406685, + -0.020317314192652702, + -0.018196573480963707, + -0.010375416837632656, + -0.019704729318618774, + -0.016109557822346687, + -0.0167380403727293, + -0.0285252146422863, + -0.02665277197957039, + -0.03554505482316017, + -0.00741522666066885, + -0.013580105267465115, + -0.026335405185818672, + -0.011694515123963356, + -0.004639182705432177, + -0.03996071219444275, + -0.022463932633399963, + -0.007204636000096798, + -0.021065134555101395, + -0.014410646632313728, + 0.0035447971895337105, + -0.0013098351191729307, + -0.024171002209186554, + 0.00047751085367053747, + -0.01870289072394371, + -0.06016797944903374, + -0.025703946128487587, + -0.009730588644742966, + -0.021792838349938393, + -0.024519823491573334, + -0.01843440905213356, + -0.0016325484029948711, + -0.008116388693451881, + -0.017774557694792747, + -0.04375867918133736, + -0.03893980756402016, + -0.018188582733273506, + -0.007122726645320654, + -0.028115490451455116, + -0.01821342669427395, + -0.01011319737881422, + -0.02616124413907528, + -0.013797983527183533, + -0.03202736750245094, + -0.030110370367765427, + -0.01883666031062603, + -0.01185502391308546, + -0.006012012716382742, + -0.017311619594693184, + -0.022577986121177673, + -0.02101938985288143, + 0.0025952248834073544, + -0.005058783106505871, + -0.004162575118243694, + -0.01559755764901638, + -0.017923563718795776, + -0.04231095686554909, + -0.017630560323596, + -0.011938830837607384, + -0.01587115228176117, + 0.004972478374838829, + -0.016601158306002617, + 0.15419845283031464, + 0.0009241115767508745, + 0.051028184592723846, + 0.008128340356051922, + -0.019917558878660202, + -0.0010339801665395498, + 0.022349294275045395, + -0.0072520882822573185, + 0.0017750378465279937, + -0.10526080429553986, + 0.03420695662498474, + 0.019183926284313202, + -0.0006544998032040894, + -0.0032203509472310543, + -0.01216941885650158, + -0.03561796247959137, + 0.024905826896429062, + -0.026948239654302597, + -0.01913355104625225, + -0.014459407888352871, + 0.006972283590584993, + -0.033184293657541275, + 0.04884861409664154, + -0.002296984428539872, + -0.19194477796554565, + 0.00392142403870821, + 0.009490449912846088, + -0.02687196619808674, + -0.06327224522829056, + -0.03684951737523079, + -0.0002613202668726444, + -0.012086644768714905, + 0.03630973398685455, + 0.007296048104763031, + 0.011186012998223305, + 0.0074085514061152935, + -0.020394617691636086, + -0.010585476644337177, + -0.030289918184280396, + 0.0773506686091423, + 0.008841303177177906, + 0.019423579797148705, + 0.001184571417979896, + 0.005553434602916241, + 0.015373414382338524, + -0.0027953842654824257, + 0.013204757124185562, + 0.029097743332386017, + 0.012627501040697098, + 0.02102004364132881, + -0.09469914436340332, + -0.023324014618992805, + 0.029243655502796173, + 0.002979277865961194, + -0.004492263309657574, + 0.20549021661281586, + -0.3244459927082062, + 0.025892559438943863, + 0.009620796889066696, + -0.05520407855510712, + -0.02271144650876522, + 0.008378816768527031, + -0.0671214610338211, + -0.016056722030043602, + -0.02355658821761608, + 0.0005429868469946086, + -0.007960098795592785, + 0.02513299137353897, + -0.13005328178405762, + -0.0025323680602014065, + -0.02197088487446308, + -0.02404806576669216, + 0.08261960744857788, + 0.17078880965709686, + 0.02880753017961979, + -0.03642067685723305, + 0.021994341164827347, + -0.012368184514343739, + -0.10681373625993729, + 0.16371481120586395, + 0.17881983518600464, + -0.10202010720968246, + -0.08641688525676727, + -0.1259487271308899, + 0.06907707452774048, + 0.023792706429958344, + -0.02534419298171997, + 0.016984017565846443, + -0.06743635982275009, + 0.08445960283279419, + -0.08037827908992767, + -0.11935994029045105, + -0.31716489791870117, + -0.01860150322318077, + 0.060669515281915665, + -0.06137414649128914, + 0.09878886491060257, + 0.01794014871120453, + 0.12382296472787857, + -0.016424886882305145, + 0.09045679122209549, + -0.02998783066868782, + -0.00972777884453535, + -0.024124544113874435, + 0.09879253059625626, + 0.05500243604183197, + -0.06635259836912155, + 0.11268552392721176, + 0.011751363053917885, + -0.04690232127904892, + -0.025168607011437416, + 0.088335320353508, + -0.1140628531575203, + 0.04129032790660858, + -0.04258979484438896, + -0.0903872698545456, + 0.008473021909594536, + -0.026690304279327393, + -0.051559556275606155, + -0.05481572076678276, + -0.05251916125416756, + -0.0018165932269766927, + 0.09836867451667786, + 0.0054859439842402935, + 0.06432581692934036, + 0.10621821135282516, + -0.019325286149978638, + -0.028727786615490913, + 0.014013150706887245, + -0.008022608235478401, + -0.006281842477619648, + -0.0297000203281641, + 0.01525485422462225, + -0.4346403479576111, + 0.07787995040416718, + -0.25380268692970276, + 0.05261845141649246, + 0.010875157080590725, + 0.0014149334747344255, + 0.05021188035607338, + -0.24382442235946655, + 0.0807114690542221, + 0.022907381877303123, + 0.006440790370106697, + -0.017028095200657845, + 0.001552293193526566, + 0.05961666256189346, + -0.14113056659698486, + 0.03398876264691353, + -0.005411976482719183, + -0.014025667682290077, + -0.5433799624443054, + 0.019015472382307053, + 0.04091138765215874, + 0.05059061944484711, + 0.0274446289986372, + -0.010288042947649956, + -0.001335533568635583, + -0.013533512130379677, + 0.018798377364873886, + -0.04099345579743385, + 0.0031264263670891523, + -0.21071769297122955, + -0.014384736306965351, + -0.1045387014746666, + -0.014340974390506744, + 0.001986369490623474, + -0.04118456318974495, + -0.10952988266944885, + 0.049147430807352066, + -0.08382093161344528, + -0.1741400957107544, + -0.0885215476155281, + -0.10934099555015564, + 0.05553343519568443, + 0.02434251271188259, + 0.006634524557739496, + -0.0017163373995572329, + 0.0185443926602602, + 0.06250902265310287, + -0.17145656049251556, + -0.07543934881687164, + 0.026583310216665268, + 0.01634727604687214, + 0.003603539662435651, + -0.2817271649837494, + 0.03882112354040146, + 0.011341865174472332, + 0.00826666783541441, + 0.050427842885255814, + -0.22358834743499756, + 0.06419781595468521, + 0.03245265409350395, + -0.04503164440393448, + -0.023194484412670135, + -0.027968740090727806, + 0.08563586324453354, + 0.07954753190279007, + -0.08513130992650986, + 0.02850884199142456, + 0.008976672776043415, + 0.07886530458927155, + 0.0022273347713053226, + -0.09540755301713943, + 0.032016951590776443, + -0.05196075513958931, + 0.10555616766214371, + 0.07629868388175964, + 0.039732079952955246, + -0.0029798501636832952, + 0.014692343771457672, + 0.09200941026210785, + -0.04299614951014519, + -0.023488566279411316, + -0.01851060427725315, + 0.09257487207651138, + 0.055612049996852875, + 0.06423109769821167, + -0.28587806224823, + -0.09950444847345352, + 0.10397437959909439, + 0.025166453793644905, + -0.03235514089465141, + -0.033381711691617966, + 0.1513858139514923, + 0.06468874961137772, + 0.01928441785275936, + 0.0032701045274734497, + -0.0579083226621151, + -0.022929169237613678, + 0.012971373274922371, + -0.018524186685681343, + -0.06484643369913101, + 0.012233717367053032, + 0.06590451300144196, + -0.04558677598834038, + 0.05253027006983757, + 0.048656731843948364, + -0.2288871705532074, + 0.037114787846803665, + -0.20519588887691498, + 0.0058607361279428005, + -0.002009372925385833, + -0.006671734619885683, + -0.07107856124639511, + -0.07407436519861221, + 0.03941629081964493, + 0.0447598397731781, + 0.03509354963898659, + -0.061107732355594635, + -0.09305761009454727, + -0.012180411256849766, + 0.04902016744017601, + 0.07974442094564438, + -0.016854410991072655, + 0.005089411046355963, + -0.08127597719430923, + 0.03258403390645981, + 0.039813362061977386, + -0.01668727956712246, + 0.027226485311985016, + -0.029213925823569298, + -0.008598217740654945, + 0.00931101106107235, + 0.026936721056699753, + -0.03083401545882225, + -0.05799110606312752, + -0.008277476765215397, + -0.014854338951408863, + -0.20012643933296204, + 0.012290815822780132, + 0.007194168865680695, + 0.06858328729867935, + -0.3296163082122803, + -0.11424986273050308, + 0.009912200272083282, + -0.06211454048752785, + 0.0007546336855739355, + 0.03507614880800247, + 0.10649498552083969, + -0.03036407195031643, + 0.0646015852689743, + -0.01595110446214676, + -0.16919563710689545, + 0.0013865949586033821, + -0.08339446783065796, + 0.06962471455335617, + 0.016058098524808884, + -0.04729780554771423, + 0.010602935217320919, + 0.01470863912254572, + 0.06903938204050064, + 0.014901719056069851, + -0.15120048820972443, + 0.016727851703763008, + 0.05003673583269119, + 0.04370126873254776, + 0.029703885316848755, + 0.021875420585274696, + 0.026293285191059113, + -0.01048936415463686, + 0.00040810942300595343, + -0.015616541728377342, + -0.062451593577861786, + 0.010016348212957382, + -0.06790193170309067, + -0.02077331207692623, + 0.007985175587236881, + -0.04435744881629944, + 0.06920231133699417, + 0.018344474956393242, + 0.028591370210051537, + 0.021957838907837868, + 0.0017570338677614927, + 0.036665257066488266, + 0.015438515692949295, + -0.0006347382441163063, + 0.04621066153049469, + -0.001942177303135395, + 0.010664877481758595, + -0.016754357144236565, + 0.006541184149682522, + -0.027716301381587982, + -0.0058586387895047665, + -0.005346015095710754, + 0.020482052117586136, + 0.06882552057504654, + 0.0026622572913765907, + 0.016321638599038124, + 0.017728103324770927, + -0.13356441259384155, + 0.030281176790595055, + 1.0354949154134374e-05, + 0.050639618188142776, + 0.0013030078262090683, + -0.11136802285909653, + -0.006832807790488005, + -0.09628921747207642, + 0.046699415892362595, + 0.002175685251131654, + 0.008100612089037895, + 0.012449901551008224, + -0.01713990420103073, + -0.000769267207942903, + 0.022544430568814278, + -0.0018787183798849583, + -0.014189678244292736, + 0.37042510509490967, + -0.030317893251776695, + 0.012663356028497219, + -0.04071582853794098, + 0.01653047651052475, + 0.06578584760427475, + 0.005606585182249546, + 0.0029362838249653578, + -0.02035594917833805, + 0.016131827607750893, + -0.06512665003538132, + 0.020292088389396667, + 0.12818951904773712, + -0.00017647731874603778, + 0.0004811069811694324, + 0.013025660999119282, + -0.006004344671964645, + 0.011330580338835716, + 0.0021733916364610195, + -0.0026290342211723328, + 0.008579215034842491, + -0.017107143998146057, + 0.0032798980828374624, + 0.21415431797504425, + -0.011049880646169186, + 0.04915957152843475, + -0.01152863260358572, + 0.01988764852285385, + -0.30189022421836853, + 0.1491061896085739, + 0.022540517151355743, + 0.02323656715452671, + -0.0028044115751981735, + -0.02501249685883522, + 0.0016759912250563502, + 0.023405946791172028, + 0.0865691602230072, + 0.0056661744602024555, + 0.2334042638540268, + -0.05771901085972786, + 0.03428330272436142, + -0.05191519856452942, + 0.025708407163619995, + -0.11474912613630295, + 0.05345827341079712, + 0.050046734511852264, + -0.03785427287220955, + 0.02726786397397518, + 0.008640051819384098, + -0.05810163915157318, + 0.19147679209709167, + 0.12065602838993073, + -0.08667072653770447, + -0.12831886112689972, + 0.027053257450461388, + -0.1771622896194458, + -0.2615586817264557, + 0.112942636013031, + 0.002398239215835929, + 0.00907410029321909, + 0.059947770088911057, + 0.040937639772892, + 0.003431124845519662, + 0.012721046805381775, + -0.10228776186704636, + 0.04169567674398422, + -0.04826785624027252, + -0.021415220573544502, + 0.027615519240498543, + 0.16087181866168976, + 0.03552674129605293, + -0.36409878730773926, + 0.0015418739058077335, + 0.03940089792013168, + -0.12929502129554749, + 0.017082052305340767, + -0.07193783670663834, + 0.10395099222660065, + -0.2240910828113556, + -0.003303584409877658, + -0.0074868109077215195, + -0.13708709180355072, + 0.2098008245229721, + 0.013808795250952244, + -0.03606148064136505, + 0.001965852687135339, + 0.04186573252081871, + 0.02105732634663582, + -0.11873909085988998, + -0.08529136329889297, + 0.0060731275007128716, + 0.04803553968667984, + 0.07665349543094635, + 0.026997262611985207, + 0.05191565304994583, + 0.09013131260871887, + 0.013081093318760395, + 0.04667182266712189, + -0.19899451732635498, + 0.004642056301236153, + 0.0025570227298885584, + -0.2640555500984192, + 0.008254006505012512, + 0.05971720814704895, + -0.002980671590194106, + 0.0011313167633488774, + -0.004445134196430445, + 0.01951296627521515, + -0.006634386721998453, + -0.008033698424696922, + 0.012400158680975437, + -0.15906694531440735, + 0.007047838997095823, + 0.0003521084145177156, + -0.00517050176858902, + -0.0003226286207791418, + -0.01226231548935175, + -0.06750697642564774, + -0.03061128593981266, + -0.0027100055012851954, + 0.004726986400783062, + 0.010185977444052696, + 0.021205933764576912, + -0.05105980113148689, + -0.006725164130330086, + 0.26042309403419495, + 0.003935054875910282, + 0.009450466372072697, + -0.009512278251349926, + 0.036205559968948364, + 0.0066987741738557816, + 0.05687355250120163, + -0.0070350514724850655, + 0.021287698298692703, + 0.004246287513524294, + -0.004053668584674597, + 0.0030501342844218016, + -0.003596516093239188, + 0.00571554945781827, + 0.039099883288145065, + 0.06648323684930801, + 0.011140268296003342, + 0.002779693342745304, + 0.0004113377653993666, + 0.0019621821120381355, + 0.002047213725745678, + -9.034215327119455e-05, + 0.006674906238913536, + -0.024464793503284454, + 4.372629337012768e-05, + 0.04560312256217003, + 0.029951298609375954, + 0.0053787860088050365, + 0.010052027180790901, + 0.0018156497972086072, + 0.001613074098713696, + -0.3710610568523407, + 0.18385423719882965, + 0.0197732076048851, + -2.409513217571657e-05, + 0.043657880276441574, + 0.029824273660779, + -0.0015272254822775722, + -0.0009817760437726974, + 0.030571524053812027, + 0.05133187025785446, + 0.021092001348733902, + -0.022430723533034325, + -0.011050102300941944, + -0.01653454266488552, + 0.00856624636799097, + 0.007617316208779812, + 0.023697074502706528, + -0.00541776092723012, + -0.06940567493438721, + -0.024501511827111244, + 0.0029131292831152678, + 0.005110545549541712, + 0.02394089475274086, + 0.009317552670836449, + -0.05198051407933235, + -0.14872707426548004, + -0.03553030639886856, + 0.05354774370789528, + 0.053996339440345764, + 0.016679847612977028, + -0.4505158066749573, + 0.006403166800737381, + -0.014287465251982212, + 0.010499212890863419, + 0.00510875741019845, + 0.0230255089700222, + -0.04791099205613136, + -0.08405473828315735, + -0.00807158276438713, + -0.016310568898916245, + -0.018034789711236954, + -0.03381670266389847, + 0.038599055260419846, + 0.01189411524683237, + 0.0038598189130425453, + 0.0077203805558383465, + -0.0006835742969997227, + 0.3038807809352875, + 0.00930703990161419, + -0.017654214054346085, + -0.029550395905971527, + 0.0014829621650278568, + -0.010562432929873466, + -0.011867706663906574, + -0.008104459382593632, + 0.008003979921340942, + -0.028282882645726204, + 0.00898829661309719, + -0.04963170364499092, + 0.014971665106713772, + 0.028662119060754776, + 0.055792808532714844, + 0.018142173066735268, + 0.029526766389608383, + 0.04726170003414154, + 0.020290115848183632, + -0.01347910612821579, + -0.027794860303401947, + -0.033374592661857605, + 0.05699307844042778, + -0.005888971965759993, + 0.009723466821014881, + 0.011825029738247395, + 0.0005665962235070765, + -0.22433574497699738, + 0.04777664318680763, + 0.054696254432201385, + 0.06447272002696991, + 0.006656138692051172, + -0.2656468152999878, + -0.006602808367460966, + -0.04309352487325668, + 0.024392882362008095, + -0.046948980540037155, + 0.17317010462284088, + -0.014694501645863056, + 0.09150391072034836, + 0.05414793640375137, + -0.0034523033536970615, + -0.029682809486985207, + -0.11646991223096848, + 0.036394182592630386, + -0.008510537445545197, + -0.09555189311504364, + 0.012331446632742882, + 0.022554755210876465, + 0.037040166556835175, + 0.011939534917473793, + -0.035405583679676056, + -0.008284371346235275, + 0.008629710413515568, + -0.0017152110813185573, + -0.01656493730843067, + 0.02205522358417511, + -0.008015291765332222, + -0.02198217809200287, + -0.08165504783391953, + 0.018647879362106323, + 0.010489191859960556, + 0.0009643095545470715, + 0.08301698416471481, + 0.00795030314475298, + -0.08973152935504913, + 0.05324552580714226, + 0.0187348835170269, + 0.00770497927442193, + 0.016434336081147194, + 0.0031714467331767082, + 0.031489044427871704, + -0.01682765781879425, + -0.0006042059976607561, + 0.006229344755411148, + 0.0031935630831867456, + -0.03694210946559906, + -0.027148112654685974, + 0.03319454565644264, + 0.013541879132390022, + 0.04362545907497406, + 0.010766182094812393, + 0.01287879142910242, + 0.02723391354084015, + 0.01831277459859848, + -0.0028144901152700186, + 0.0317537821829319, + -0.05053209140896797, + 0.03341667726635933, + 0.009338690899312496, + 0.030376508831977844, + 0.028512636199593544, + 0.002190604107454419, + 0.031132254749536514, + 0.04174429178237915, + 0.025147251784801483, + 0.02602408640086651, + 0.022863827645778656, + 0.024160150438547134, + 0.04043813422322273, + 0.011693909764289856, + 0.008020071312785149, + 0.010814648121595383, + 0.014862221665680408, + 0.043966785073280334, + 0.04133215174078941, + 0.03920775279402733, + 0.02128027193248272, + -0.0024078795686364174, + 0.03185494989156723, + 0.030951442196965218, + 0.008766901679337025, + -0.0013500713976100087, + 0.012680909596383572, + 0.01911563239991665, + 0.02226334996521473, + 0.03873631730675697, + 0.005242412444204092, + 0.02335301972925663, + 0.00577192846685648, + 0.0019918885082006454, + 0.019501060247421265, + 0.048295676708221436, + 0.027288099750876427, + 0.03500128164887428, + 0.032504353672266006, + 0.03619033470749855, + 0.022762063890695572, + 0.014124974608421326, + 0.04055529460310936, + 0.040181197226047516, + 0.04843837395310402, + 0.019578352570533752, + 0.04370861127972603, + 0.024640914052724838, + 0.027013463899493217, + 0.04700532928109169, + 0.018523193895816803, + 0.03569294884800911, + 0.031140455976128578, + 0.010298499837517738, + 0.03979840502142906, + 0.015059049241244793, + 0.020604899153113365, + 0.010335667058825493, + 0.02557498589158058, + 0.015946611762046814, + 0.018900645896792412, + 0.05494159087538719, + 0.015756357461214066, + 0.0452926866710186, + 0.04820817708969116, + -0.0183499027043581, + 0.04002442955970764, + -0.08226092159748077, + -0.034417178481817245, + 0.059122342616319656, + 0.028960591182112694, + -0.020427608862519264, + -0.043222296983003616, + 0.023134637624025345, + -0.014232538640499115, + -0.06970997899770737, + -0.0035826240200549364, + -0.015384080819785595, + -0.0695020854473114, + 0.03645527362823486, + 0.013986784033477306, + -0.027729706838726997, + -0.05711805075407028, + -0.0763891413807869, + -0.16338491439819336, + -0.02358265034854412, + -0.004730133805423975, + 0.022057903930544853, + -0.011578230187296867, + 0.040772147476673126, + -0.059327173978090286, + -0.03819728270173073, + -0.050089117139577866, + -0.005152902565896511, + -0.3071111738681793, + -0.010683669708669186, + 0.030922774225473404, + 0.08924981951713562, + 0.005679265595972538, + 0.06334424018859863, + 0.016136568039655685, + -0.02575727365911007, + -0.012562219053506851, + 0.007206748705357313, + -0.1373208612203598, + -0.010450832545757294, + -0.05991309881210327, + -0.006700845435261726, + -0.006468744482845068, + -0.02040017955005169, + -0.010068708099424839, + 0.008442427963018417, + 0.012259873561561108, + -0.002103718463331461, + -0.019605906680226326, + -0.010690353810787201, + 0.0005222380859777331, + -0.015031278133392334, + -0.012983204796910286, + -0.03552224859595299, + -0.007792052812874317, + -0.035602111369371414, + -0.03479204699397087, + -0.02480080910027027, + -0.05733964219689369, + 4.38804054283537e-05, + -0.021825626492500305, + -0.03287259489297867, + -0.05437042564153671, + -0.007981077767908573, + 0.023045696318149567, + 0.05785335600376129, + 0.03685669228434563, + 0.04314129799604416, + -0.005843586288392544, + -0.024806369096040726, + -0.02562016434967518, + 0.0015172295970842242, + -0.01568800024688244, + -0.005925294477492571, + 0.010173594579100609, + 0.06834683567285538, + 0.024159085005521774, + -0.009547322988510132, + 0.014080812223255634, + 0.013578452169895172, + 0.035671167075634, + 0.01240566186606884, + -0.021352441981434822, + 0.05245270952582359, + -0.008943279273808002, + -0.010131126269698143, + 0.02976749651134014, + 0.0600045844912529, + 0.0014893191400915384, + 0.03796907886862755, + 0.01258794590830803, + -0.025344882160425186, + 0.14140591025352478, + 0.028354406356811523, + 0.0035325682256370783, + 0.05017172172665596, + 0.01994139887392521, + 0.03679897263646126, + -0.009579945355653763, + -0.012607194483280182, + -0.00034231581958010793, + 0.00046832446241751313, + 0.057916246354579926, + 0.02351403795182705, + 0.06157909706234932, + 0.00789523497223854, + -0.018361341208219528, + 0.0018971840618178248, + -0.007180131506174803, + -0.0010631990153342485, + -0.03140748664736748, + -0.028505641967058182, + 0.010669395327568054, + -0.036474280059337616, + 0.01703447848558426, + 0.04667484760284424, + -0.007303370162844658, + 0.01768752932548523, + 0.012412219308316708, + 0.013702306896448135, + 0.07651616632938385, + 0.05469715967774391, + 0.013292597606778145, + -0.006288900971412659, + 0.0215559434145689, + 0.010094149969518185, + -0.024216346442699432, + -0.15225785970687866, + 0.05467289313673973, + 0.019871067255735397, + 0.04662928730249405, + 0.05072600021958351, + -0.011824453249573708, + -0.028083933517336845, + 0.013322187587618828, + -0.044827401638031006, + 0.05955006927251816, + -0.006152187939733267, + 0.013426700606942177, + -0.014220507815480232, + 0.022510837763547897, + 0.019426455721259117, + -0.05546477064490318, + -0.49202534556388855, + 0.026985207572579384, + -0.08852843940258026, + 0.07166163623332977, + 0.05509938299655914, + -0.42284780740737915, + -0.05131356418132782, + 0.0196990966796875, + -0.008681846782565117, + 0.02739996463060379, + 0.0010900507913902402, + 0.04289104416966438, + -0.06694932281970978, + 0.05930810049176216, + -0.02174360118806362, + 0.03464379161596298, + 0.018284866586327553, + 0.018807150423526764, + 0.019874336197972298, + -0.03665176033973694, + -0.2980017066001892, + 0.050937239080667496, + -0.013874954544007778, + -0.0229057464748621, + 0.016420641914010048, + 0.024160616099834442, + -0.10750921070575714, + -0.010134756565093994, + 0.026874780654907227, + 0.007151094265282154, + 0.06304068863391876, + -0.11811652034521103, + -0.12590888142585754, + 0.031846947968006134, + -0.06898463517427444, + 0.03395693376660347, + -0.00010166154243052006, + -0.19019480049610138, + 0.06616076827049255, + -0.035927142947912216, + 0.08526375889778137, + 0.0015017242403700948, + -0.009137739427387714, + 0.04529058188199997, + -0.23621641099452972, + 0.02148340456187725, + -0.02741178683936596, + -0.20779411494731903, + ] + value = numpy.array(list_value, dtype=numpy.float32).reshape((64, 64, 1, 1)) + tensor = numpy_helper.from_array(value, name="onnx::Conv_504") + + initializers.append(tensor) + + list_value = [ + 5.195802688598633, + 0.940099835395813, + -7.016428470611572, + 5.185446739196777, + -4.134859085083008, + 2.0121846199035645, + 5.215719223022461, + 3.371406078338623, + 3.7616095542907715, + -3.6593239307403564, + 15.99945068359375, + 3.306276321411133, + 5.790191173553467, + 6.33050537109375, + 3.4512906074523926, + 2.5531861782073975, + 4.278702259063721, + 4.350361347198486, + 8.025779724121094, + -2.8830037117004395, + 2.915111541748047, + 3.592482805252075, + 5.810481071472168, + 3.4743332862854004, + 3.5245680809020996, + 1.8243598937988281, + 8.069726943969727, + 1.401036024093628, + 5.110081672668457, + -12.873579978942871, + 10.977816581726074, + 5.909627437591553, + -0.4007779359817505, + -20.147268295288086, + 6.649413585662842, + 3.325921058654785, + 5.84471321105957, + 4.47447395324707, + 3.754193067550659, + -5.167671203613281, + 3.2778055667877197, + -9.067073822021484, + 2.6243438720703125, + 1.7002031803131104, + 5.476454734802246, + 2.510835647583008, + 3.856968402862549, + 2.3172807693481445, + 12.462139129638672, + 7.355924129486084, + 4.140628814697266, + 4.807559967041016, + 5.7524309158325195, + 4.128836154937744, + 11.4532470703125, + -12.482564926147461, + 5.590144157409668, + 0.9172697067260742, + 4.356811046600342, + 0.9934853315353394, + -4.3548994064331055, + 15.853201866149902, + -5.241130828857422, + 5.9644365310668945, + ] + value = numpy.array(list_value, dtype=numpy.float32) + tensor = numpy_helper.from_array(value, name="onnx::Conv_505") + + initializers.append(tensor) + + # inputs + + inputs.append(make_tensor_value_info("input", 1, ["batch_size", 3, 32, 32])) + + # outputs + + outputs.append(make_tensor_value_info("/layer1/layer1.0/relu/Relu_output_0", 1, ["batch_size", 64, 8, 8])) + + # nodes + + node = make_node( + "Conv", + ["input", "onnx::Conv_501", "onnx::Conv_502"], + ["/conv1/Conv_output_0"], + name="/conv1/Conv", + dilations=[1, 1], + group=1, + kernel_shape=[7, 7], + pads=[3, 3, 3, 3], + strides=[2, 2], + domain="", + ) + nodes.append(node) + + node = make_node("Relu", ["/conv1/Conv_output_0"], ["/relu/Relu_output_0"], name="/relu/Relu", domain="") + nodes.append(node) + + node = make_node( + "MaxPool", + ["/relu/Relu_output_0"], + ["/maxpool/MaxPool_output_0"], + name="/maxpool/MaxPool", + ceil_mode=0, + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + strides=[2, 2], + domain="", + ) + nodes.append(node) + + node = make_node( + "Conv", + ["/maxpool/MaxPool_output_0", "onnx::Conv_504", "onnx::Conv_505"], + ["/layer1/layer1.0/conv1/Conv_output_0"], + name="/layer1/layer1.0/conv1/Conv", + dilations=[1, 1], + group=1, + kernel_shape=[1, 1], + pads=[0, 0, 0, 0], + strides=[1, 1], + domain="", + ) + nodes.append(node) + + node = make_node( + "Relu", + ["/layer1/layer1.0/conv1/Conv_output_0"], + ["/layer1/layer1.0/relu/Relu_output_0"], + name="/layer1/layer1.0/relu/Relu", + domain="", + ) + nodes.append(node) + + # opsets + opset_imports = [make_opsetid(domain, 1 if version is None else version) for domain, version in opsets.items()] + + # graph + graph = make_graph(nodes, "torch_jit", inputs, outputs, initializers) + # '7' + + onnx_model = make_model(graph, opset_imports=opset_imports, functions=functions) + onnx_model.ir_version = 7 + onnx_model.producer_name = "pytorch" + onnx_model.producer_version = "" + onnx_model.domain = "" + onnx_model.model_version = 0 + onnx_model.doc_string = "" + set_model_props(onnx_model, {}) + + return onnx_model diff --git a/onnxruntime/test/python/quantization/test_conv_dynamic.py b/onnxruntime/test/python/quantization/test_conv_dynamic.py index 9578c9fe708aa..18467bcbc1083 100644 --- a/onnxruntime/test/python/quantization/test_conv_dynamic.py +++ b/onnxruntime/test/python/quantization/test_conv_dynamic.py @@ -95,7 +95,7 @@ def test_quant_conv(self): for use_quant_config in [True, False]: self.dynamic_quant_conv_test(QuantType.QUInt8, extra_options={}, use_quant_config=use_quant_config) - # TODO: uncomment following after ConvInteger s8 supportted + # TODO: uncomment following after ConvInteger s8 supported # def test_quant_conv_s8s8(self): # self.dynamic_quant_conv_test(QuantType.QInt8, extra_options={'ActivationSymmetric': True}) diff --git a/onnxruntime/test/python/quantization/test_quantize_static_resnet.py b/onnxruntime/test/python/quantization/test_quantize_static_resnet.py new file mode 100644 index 0000000000000..1efa283af6881 --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantize_static_resnet.py @@ -0,0 +1,138 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- +import os +import random +import tempfile +import unittest + +import numpy as np +import onnx +from numpy.testing import assert_allclose +from onnx.numpy_helper import to_array +from resnet_code import create_model + +from onnxruntime import InferenceSession +from onnxruntime import __version__ as ort_version +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static +from onnxruntime.quantization.calibrate import CalibrationDataReader, CalibrationMethod + + +class FakeResnetCalibrationDataReader(CalibrationDataReader): + def __init__(self, batch_size: int = 16): + super().__init__() + self.dataset = [ + (np.random.rand(1, 3, 32, 32).astype(np.float32), random.randint(0, 9)) for _ in range(batch_size) + ] + self.iterator = iter(self.dataset) + + def get_next(self) -> dict: + try: + return {"input": next(self.iterator)[0]} + except Exception: + return None + + +class TestStaticQuantizationResNet(unittest.TestCase): + def test_quantize_static_resnet(self): + kwargs = { + "activation_type": QuantType.QUInt8, + "weight_type": QuantType.QInt8, + "calibrate_method": CalibrationMethod.Percentile, + "extra_options": { + "ActivationSymmetric": False, + "EnableSubgraph": False, + "ForceQuantizeNoInputCheck": False, + "MatMulConstBOnly": False, + "WeightSymmetric": True, + "extra.Sigmoid.nnapi": False, + }, + "nodes_to_exclude": None, + "nodes_to_quantize": None, + "op_types_to_quantize": None, + "per_channel": True, + "quant_format": QuantFormat.QDQ, + "reduce_range": False, + } + + proto = create_model() + + with tempfile.TemporaryDirectory() as temp: + model = os.path.join(temp, "resnet_first_nodes.onnx") + with open(model, "wb") as f: + f.write(proto.SerializeToString()) + + for per_channel in [True, False]: + kwargs["per_channel"] = per_channel + dataloader = FakeResnetCalibrationDataReader(16) + with self.subTest(per_channel=per_channel): + qdq_file = os.path.join( + temp, f"preprocessed-small-qdq-{1 if per_channel else 0}-ort-{ort_version}.onnx" + ) + quantize_static( + model_input=model, + model_output=qdq_file, + calibration_data_reader=dataloader, + use_external_data_format=False, + **kwargs, + ) + + # With onnxruntime==1.15.1, the initializer 'onnx::Conv_504_zero_point' is: + # * uint8(128) if per_channel is False + # * int8([0, 0, ....]) if per_channel is True + # With onnxruntime>1.16.0 + # * uint8(128) if per_channel is False + # * uint8([128, 128, ..., 127, ...]) if per_channel is True + # QLinearConv : zero point of per-channel filter must be same. + # That's why the quantization forces a symmetric quantization into INT8. + # zero_point is guaranted to be zero whatever the channel is. + + with open(qdq_file, "rb") as f: + onx = onnx.load(f) + for init in onx.graph.initializer: + arr = to_array(init) + if ( + arr.dtype == np.int8 + and "zero_point" not in init.name + and not init.name.endswith("quantized") + ): + raise AssertionError( + f"Initializer {init.name!r} has type {arr.dtype} and " + f"shape {arr.shape} but should be {np.uint8}." + ) + + sess = InferenceSession(qdq_file, providers=["CPUExecutionProvider"]) + shape = (1, 3, 32, 32) + size = np.prod(shape) + dummy = (np.arange(size) / float(size)).astype(np.float32).reshape(shape) + got = sess.run(None, {"input": dummy}) + self.assertEqual(got[0].shape, (1, 64, 8, 8)) + self.assertEqual(got[0].dtype, np.float32) + if per_channel: + expected = np.array( + [ + [[1.0862497091293335, 0.9609132409095764], [1.0862497091293335, 0.9191343784332275]], + [[0.7520190477371216, 1.0026921033859253], [1.0444709062576294, 1.0862497091293335]], + [[0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [0.9609132409095764, 0.7937979102134705]], + ], + dtype=np.float32, + ) + assert_allclose(expected, got[0][0, :4, :2, :2], atol=0.2) + else: + expected = np.array( + [ + [[1.428238868713379, 1.2602107524871826], [1.3442248106002808, 1.2182037830352783]], + [[0.8821475505828857, 1.0921826362609863], [1.1341897249221802, 1.1761966943740845]], + [[0.0, 0.0], [0.0, 0.0]], + [[0.0, 0.0], [1.2182037830352783, 1.050175666809082]], + ], + dtype=np.float32, + ) + assert_allclose(expected, got[0][0, :4, :2, :2], atol=0.2) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 8357ce22fb710..b605673d6606f 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -184,6 +184,8 @@ static constexpr PATH_TYPE SEQUENCE_MODEL_URI_2 = TSTR("testdata/optional_sequen #endif static constexpr PATH_TYPE CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_1.onnx"); static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_TEST_MODEL_URI = TSTR("testdata/custom_op_library/custom_op_test.onnx"); +static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_COPY_TENSOR_ARRAY_2 = TSTR("testdata/custom_op_library/copy_2_inputs_2_outputs.onnx"); +static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_COPY_TENSOR_ARRAY_3 = TSTR("testdata/custom_op_library/copy_3_inputs_3_outputs.onnx"); #if !defined(DISABLE_FLOAT8_TYPES) static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_TEST_MODEL_FLOAT8_URI = TSTR("testdata/custom_op_library/custom_op_test_float8.onnx"); #endif @@ -1406,6 +1408,59 @@ TEST(CApiTest, test_custom_op_library) { #endif } +// It has memory leak. The OrtCustomOpDomain created in custom_op_library.cc:RegisterCustomOps function was not freed +#if defined(__ANDROID__) +TEST(CApiTest, test_custom_op_library_copy_variadic) { +// To accomodate a reduced op build pipeline +#elif defined(REDUCED_OPS_BUILD) && defined(USE_CUDA) +TEST(CApiTest, test_custom_op_library_copy_variadic) { +#else +TEST(CApiTest, test_custom_op_library_copy_variadic) { +#endif + std::cout << "Running inference using custom op shared library" << std::endl; + + std::vector inputs(2); + inputs[0].name = "input_0"; + inputs[0].dims = {15}; + inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f, + 6.6f, 7.7f, 8.8f, 9.9f, 10.0f, + 11.1f, 12.2f, 13.3f, 14.4f, 15.5f}; + inputs[1].name = "input_1"; + inputs[1].dims = {15}; + inputs[1].values = {15.5f, 14.4f, 13.3f, 12.2f, 11.1f, + 10.0f, 9.9f, 8.8f, 7.7f, 6.6f, + 5.5f, 4.4f, 3.3f, 2.2f, 1.1f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {15}; + std::vector expected_values_y = inputs[1].values; + + onnxruntime::PathString lib_name; +#if defined(_WIN32) + lib_name = ORT_TSTR("custom_op_library.dll"); +#elif defined(__APPLE__) + lib_name = ORT_TSTR("libcustom_op_library.dylib"); +#else + lib_name = ORT_TSTR("./libcustom_op_library.so"); +#endif + + TestInference(*ort_env, CUSTOM_OP_LIBRARY_COPY_TENSOR_ARRAY_2, + inputs, "output_1", expected_dims_y, + expected_values_y, 0, nullptr, lib_name.c_str()); + + inputs.push_back({}); + inputs[2].name = "input_2"; + inputs[2].dims = {15}; + inputs[2].values = {6.6f, 7.7f, 8.8f, 9.9f, 10.0f, + 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, + 11.1f, 12.2f, 13.3f, 14.4f, 15.5f}; + + expected_values_y = inputs[2].values; + TestInference(*ort_env, CUSTOM_OP_LIBRARY_COPY_TENSOR_ARRAY_3, + inputs, "output_2", expected_dims_y, + expected_values_y, 0, nullptr, lib_name.c_str()); +} + #if !defined(DISABLE_FLOAT8_TYPES) struct InputF8 { @@ -3126,9 +3181,12 @@ struct Merge { ORT_ENFORCE(ort_api->KernelInfoGetAttribute_int64(info, "reverse", &reverse) == nullptr); reverse_ = reverse != 0; } - void Compute(const Ort::Custom::Tensor& strings_in, - std::string_view string_in, - Ort::Custom::Tensor* strings_out) { + Ort::Status Compute(const Ort::Custom::Tensor& strings_in, + std::string_view string_in, + Ort::Custom::Tensor* strings_out) { + if (strings_in.NumberOfElement() == 0) { + return Ort::Status("the 1st input must have more than one string!", OrtErrorCode::ORT_INVALID_ARGUMENT); + } std::vector string_pool; for (const auto& s : strings_in.Data()) { string_pool.emplace_back(s.data(), s.size()); @@ -3141,6 +3199,7 @@ struct Merge { std::reverse(string_pool.begin(), string_pool.end()); } strings_out->SetStringOutput(string_pool, {static_cast(string_pool.size())}); + return Ort::Status(nullptr); } bool reverse_ = false; }; diff --git a/onnxruntime/test/testdata/custom_op_library/copy_2_inputs_2_outputs.onnx b/onnxruntime/test/testdata/custom_op_library/copy_2_inputs_2_outputs.onnx new file mode 100644 index 0000000000000..1756aae6e8d5a --- /dev/null +++ b/onnxruntime/test/testdata/custom_op_library/copy_2_inputs_2_outputs.onnx @@ -0,0 +1,16 @@ + :ï +d +input_0 +input_1output_0output_1copy_tensor_array"CopyTensorArrayAllVariadic: test.customopgraphZ +input_0 + + ÿÿÿÿÿÿÿÿÿZ +input_1 + + ÿÿÿÿÿÿÿÿÿb +output_0 + + ÿÿÿÿÿÿÿÿÿb +output_1 + + ÿÿÿÿÿÿÿÿÿB \ No newline at end of file diff --git a/onnxruntime/test/testdata/custom_op_library/copy_3_inputs_3_outputs.onnx b/onnxruntime/test/testdata/custom_op_library/copy_3_inputs_3_outputs.onnx new file mode 100644 index 0000000000000..86c9ec1c1fc37 --- /dev/null +++ b/onnxruntime/test/testdata/custom_op_library/copy_3_inputs_3_outputs.onnx @@ -0,0 +1,23 @@ + :À +t +input_0 +input_1 +input_2output_0output_1output_2copy_tensor_array"CopyTensorArrayCombined: test.customopgraphZ +input_0 + + ÿÿÿÿÿÿÿÿÿZ +input_1 + + ÿÿÿÿÿÿÿÿÿZ +input_2 + + ÿÿÿÿÿÿÿÿÿb +output_0 + + ÿÿÿÿÿÿÿÿÿb +output_1 + + ÿÿÿÿÿÿÿÿÿb +output_2 + + ÿÿÿÿÿÿÿÿÿB \ No newline at end of file diff --git a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc index f9e537fb61047..f7a623901f263 100644 --- a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc @@ -11,16 +11,20 @@ using namespace Ort::Custom; namespace Cpu { -void KernelOne(const Ort::Custom::Tensor& X, - const Ort::Custom::Tensor& Y, - Ort::Custom::Tensor& Z) { - auto input_shape = X.Shape(); +Ort::Status KernelOne(const Ort::Custom::Tensor& X, + const Ort::Custom::Tensor& Y, + Ort::Custom::Tensor& Z) { + if (X.NumberOfElement() != Y.NumberOfElement()) { + return Ort::Status("x and y has different number of elements", OrtErrorCode::ORT_INVALID_ARGUMENT); + } + auto x_shape = X.Shape(); auto x_raw = X.Data(); auto y_raw = Y.Data(); - auto z_raw = Z.Allocate(input_shape); + auto z_raw = Z.Allocate(x_shape); for (int64_t i = 0; i < Z.NumberOfElement(); ++i) { z_raw[i] = x_raw[i] + y_raw[i]; } + return Ort::Status{nullptr}; } // lite custom op as a function @@ -162,6 +166,46 @@ void FilterFloat8(const Ort::Custom::Tensor& floats_in, } #endif +// a sample custom op accepting variadic inputs, and generate variadic outputs by simply 1:1 copying. +template +Ort::Status CopyTensorArrayAllVariadic(const Ort::Custom::TensorArray& inputs, Ort::Custom::TensorArray& outputs) { + for (size_t ith_input = 0; ith_input < inputs.Size(); ++ith_input) { + const auto& input = inputs[ith_input]; + const auto& input_shape = input->Shape(); + const T* raw_input = reinterpret_cast(input->DataRaw()); + auto num_elements = input->NumberOfElement(); + T* raw_output = outputs.AllocateOutput(ith_input, input_shape); + if (!raw_output) { + return Ort::Status("Failed to allocate output!", OrtErrorCode::ORT_FAIL); + } + for (int64_t jth_elem = 0; jth_elem < num_elements; ++jth_elem) { + raw_output[jth_elem] = raw_input[jth_elem]; + } + } + return Ort::Status{nullptr}; +} + +template +Ort::Status CopyTensorArrayCombined(const Ort::Custom::Tensor& first_input, + const Ort::Custom::TensorArray& other_inputs, + Ort::Custom::Tensor& first_output, + Ort::Custom::TensorArray& other_outputs) { + const auto first_input_shape = first_input.Shape(); + const T* raw_input = reinterpret_cast(first_input.DataRaw()); + + T* raw_output = first_output.Allocate(first_input_shape); + if (!raw_output) { + return Ort::Status("Failed to allocate output!", OrtErrorCode::ORT_FAIL); + } + + auto num_elements = first_input.NumberOfElement(); + for (int64_t ith_elem = 0; ith_elem < num_elements; ++ith_elem) { + raw_output[ith_elem] = raw_input[ith_elem]; + } + + return CopyTensorArrayAllVariadic(other_inputs, other_outputs); +} + void RegisterOps(Ort::CustomOpDomain& domain) { static const std::unique_ptr c_CustomOpOne{Ort::Custom::CreateLiteCustomOp("CustomOpOne", "CPUExecutionProvider", KernelOne)}; static const std::unique_ptr c_CustomOpTwo{Ort::Custom::CreateLiteCustomOp("CustomOpTwo", "CPUExecutionProvider", KernelTwo)}; @@ -171,6 +215,8 @@ void RegisterOps(Ort::CustomOpDomain& domain) { static const std::unique_ptr c_Select{Ort::Custom::CreateLiteCustomOp("Select", "CPUExecutionProvider", Select)}; static const std::unique_ptr c_Fill{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)}; static const std::unique_ptr c_Box{Ort::Custom::CreateLiteCustomOp("Box", "CPUExecutionProvider", Box)}; + static const std::unique_ptr c_CopyTensorArrayAllVariadic{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayAllVariadic", "CPUExecutionProvider", CopyTensorArrayAllVariadic)}; + static const std::unique_ptr c_CopyTensorArrayCombined{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayCombined", "CPUExecutionProvider", CopyTensorArrayCombined)}; #if !defined(DISABLE_FLOAT8_TYPES) static const CustomOpOneFloat8 c_CustomOpOneFloat8; @@ -185,6 +231,9 @@ void RegisterOps(Ort::CustomOpDomain& domain) { domain.Add(c_Select.get()); domain.Add(c_Fill.get()); domain.Add(c_Box.get()); + domain.Add(c_CopyTensorArrayAllVariadic.get()); + domain.Add(c_CopyTensorArrayCombined.get()); + #if !defined(DISABLE_FLOAT8_TYPES) domain.Add(&c_CustomOpOneFloat8); domain.Add(c_FilterFloat8.get()); diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 71a10f646a7c6..c142106ed506c 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -233,7 +233,64 @@ "^test_resize_upsample_sizes_nearest_cuda", "^test_resize_upsample_sizes_nearest_floor_align_corners_cuda", "^test_resize_upsample_sizes_nearest_not_larger_cuda", - "^test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_cuda" + "^test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_cuda", + // onnx 1.15 (opset 20) new and updated op tests + "^test_ai_onnx_ml_label_encoder_string_int", + "^test_ai_onnx_ml_label_encoder_string_int_no_default", + "^test_ai_onnx_ml_label_encoder_tensor_mapping", + "^test_ai_onnx_ml_label_encoder_tensor_value_only_mapping", + "^test_gridsample_aligncorners_true", + "^test_gridsample_bicubic_align_corners_0_additional_1", + "^test_gridsample_bicubic_align_corners_1_additional_1", + "^test_gridsample_bicubic", + "^test_gridsample_bilinear_align_corners_0_additional_1", + "^test_gridsample_bilinear_align_corners_1_additional_1", + "^test_gridsample_bilinear", + "^test_gridsample_border_padding", + "^test_gridsample", + "^test_gridsample_nearest_align_corners_0_additional_1", + "^test_gridsample_nearest_align_corners_1_additional_1", + "^test_gridsample_nearest", + "^test_gridsample_reflection_padding", + "^test_gridsample_volumetric_bilinear_align_corners_0", + "^test_gridsample_volumetric_bilinear_align_corners_1", + "^test_gridsample_volumetric_nearest_align_corners_0", + "^test_gridsample_volumetric_nearest_align_corners_1", + "^test_gridsample_zeros_padding", + "^test_image_decoder_decode_bmp_rgb", + "^test_image_decoder_decode_jpeg2k_rgb", + "^test_image_decoder_decode_jpeg_bgr", + "^test_image_decoder_decode_jpeg_grayscale", + "^test_image_decoder_decode_jpeg_rgb", + "^test_image_decoder_decode_png_rgb", + "^test_image_decoder_decode_pnm_rgb", + "^test_image_decoder_decode_tiff_rgb", + "^test_image_decoder_decode_webp_rgb", + "^test_regex_full_match_basic", + "^test_regex_full_match_email_domain", + "^test_regex_full_match_empty", + "^test_string_concat_broadcasting", + "^test_string_concat", + "^test_string_concat_empty_string", + "^test_string_concat_utf8", + "^test_string_concat_zero_dimensional", + "^test_string_split_basic", + "^test_string_split_consecutive_delimiters", + "^test_string_split_empty_string_delimiter", + "^test_string_split_empty_tensor", + "^test_string_split_maxsplit", + "^test_string_split_no_delimiter", + "^test_dft_axis", + "^test_dft", + "^test_dft_inverse", + "^test_isinf", + "^test_isinf_float16", + "^test_isinf_negative", + "^test_isinf_positive", + "^test_isnan", + "^test_isnan_float16", + "^test_reduce_max_bool_inputs", + "^test_reduce_min_bool_inputs" ], "current_failing_tests_x86": [ "^test_vgg19", @@ -316,7 +373,24 @@ "^test_layer_normalization_4d_axis_negative_1_expanded_ver18_cpu", "^test_layer_normalization_4d_axis_negative_2_expanded_ver18_cpu", "^test_layer_normalization_4d_axis_negative_3_expanded_ver18_cpu", - "^test_layer_normalization_default_axis_expanded_ver18_cpu" + "^test_layer_normalization_default_axis_expanded_ver18_cpu", + // onnx 1.15 (opset 20) new and updated op tests (test_affine_grid_???_expanded utilizes ConstantOfShape so it needs to be skipped as well) + // https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1139541&view=logs&j=249e9d58-0012-5814-27cf-6a201adbd9cf&t=bb33e81f-0527-50e0-0fd2-e94f509f0a82 + // only supported with cpu provider + "^test_affine_grid_2d", + "^test_affine_grid_2d_align_corners", + "^test_affine_grid_2d_align_corners_expanded", + "^test_affine_grid_2d_expanded", + "^test_affine_grid_3d", + "^test_affine_grid_3d_align_corners", + "^test_affine_grid_3d_align_corners_expanded", + "^test_affine_grid_3d_expanded", + "^test_constantofshape_float_ones", + "^test_constantofshape_int_shape_zero", + "^test_constantofshape_int_zeros", + // https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1141563&view=logs&j=a018b46d-e41a-509d-6581-c95fdaa42fcd&t=d61c1d37-f101-5d28-982f-e5931b720302 + "^test_gelu_tanh_2_cpu", + "^test_gelu_tanh_2_expanded_cpu" ], "current_failing_tests_NNAPI": [ "^test_maxpool_2d_uint8", @@ -390,6 +464,14 @@ "^test_squeeze_negative_axes" ], "current_failing_tests_OPENVINO_CPU_FP32": [ + "^test_affine_grid_2d_align_corners", + "^test_affine_grid_2d_align_corners_expanded", + "^test_affine_grid_2d", + "^test_affine_grid_2d_expanded", + "^test_affine_grid_3d_align_corners", + "^test_affine_grid_3d_align_corners_expanded", + "^test_affine_grid_3d", + "^test_affine_grid_3d_expanded", "^test_operator_permute2", "^test_operator_repeat", "^test_operator_repeat_dim_overflow", @@ -569,7 +651,22 @@ "^test_sequence_map_identity_1_sequence_cpu", "^test_sequence_map_identity_1_sequence_expanded_cpu", "^test_sequence_map_identity_2_sequences_cpu", - "^test_sequence_map_identity_2_sequences_expanded_cpu" + "^test_sequence_map_identity_2_sequences_expanded_cpu", + // onnx 1.15 (opset 20) new and updated op tests (test_affine_grid_???_expanded utilizes ConstantOfShape so it needs to be skipped as well) + // https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1139542&view=logs&j=3032dfba-5baf-5872-0871-2e69cb7f4b6a&t=f0d05deb-fc26-5aaf-e43e-7db2764c07da + // only supported with cpu provider + "^test_affine_grid_2d", + "^test_affine_grid_2d_align_corners", + "^test_affine_grid_2d_align_corners_expanded", + "^test_affine_grid_2d_expanded", + "^test_affine_grid_3d", + "^test_affine_grid_3d_align_corners", + "^test_affine_grid_3d_align_corners_expanded", + "^test_affine_grid_3d_expanded", + "^test_constantofshape_float_ones", + "^test_constantofshape_int_shape_zero", + "^test_constantofshape_int_zeros" + ], // ORT first supported opset 7, so models with nodes that require versions prior to opset 7 are not supported "tests_with_pre_opset7_dependencies": [ diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_overrides.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_overrides.jsonc index caeea0a758ad9..07385ac9ade05 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_overrides.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_overrides.jsonc @@ -7,6 +7,9 @@ "test_dft": 1e-3, "test_dft_axis": 1e-3, "test_dft_inverse": 1e-3, + "test_dft_opset19": 1e-3, + "test_dft_axis_opset19": 1e-3, + "test_dft_inverse_opset19": 1e-3, "test_stft": 1e-4, "test_stft_with_window": 1e-4 }, diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx b/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx new file mode 100644 index 0000000000000..ec0f9a97815b8 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py b/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py new file mode 100644 index 0000000000000..a59e263763a3d --- /dev/null +++ b/onnxruntime/test/testdata/transform/fusion/layer_norm_fusion_scale_bias.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +import onnx +from onnx import OperatorSetIdProto, TensorProto, helper + + +def GenerateModel(model_name, has_casts=False, has_identity=False): # noqa: N802 + nodes = [ # LayerNorm subgraph + helper.make_node("ReduceMean", ["A"], ["rd_out"], "reduce1", axes=[-1], keepdims=1), + helper.make_node("Sub", ["A", "rd_out"], ["sub_out"], "sub"), + helper.make_node("Pow", ["cast_sub_out" if has_casts else "sub_out", "pow_in_2"], ["pow_out"], "pow"), + helper.make_node("ReduceMean", ["pow_out"], ["rd2_out"], "reduce2", axes=[-1], keepdims=1), + helper.make_node("Add", ["rd2_out", "const_e12_f32"], ["add1_out"], "add1"), + helper.make_node("Sqrt", ["add1_out"], ["sqrt_out"], "sqrt"), + helper.make_node("Div", ["cast_sub_out" if has_casts else "sub_out", "sqrt_out"], ["div_out"], "div"), + helper.make_node( + "Mul", + ["gamma_id_out" if has_identity else "gamma", "cast_div_out" if has_casts else "div_out"], + ["mul_out"], + "mul", + ), + helper.make_node("Add", ["mul_out", "const_e6_f16_out" if has_identity else "const_e6_f16"], ["C"], "add2"), + ] + + if has_casts: + nodes.extend( + [ + helper.make_node("Cast", ["sub_out"], ["cast_sub_out"], "cast_sub", to=1), + helper.make_node("Cast", ["div_out"], ["cast_div_out"], "cast_2", to=10), + ] + ) + + if has_identity: + nodes.extend( + [ + helper.make_node("Identity", ["gamma"], ["gamma_id_out"], "gamma_identity"), + helper.make_node("Identity", ["const_e6_f16"], ["const_e6_f16_out"], "const_e6_f16_identity"), + ] + ) + + initializers = [ # initializers + helper.make_tensor("pow_in_2", TensorProto.FLOAT, [], [2]), + helper.make_tensor("const_e12_f32", TensorProto.FLOAT, [], [1e-12]), + helper.make_tensor("const_e6_f16", TensorProto.FLOAT16, [4], [1e-6, 1e-6, 1e-6, 1e-6]), + helper.make_tensor( + "gamma", + TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT, + [4], + [1, 2, 3, 4], + ), + ] + + input_type = TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT + output_type = TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT + + graph = helper.make_graph( + nodes, + "LayerNorm", # name + [ # inputs + helper.make_tensor_value_info("A", input_type, [16, 32, 4]), + ], + [ # outputs + helper.make_tensor_value_info("C", output_type, [16, 32, 4]), + ], + initializers, + ) + + onnxdomain = OperatorSetIdProto() + onnxdomain.version = 12 + # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. + onnxdomain.domain = "" + msdomain = OperatorSetIdProto() + msdomain.version = 1 + msdomain.domain = "com.microsoft" + opsets = [onnxdomain, msdomain] + + model = helper.make_model(graph, opset_imports=opsets) + onnx.save(model, model_name) + + +GenerateModel("layer_norm_fusion_scale_bias.onnx", True, True) diff --git a/onnxruntime/test/testdata/transform/gh_issue_17392.onnx b/onnxruntime/test/testdata/transform/gh_issue_17392.onnx new file mode 100644 index 0000000000000..ca9b78a4179bd --- /dev/null +++ b/onnxruntime/test/testdata/transform/gh_issue_17392.onnx @@ -0,0 +1,9 @@ + :Z ++C"Constant* + value_stringsJabcJdef  + +CY"IdentityConstantb +Y + + +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.onnx b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.onnx new file mode 100644 index 0000000000000..9d82d68a40098 Binary files /dev/null and b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.onnx differ diff --git a/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.py b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.py new file mode 100644 index 0000000000000..d4e3f7e8cbab6 --- /dev/null +++ b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.py @@ -0,0 +1,60 @@ +import numpy as np +import onnx +from onnx import TensorProto, helper + + +# Create a model with shared initializers that can be updated in-place by the transpose optimizer, +# including ones behind a DQ node. The transpose optimizer updates the first usage and inserts +# Transpose/Unsqueeze ops on the others (see UnsqueezeInput and TransposeInput). +# When we push the Transpose past other usages we should be able to cancel out those Transpose/Unsqueeze ops. +# We need 3 DQ nodes to ensure the Transpose or Unsqueeze added by the transpose optimizer is not +# removed prematurely. +def create_model(broadcast_weights: bool): + if broadcast_weights: + bias_shape = [2, 2] + bias_values = np.random.randn(2, 2) + else: + bias_shape = [1, 3, 2, 2] + bias_values = np.random.randn(1, 3, 2, 2) + + graph = helper.make_graph( + name="graph", + inputs=[ + helper.make_tensor_value_info("input0", TensorProto.FLOAT, [1, 2, 2, 3]), + ], + initializer=[ + helper.make_tensor("bias_quant", TensorProto.UINT8, bias_shape, bias_values.astype(np.uint8)), + helper.make_tensor("bias_fp32", TensorProto.FLOAT, bias_shape, bias_values.astype(np.float32)), + helper.make_tensor("dq_scale0", TensorProto.FLOAT, [], [1.5]), + helper.make_tensor("dq_zp0", TensorProto.UINT8, [], [5]), + helper.make_tensor("dq_scale1", TensorProto.FLOAT, [], [0.5]), + ], + nodes=[ + # Transpose input from channels last to channels first + helper.make_node("Transpose", ["input0"], ["input_T"], perm=[0, 3, 1, 2]), + helper.make_node("DequantizeLinear", ["bias_quant", "dq_scale0", "dq_zp0"], ["DQ0"], "DQ0"), + helper.make_node("Add", ["input_T", "DQ0"], ["A0"], "A0"), + helper.make_node("DequantizeLinear", ["bias_quant", "dq_scale1"], ["DQ1"], "DQ1"), + helper.make_node("Add", ["A0", "DQ1"], ["A1"], "A1"), + helper.make_node("DequantizeLinear", ["bias_quant", "dq_scale0"], ["DQ2"], "DQ2"), + helper.make_node("Add", ["A1", "DQ2"], ["A2"], "A2"), + helper.make_node("Add", ["A2", "bias_fp32"], ["A3"], "A3"), + helper.make_node("Add", ["A3", "bias_fp32"], ["A4"], "A4"), + # NCHW to NHWC + helper.make_node("Transpose", ["A4"], ["output0"], perm=[0, 2, 3, 1]), + ], + outputs=[ + helper.make_tensor_value_info("output0", TensorProto.FLOAT, [1, 2, 2, 3]), + ], + ) + + model = helper.make_model(graph) + onnx.checker.check_model(model, full_check=True) + return model + + +if __name__ == "__main__": + model = create_model(broadcast_weights=False) + onnx.save(model, "transpose_optimizer_shared_initializers.onnx") + model = create_model(broadcast_weights=True) + onnx.save(model, "transpose_optimizer_shared_initializers_broadcast.onnx") diff --git a/onnxruntime/test/testdata/transpose_optimizer_shared_initializers_broadcast.onnx b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers_broadcast.onnx new file mode 100644 index 0000000000000..8bb2c6fd4a8b5 Binary files /dev/null and b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers_broadcast.onnx differ diff --git a/onnxruntime/test/xctest/xcgtest.mm b/onnxruntime/test/xctest/xcgtest.mm index 5367f3e89c07c..c02f18d906cbe 100644 --- a/onnxruntime/test/xctest/xcgtest.mm +++ b/onnxruntime/test/xctest/xcgtest.mm @@ -201,7 +201,7 @@ + (void)registerTestClasses { delete listeners.Release(listeners.default_result_printer()); free(argv); - BOOL runDisabledTests = testing::GTEST_FLAG(also_run_disabled_tests); + BOOL runDisabledTests = GTEST_FLAG_GET(also_run_disabled_tests); NSMutableDictionary* testFilterMap = [NSMutableDictionary dictionary]; NSCharacterSet* decimalDigitCharacterSet = [NSCharacterSet decimalDigitCharacterSet]; diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 174edabbc91fe..968eece361724 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -9,6 +9,7 @@ #include "api.h" #include +#include #include namespace { @@ -17,6 +18,14 @@ OrtErrorCode g_last_error_code; std::string g_last_error_message; } // namespace +enum DataLocation { + DATA_LOCATION_NONE = 0, + DATA_LOCATION_CPU = 1, + DATA_LOCATION_CPU_PINNED = 2, + DATA_LOCATION_TEXTURE = 3, + DATA_LOCATION_GPU_BUFFER = 4 +}; + static_assert(sizeof(const char*) == sizeof(size_t), "size of a pointer and a size_t value should be the same."); static_assert(sizeof(size_t) == 4, "size of size_t should be 4 in this build (wasm32)."); @@ -223,13 +232,23 @@ void OrtFree(void* ptr) { } } -OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length) { +OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location) { + if (data_location != DATA_LOCATION_CPU && + data_location != DATA_LOCATION_CPU_PINNED && + data_location != DATA_LOCATION_GPU_BUFFER) { + std::ostringstream ostr; + ostr << "Invalid data location: " << data_location; + CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str())); + return nullptr; + } + std::vector shapes(dims_length); for (size_t i = 0; i < dims_length; i++) { shapes[i] = dims[i]; } if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { + // data_location is ignored for string tensor. It is always CPU. OrtAllocator* allocator = nullptr; RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); @@ -244,12 +263,16 @@ OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* return UNREGISTER_AUTO_RELEASE(value); } else { - OrtMemoryInfo* memoryInfo = nullptr; - RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memoryInfo); - REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memoryInfo); + OrtMemoryInfo* memory_info = nullptr; + if (data_location != DATA_LOCATION_GPU_BUFFER) { + RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); + } else { + RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + } + REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memory_info); OrtValue* value = nullptr; - int error_code = CHECK_STATUS(CreateTensorWithDataAsOrtValue, memoryInfo, data, data_length, + int error_code = CHECK_STATUS(CreateTensorWithDataAsOrtValue, memory_info, data, data_length, dims_length > 0 ? shapes.data() : nullptr, dims_length, static_cast(data_type), &value); @@ -373,15 +396,85 @@ void OrtReleaseRunOptions(OrtRunOptions* run_options) { Ort::GetApi().ReleaseRunOptions(run_options); } +OrtIoBinding* OrtCreateBinding(OrtSession* session) { + OrtIoBinding* binding = nullptr; + int error_code = CHECK_STATUS(CreateIoBinding, session, &binding); + return (error_code == ORT_OK) ? binding : nullptr; +} + +int EMSCRIPTEN_KEEPALIVE OrtBindInput(OrtIoBinding* io_binding, + const char* name, + OrtValue* input) { + return CHECK_STATUS(BindInput, io_binding, name, input); +} + +int EMSCRIPTEN_KEEPALIVE OrtBindOutput(OrtIoBinding* io_binding, + const char* name, + OrtValue* output, + int output_location) { + if (output) { + return CHECK_STATUS(BindOutput, io_binding, name, output); + } else { + if (output_location != DATA_LOCATION_NONE && + output_location != DATA_LOCATION_CPU && + output_location != DATA_LOCATION_CPU_PINNED && + output_location != DATA_LOCATION_GPU_BUFFER) { + std::ostringstream ostr; + ostr << "Invalid data location (" << output_location << ") for output: \"" << name << "\"."; + return CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str())); + } + + OrtMemoryInfo* memory_info = nullptr; + if (output_location != DATA_LOCATION_GPU_BUFFER) { + RETURN_ERROR_CODE_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); + } else { + RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + } + REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memory_info); + return CHECK_STATUS(BindOutputToDevice, io_binding, name, memory_info); + } +} + +void OrtClearBoundOutputs(OrtIoBinding* io_binding) { + Ort::GetApi().ClearBoundOutputs(io_binding); +} + +void OrtReleaseBinding(OrtIoBinding* io_binding) { + Ort::GetApi().ReleaseIoBinding(io_binding); +} + +int OrtRunWithBinding(OrtSession* session, + OrtIoBinding* io_binding, + size_t output_count, + OrtValue** outputs, + OrtRunOptions* run_options) { + RETURN_ERROR_CODE_IF_ERROR(RunWithBinding, session, run_options, io_binding); + + OrtAllocator* allocator = nullptr; + RETURN_ERROR_CODE_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); + + size_t binding_output_count = 0; + OrtValue** binding_outputs = nullptr; + RETURN_ERROR_CODE_IF_ERROR(GetBoundOutputValues, io_binding, allocator, &binding_outputs, &binding_output_count); + REGISTER_AUTO_RELEASE_BUFFER(OrtValue*, binding_outputs, allocator); + + if (binding_output_count != output_count) { + return CheckStatus( + Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Output count is inconsistent with IO Binding output data.")); + } + + for (size_t i = 0; i < output_count; i++) { + outputs[i] = binding_outputs[i]; + } + + return ORT_OK; +} + int OrtRun(OrtSession* session, const char** input_names, const ort_tensor_handle_t* inputs, size_t input_count, const char** output_names, size_t output_count, ort_tensor_handle_t* outputs, OrtRunOptions* run_options) { - auto status_code = CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs); -#if defined(USE_JSEP) - EM_ASM({ Module.jsepRunPromiseResolve ?.($0); }, status_code); -#endif - return status_code; + return CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs); } char* OrtEndProfiling(ort_session_handle_t session) { diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 398c901e0e5ed..9a0664697f0ff 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -15,6 +15,9 @@ struct OrtSession; using ort_session_handle_t = OrtSession*; +struct OrtIoBinding; +using ort_io_binding_handle_t = OrtIoBinding*; + struct OrtSessionOptions; using ort_session_options_handle_t = OrtSessionOptions*; @@ -164,9 +167,10 @@ void EMSCRIPTEN_KEEPALIVE OrtFree(void* ptr); * @param data_length size of the buffer 'data' in bytes. * @param dims a pointer to an array of dims. the array should contain (dims_length) element(s). * @param dims_length the length of the tensor's dimension + * @param data_location specify the memory location of the tensor data. 0 for CPU, 1 for GPU buffer. * @returns a tensor handle. Caller must release it after use by calling OrtReleaseTensor(). */ -ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length); +ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location); /** * get type, shape info and data of the specified tensor. @@ -216,6 +220,58 @@ int EMSCRIPTEN_KEEPALIVE OrtAddRunConfigEntry(ort_run_options_handle_t run_optio */ void EMSCRIPTEN_KEEPALIVE OrtReleaseRunOptions(ort_run_options_handle_t run_options); +/** + * create an instance of ORT IO binding. + */ +ort_io_binding_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateBinding(ort_session_handle_t session); + +/** + * bind an input tensor to the IO binding instance. A cross device copy will be performed if necessary. + * @param io_binding handle of the IO binding + * @param name name of the input + * @param input handle of the input tensor + * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtBindInput(ort_io_binding_handle_t io_binding, + const char* name, + ort_tensor_handle_t input); + +/** + * bind an output tensor or location to the IO binding instance. + * @param io_binding handle of the IO binding + * @param name name of the output + * @param output handle of the output tensor. nullptr for output location binding. + * @param output_location specify the memory location of the output tensor data. + * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtBindOutput(ort_io_binding_handle_t io_binding, + const char* name, + ort_tensor_handle_t output, + int output_location); + +/** + * clear all bound outputs. + */ +void EMSCRIPTEN_KEEPALIVE OrtClearBoundOutputs(ort_io_binding_handle_t io_binding); + +/** + * release the specified ORT IO binding. + */ +void EMSCRIPTEN_KEEPALIVE OrtReleaseBinding(ort_io_binding_handle_t io_binding); + +/** + * inference the model. + * @param session handle of the specified session + * @param io_binding handle of the IO binding + * @param run_options handle of the run options + * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtRunWithBinding(ort_session_handle_t session, + ort_io_binding_handle_t io_binding, + size_t output_count, + ort_tensor_handle_t* outputs, + ort_run_options_handle_t run_options); + /** * inference the model. * @param session handle of the specified session diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 15d393f4ce62d..427ad6f6d14f3 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -14,40 +14,156 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module.jsepReleaseKernel = releaseKernel; Module.jsepRunKernel = runKernel; - Module['jsepOnRunStart'] = sessionId => { - Module['jsepRunPromise'] = new Promise(r => { - Module.jsepRunPromiseResolve = r; - }); - - if (Module.jsepSessionState) { - throw new Error('Session already started'); - } - - Module.jsepSessionState = { - sessionId, - errors: [] + // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) + // It removes some overhead in cwarp() and ccall() that we don't need. + // + // Currently in JSEP build, we only use this for the following functions: + // - OrtRun() + // - OrtRunWithBinding() + // - OrtBindInput() + // + // Note: about parameters "getFunc" and "setFunc": + // - Emscripten has different behaviors for Debug and Release builds for generating exported function wrapper. + // + // - In Debug build, it will generate a wrapper function for each exported function. For example, it generates a + // wrapper for OrtRun() like this (minified): + // ``` + // var _OrtRun = Module["_OrtRun"] = createExportWrapper("OrtRun"); + // ``` + // + // - In Release build, it will generate a lazy loading wrapper for each exported function. For example, it generates + // a wrapper for OrtRun() like this (minified): + // ``` + // d._OrtRun = (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); + // ``` + // + // The behavior of these two wrappers are different. The debug build will assign `Module["_OrtRun"]` only once + // because `createExportWrapper()` does not reset `Module["_OrtRun"]` inside. The release build, however, will + // reset d._OrtRun to J.ka when the first time it is called. + // + // The difference is important because we need to design the async wrapper in a way that it can handle both cases. + // + // Now, let's look at how the async wrapper is designed to work for both cases: + // + // - Debug build: + // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to `createExportWrapper("OrtRun")`. + // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async + // wrapper function. + // Value of `Module["_OrtRun"]` will not be changed again. + // + // - Release build: + // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to a lazy loading wrapper function. + // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async + // wrapper function. + // 3. When the first time `Module["_OrtRun"]` is called, the async wrapper will be called. It will call into this + // function: + // ``` + // (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); + // ``` + // This function will assign d._OrtRun (ie. the minimized `Module["_OrtRun"]`) to the real function (J.ka). + // 4. Since d._OrtRun is re-assigned, we need to update the async wrapper to re-assign its stored + // function to the updated value (J.ka), and re-assign the value of `d._OrtRun` back to the async wrapper. + // Value of `Module["_OrtRun"]` will not be changed again. + // + // The value of `Module["_OrtRun"]` will need to be assigned for 2 times for debug build and 4 times for release + // build. + // + // This is why we need this `getFunc` and `setFunc` parameters. They are used to get the current value of an + // exported function and set the new value of an exported function. + // + const jsepWrapAsync = (func, getFunc, setFunc) => { + return (...args) => { + // cache the async data before calling the function. + const previousAsync = Asyncify.currData; + + const previousFunc = getFunc?.(); + const ret = func(...args); + const newFunc = getFunc?.(); + if (previousFunc !== newFunc) { + // The exported function has been updated. + // Set the sync function reference to the new function. + func = newFunc; + // Set the exported function back to the async wrapper. + setFunc(previousFunc); + // Remove getFunc and setFunc. They are no longer needed. + setFunc = null; + getFunc = null; + } + + // If the async data has been changed, it means that the function started an async operation. + if (Asyncify.currData != previousAsync) { + // returns the promise + return Asyncify.whenDone(); + } + // the function is synchronous. returns the result. + return ret; }; }; - Module['jsepOnRunEnd'] = sessionId => { - if (Module.jsepSessionState.sessionId !== sessionId) { - throw new Error('Session ID mismatch'); - } - - const errorPromises = Module.jsepSessionState.errors; - Module.jsepSessionState = null; - - return errorPromises.length === 0 ? Promise.resolve() : new Promise((resolve, reject) => { - Promise.all(errorPromises).then(errors => { - errors = errors.filter(e => e); - if (errors.length > 0) { - reject(new Error(errors.join('\n'))); - } else { - resolve(); + // This is a wrapper for OrtRun() and OrtRunWithBinding() to ensure that Promises are handled correctly. + const runAsync = (runAsyncFunc) => { + return async (...args) => { + try { + // Module.jsepSessionState should be null, unless we are in the middle of a session. + // If it is not null, it means that the previous session has not finished yet. + if (Module.jsepSessionState) { + throw new Error('Session already started'); + } + const state = Module.jsepSessionState = {sessionHandle: args[0], errors: []}; + + // Run the acyncified function: OrtRun() or OrtRunWithBinding() + const ret = await runAsyncFunc(...args); + + // Check if the session is still valid. this object should be the same as the one we set above. + if (Module.jsepSessionState !== state) { + throw new Error('Session mismatch'); + } + + // Flush the backend. This will submit all pending commands to the GPU. + backend['flush'](); + + // Await all pending promises. This includes GPU validation promises for diagnostic purposes. + const errorPromises = state.errors; + if (errorPromises.length > 0) { + let errors = await Promise.all(errorPromises); + errors = errors.filter(e => e); + if (errors.length > 0) { + throw new Error(errors.join('\n')); + } } - }, reason => { - reject(reason); - }); - }); + + return ret; + } finally { + Module.jsepSessionState = null; + } + }; + }; + + // replace the original functions with asyncified versions + Module['_OrtRun'] = runAsync(jsepWrapAsync( + Module['_OrtRun'], + () => Module['_OrtRun'], + v => Module['_OrtRun'] = v)); + Module['_OrtRunWithBinding'] = runAsync(jsepWrapAsync( + Module['_OrtRunWithBinding'], + () => Module['_OrtRunWithBinding'], + v => Module['_OrtRunWithBinding'] = v)); + Module['_OrtBindInput'] = jsepWrapAsync( + Module['_OrtBindInput'], + () => Module['_OrtBindInput'], + v => Module['_OrtBindInput'] = v); + + // expose webgpu backend functions + Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => { + return backend['registerBuffer'](sessionId, index, buffer, size); + }; + Module['jsepUnregisterBuffers'] = sessionId => { + backend['unregisterBuffers'](sessionId); + }; + Module['jsepGetBuffer'] = (dataId) => { + return backend['getBuffer'](dataId); + }; + Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { + return backend['createDownloader'](gpuBuffer, size, type); }; }; diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 91b1df7b7cf2d..86d3cdee9ba98 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -4171,7 +4171,7 @@ Return true if all elements are true and false otherwise. "T", OpSchema::Variadic, /*is_homogeneous*/ false, /*min_arity*/ 1) - .TypeConstraint("T", OpSchema::all_tensor_types_with_bfloat(), + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Allow inputs and outputs to be any kind of tensor."); #endif // ENABLE_TRITON diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index d35366e556a42..9ac9f3ee090bb 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -1252,7 +1252,7 @@ Status WithOrtValuesFromTensorProtos( OrtValue ort_value; - ORT_RETURN_IF_ERROR(utils::TensorProtoToMLValue( + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue( Env::Default(), model_location.c_str(), tensor_proto, mem_buffer, ort_value)); diff --git a/orttraining/orttraining/models/runner/training_util.cc b/orttraining/orttraining/models/runner/training_util.cc index fa9c3f24b2ee5..7764508d9a091 100644 --- a/orttraining/orttraining/models/runner/training_util.cc +++ b/orttraining/orttraining/models/runner/training_util.cc @@ -52,7 +52,7 @@ common::Status DataSet::AddData(const vector& featu OrtValue ort_value; OrtMemoryInfo info("Cpu", OrtDeviceAllocator, OrtDevice{}, 0, OrtMemTypeDefault); std::unique_ptr buffer = std::make_unique(cpu_tensor_length); - ORT_RETURN_IF_ERROR(utils::TensorProtoToMLValue( + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue( Env::Default(), nullptr, tensor_proto, MemBuffer(buffer.get(), cpu_tensor_length, info), ort_value)); sample->push_back(ort_value); diff --git a/orttraining/orttraining/python/orttraining_pybind_state.cc b/orttraining/orttraining/python/orttraining_pybind_state.cc index 3f3aa396e6ca0..35d9755ba0ba7 100644 --- a/orttraining/orttraining/python/orttraining_pybind_state.cc +++ b/orttraining/orttraining/python/orttraining_pybind_state.cc @@ -1065,17 +1065,60 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn checkpoint_state(m, "CheckpointState", R"pbdoc(CheckpointState.)pbdoc"); checkpoint_state .def(py::init()) - .def("add_property", [](onnxruntime::training::api::CheckpointState* state, - const std::string& property_name, - const std::variant& property_value) { - state->property_bag.AddProperty(property_name, property_value); - }) - .def("get_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { - return state->property_bag.GetProperty(property_name); - }) - .def("has_property", [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { - return state->property_bag.HasProperty(property_name); - }); + .def("add_property", + [](onnxruntime::training::api::CheckpointState* state, + const std::string& property_name, + const std::variant& property_value) { + state->property_bag.AddProperty(property_name, property_value); + }) + .def("get_property", + [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { + return state->property_bag.GetProperty(property_name); + }) + .def("has_property", + [](onnxruntime::training::api::CheckpointState* state, const std::string& property_name) { + return state->property_bag.HasProperty(property_name); + }) + .def("copy_parameter_from", + [](onnxruntime::training::api::CheckpointState* state, + const std::string& parameter_name, OrtValue& value) -> void { + auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == state->module_checkpoint_state.named_parameters.end()) { + ORT_THROW("Parameter with name ", parameter_name, " does not exist."); + } + ORT_THROW_IF_ERROR(it->second->CopyFrom( + state->module_checkpoint_state.train_session_data_transfer_mgr, value)); + }) + .def("get_parameter", + [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { + auto it = state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == state->module_checkpoint_state.named_parameters.end()) { + ORT_THROW("Parameter with name ", parameter_name, " does not exist."); + } + return it->second; + }) + .def("has_parameter", + [](onnxruntime::training::api::CheckpointState* state, const std::string& parameter_name) { + return state->module_checkpoint_state.named_parameters.count(parameter_name); + }) + .def("parameter_names", + [](onnxruntime::training::api::CheckpointState* state) { + std::vector names; + for ([[maybe_unused]] auto& [name, value] : state->module_checkpoint_state.named_parameters) { + names.push_back(name); + } + std::sort(names.begin(), names.end()); + return names; + }) + .def("property_names", + [](onnxruntime::training::api::CheckpointState* state) { + std::vector names; + for ([[maybe_unused]] auto& [name, value] : state->property_bag) { + names.push_back(name); + } + std::sort(names.begin(), names.end()); + return names; + }); py::class_ training_optimizer(m, "Optimizer", R"pbdoc(Training Optimizer.)pbdoc"); @@ -1111,6 +1154,21 @@ void addObjectMethodsForTraining(py::module& m, ExecutionProviderRegistrationFn ORT_THROW_IF_ERROR(scheduler->Step()); }); + py::class_> + parameter(m, "Parameter"); + parameter + .def_property_readonly("name", &onnxruntime::training::api::Parameter::Name) + .def_property_readonly("data", &onnxruntime::training::api::Parameter::Data) + .def_property_readonly("grad", &onnxruntime::training::api::Parameter::Gradient) + .def_property_readonly("requires_grad", &onnxruntime::training::api::Parameter::RequiresGrad) + .def("copy_from", + [](onnxruntime::training::api::Parameter* parameter, + onnxruntime::training::api::CheckpointState* state, + OrtValue& value) -> void { + ORT_THROW_IF_ERROR(parameter->CopyFrom(state->module_checkpoint_state.train_session_data_transfer_mgr, value)); + }); + m.def( "save_checkpoint", [](const std::vector& trainable_tensor_protos_pybytes, diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 7024244629c3e..88ef90a7feaa8 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -15,6 +15,12 @@ namespace onnxruntime { namespace python { namespace py = pybind11; +#if defined(USE_MPI) && defined(ORT_USE_NCCL) +static constexpr bool HAS_COLLECTIVE_OPS = true; +#else +static constexpr bool HAS_COLLECTIVE_OPS = false; +#endif + using namespace onnxruntime::logging; std::unique_ptr CreateExecutionProviderInstance( @@ -361,6 +367,8 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { }, "Clean the execution provider instances used in ort training module."); + m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; }); + // See documentation for class TrainingEnvInitialzer earlier in this module // for an explanation as to why this is needed. auto atexit = py::module_::import("atexit"); diff --git a/orttraining/orttraining/python/training/api/checkpoint_state.py b/orttraining/orttraining/python/training/api/checkpoint_state.py index 285264bbed744..ba95cd04fce7e 100644 --- a/orttraining/orttraining/python/training/api/checkpoint_state.py +++ b/orttraining/orttraining/python/training/api/checkpoint_state.py @@ -5,70 +5,171 @@ import os +import numpy as np + from onnxruntime.capi import _pybind_state as C +from onnxruntime.capi.onnxruntime_inference_collection import OrtValue -class CheckpointState: - """Class that holds the state of the training session +class Parameter: + """Class that represents a model parameter - This class holds all the state information of the training session such as the model parameters, - its gradients, the optimizer state and user defined properties. + This class represents a model parameter and provides access to its data, + gradient and other properties. This class is not expected to be instantiated directly. + Instead, it is returned by the `CheckpointState` object. + + Args: + parameter: The C.Parameter object that holds the underlying parameter data. + state: The C.CheckpointState object that holds the underlying session state. + """ + + def __init__(self, parameter: C.Parameter, state: C.CheckpointState): + self._parameter = parameter + self._state = state - User defined properties can be indexed by name from the `CheckpointState` object. + @property + def name(self) -> str: + """The name of the parameter""" + return self._parameter.name - To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method. + @property + def data(self) -> np.ndarray: + """The data of the parameter""" + return self._parameter.data.numpy() + + @data.setter + def data(self, value: np.ndarray) -> None: + """Sets the data of the parameter""" + self._parameter.copy_from(self._state, OrtValue.ortvalue_from_numpy(value)._ortvalue) + + @property + def grad(self) -> np.ndarray: + """The gradient of the parameter""" + return self._parameter.grad.numpy() if self._parameter.grad.has_value() else None + + @property + def requires_grad(self) -> bool: + """Whether or not the parameter requires its gradient to be computed""" + return self._parameter.requires_grad + + def __repr__(self) -> str: + """Returns a string representation of the parameter""" + return f"Parameter(name={self.name}, requires_grad={self.requires_grad})" + + +class Parameters: + """Class that holds all the model parameters + + This class holds all the model parameters and provides access to them. + This class is not expected to be instantiated directly. Instead, it is returned by the + `CheckpointState`'s parameters attribute. + This class behaves like a dictionary and provides access to the parameters by name. Args: - state: The C.Checkpoint state object that holds the underlying session state. + state: The C.CheckpointState object that holds the underlying session state. """ def __init__(self, state: C.CheckpointState): - if not isinstance(state, C.CheckpointState): - raise TypeError(f"Invalid argument for CheckpointState received {type(state)}") self._state = state - @classmethod - def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState: - """Loads the checkpoint state from the checkpoint file + def __getitem__(self, name: str) -> Parameter: + """Gets the parameter associated with the given name + + Searches for the name in the parameters of the checkpoint state. Args: - checkpoint_uri: The path to the checkpoint file. + name: The name of the parameter Returns: - CheckpointState: The checkpoint state object. + The value of the parameter + + Raises: + KeyError: If the parameter is not found """ - return cls(C.load_checkpoint(os.fspath(checkpoint_uri))) - @classmethod - def save_checkpoint( - cls, state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False - ) -> None: - """Saves the checkpoint state to the checkpoint file + if name not in self: + raise KeyError(f"Parameter {name} not found.") + + return Parameter(self._state.get_parameter(name), self._state) + + def __setitem__(self, name: str, value: np.ndarray) -> None: + """Sets the parameter value for the given name + + Searches for the name in the parameters of the checkpoint state. + If the name is found in parameters, the value is updated. Args: - state: The checkpoint state object. - checkpoint_uri: The path to the checkpoint file. - include_optimizer_state: If True, the optimizer state is also saved to the checkpoint file. + name: The name of the parameter + value: The value of the parameter as a numpy array + + Raises: + KeyError: If the parameter is not found """ - C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state) + if name not in self: + raise KeyError(f"Parameter {name} not found.") + + self._state.copy_parameter_from(name, OrtValue.ortvalue_from_numpy(value)._ortvalue) + + def __contains__(self, name: str) -> bool: + """Checks if the parameter exists in the state + + Args: + name: The name of the parameter + + Returns: + True if the name is a parameter False otherwise + """ + + return self._state.has_parameter(name) + + def __iter__(self): + """Returns an iterator over the properties""" + for parameter_name in self._state.parameter_names(): + yield parameter_name, Parameter(self._state.get_parameter(parameter_name), self._state) + + def __repr__(self) -> str: + """Returns a string representation of the parameters""" + return self._state.parameter_names() + + def __len__(self) -> int: + """Returns the number of parameters""" + return len(self._state.parameter_names()) + + +class Properties: + def __init__(self, state: C.CheckpointState): + self._state = state def __getitem__(self, name: str) -> int | float | str: """Gets the property associated with the given name + Searches for the name in the properties of the checkpoint state. + Args: name: The name of the property Returns: The value of the property + + Raises: + KeyError: If the property is not found """ + + if name not in self: + raise KeyError(f"Property {name} not found.") + return self._state.get_property(name) def __setitem__(self, name: str, value: int | float | str) -> None: """Sets the property value for the given name + Searches for the name in the properties of the checkpoint state. + The value is added or updated in the properties. + Args: name: The name of the property value: The value of the property + Properties only support int, float and str values. """ self._state.add_property(name, value) @@ -79,6 +180,75 @@ def __contains__(self, name: str) -> bool: name: The name of the property Returns: - True if the property exists, False otherwise + True if the name is a property, False otherwise """ + return self._state.has_property(name) + + def __iter__(self): + """Returns an iterator over the properties""" + for property_name in self._state.property_names(): + yield property_name, self._state.get_property(property_name) + + def __repr__(self) -> str: + """Returns a string representation of the properties""" + return self._state.property_names() + + def __len__(self) -> int: + """Returns the number of properties""" + return len(self._state.property_names()) + + +class CheckpointState: + """Class that holds the state of the training session + + This class holds all the state information of the training session such as the model parameters, + its gradients, the optimizer state and user defined properties. + + To create the `CheckpointState`, use the `CheckpointState.load_checkpoint` method. + + Args: + state: The C.Checkpoint state object that holds the underlying session state. + """ + + def __init__(self, state: C.CheckpointState): + if not isinstance(state, C.CheckpointState): + raise TypeError(f"Invalid argument for CheckpointState received {type(state)}") + self._state = state + self._parameters = Parameters(self._state) + self._properties = Properties(self._state) + + @classmethod + def load_checkpoint(cls, checkpoint_uri: str | os.PathLike) -> CheckpointState: + """Loads the checkpoint state from the checkpoint file + + Args: + checkpoint_uri: The path to the checkpoint file. + + Returns: + CheckpointState: The checkpoint state object. + """ + return cls(C.load_checkpoint(os.fspath(checkpoint_uri))) + + @classmethod + def save_checkpoint( + cls, state: CheckpointState, checkpoint_uri: str | os.PathLike, include_optimizer_state: bool = False + ) -> None: + """Saves the checkpoint state to the checkpoint file + + Args: + state: The checkpoint state object. + checkpoint_uri: The path to the checkpoint file. + include_optimizer_state: If True, the optimizer state is also saved to the checkpoint file. + """ + C.save_checkpoint(state._state, os.fspath(checkpoint_uri), include_optimizer_state) + + @property + def parameters(self) -> Parameters: + """Returns the model parameters from the checkpoint state""" + return self._parameters + + @property + def properties(self) -> Properties: + """Returns the properties from the checkpoint state""" + return self._properties diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py index c071f01f87ea5..8e21013da2353 100644 --- a/orttraining/orttraining/python/training/ort_triton/_codegen.py +++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py @@ -263,10 +263,20 @@ def ReduceKernelNode( # noqa: N802 "Rsqrt": "{indent}{o0} = 1.0 / tl.sqrt({i0})\n", "Cast": "{indent}{o0} = {i0}.to(tl.{dtype})\n", "CastBool": "{indent}{o0} = {i0} != 0\n", - "Erf": "{indent}{o0} = tl.libdevice.erf({i0})\n", - "Gelu": "{indent}{o0} = (tl.libdevice.erf({i0} / 1.41421356237) + 1.0) * 0.5\n", + "Erf": "{indent}{o0} = tl.erf({i0})\n", + "Gelu": "{indent}{o0} = {i0} * 0.5 * (tl.math.erf({i0} * 0.70710678118654752440) + 1.0)\n", + "QuickGelu": "{indent}{o0} = {i0} * tl.sigmoid({i0} * {alpha})\n", + "GeluGrad": ( + "{indent}{o0} = {i0} * (0.5 * (1.0 + tl.math.erf(0.70710678118654752440 * {i1})) + " + "{i1} * 1.12837916709551257390 * 0.70710678118654752440 * 0.5 * tl.exp(-0.5 * {i1} * {i1}))\n" + ), + "QuickGeluGrad": ( + "{indent}tmp_v = {i1} * {alpha}\n" + "{indent}tmp_sigmoid = tl.sigmoid(tmp_v)\n" + "{indent}{o0} = {i0} * tmp_sigmoid * (1.0 + tmp_v * (1.0 - tmp_sigmoid))\n" + ), "Exp": "{indent}{o0} = tl.exp({i0})\n", - "Tanh": "{indent}{o0} = tl.libdevice.tanh({i0})\n", + "Tanh": "{indent}{o0} = tl.math.tanh({i0})\n", "Where": "{indent}{o0} = tl.where({i0}, {i1}, {i2})\n", "Sigmoid": "{indent}{o0} = tl.sigmoid({i0})\n", "Log": "{indent}{o0} = tl.log({i0})\n", @@ -303,6 +313,9 @@ def ComputeNode( # noqa: N802 else: kwargs["dtype"] = to_dtype.__name__ + if op_type == "QuickGelu" or op_type == "QuickGeluGrad": + kwargs["alpha"] = str(node.attributes.get("alpha", 1.702)) + if op_type == "Sum": output_var = kwargs["o0"] formula = " + ".join([kwargs[f"i{idx}"] for idx in range(len(node.inputs))]) diff --git a/orttraining/orttraining/python/training/ort_triton/_common.py b/orttraining/orttraining/python/training/ort_triton/_common.py index 65540202420b5..82ac82cfa2919 100644 --- a/orttraining/orttraining/python/training/ort_triton/_common.py +++ b/orttraining/orttraining/python/training/ort_triton/_common.py @@ -131,6 +131,10 @@ class TypeAndShapeInfer: "ReduceMax": _infer_reduction, "ReduceMin": _infer_reduction, "Sum": _infer_elementwise, + "Gelu": _infer_unary, + "QuickGelu": _infer_unary, + "GeluGrad": _infer_elementwise, + "QuickGeluGrad": _infer_elementwise, } @classmethod diff --git a/orttraining/orttraining/python/training/ort_triton/_ir.py b/orttraining/orttraining/python/training/ort_triton/_ir.py index 8aa5c1b13159b..f7d3b31eac5b6 100644 --- a/orttraining/orttraining/python/training/ort_triton/_ir.py +++ b/orttraining/orttraining/python/training/ort_triton/_ir.py @@ -5,7 +5,7 @@ from abc import abstractmethod from collections import defaultdict -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple import numpy as np import sympy @@ -184,14 +184,25 @@ class ComputeNode(IRNode): Each operator is represented as a ComputeNode. """ - def __init__(self, op_type: str, inputs: List[TensorArg], outputs: List[TensorArg]): + def __init__( + self, + op_type: str, + inputs: List[TensorArg], + outputs: List[TensorArg], + attributes: Dict[str, Any] = {}, # noqa: B006 + ): super().__init__(inputs, outputs) self._op_type: str = op_type + self._attributes: Dict[str, Any] = attributes @property def op_type(self): return self._op_type + @property + def attributes(self): + return self._attributes + class ReduceNode(ComputeNode): def __init__(self, op_type: str, inputs: List[TensorArg], outputs: List[TensorArg], offset_calc: OffsetCalculator): diff --git a/orttraining/orttraining/python/training/ort_triton/_lowering.py b/orttraining/orttraining/python/training/ort_triton/_lowering.py index 16db9ab000834..5de60e69437a0 100644 --- a/orttraining/orttraining/python/training/ort_triton/_lowering.py +++ b/orttraining/orttraining/python/training/ort_triton/_lowering.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Set, Tuple import sympy -from onnx import NodeProto +from onnx import NodeProto, helper from ._common import AutotuneConfigs, TensorInfo from ._ir import ( @@ -378,7 +378,10 @@ def _to_compute_node(self, node: NodeProto, offset_calc: OffsetCalculator): return DropoutNode(inputs, outputs, offset_calc) if is_reduction_node(node): return ReduceNode(op_type, inputs, outputs, offset_calc) - return ComputeNode(op_type, inputs, outputs) + attributes = {} + for attr in node.attribute: + attributes[attr.name] = helper.get_attribute_value(attr) + return ComputeNode(op_type, inputs, outputs, attributes) def _analyze_kernel_io_list(self): cross_kernel_inputs = set() diff --git a/orttraining/orttraining/python/training/ort_triton/_op_config.py b/orttraining/orttraining/python/training/ort_triton/_op_config.py index f58d0e1847207..7d9af00933a75 100644 --- a/orttraining/orttraining/python/training/ort_triton/_op_config.py +++ b/orttraining/orttraining/python/training/ort_triton/_op_config.py @@ -36,6 +36,10 @@ "DropoutGrad": {"domain": "com.microsoft", "versions": [1]}, "Identity": {"versions": [13], "is_no_op": True}, "Sum": {"versions": [13]}, + "Gelu": {"domain": "com.microsoft", "versions": [1]}, + "QuickGelu": {"domain": "com.microsoft", "versions": [1]}, + "GeluGrad": {"domain": "com.microsoft", "versions": [1]}, + "QuickGeluGrad": {"domain": "com.microsoft", "versions": [1]}, } _REDUCTION_OPS = { diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 4c72b6d98a088..f75d553a5f460 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -28,7 +28,8 @@ class PythonOpShapeInferStore: @classmethod def register(cls, kclass: torch.autograd.Function) -> None: - """Register a shape inference function for a torch.autograd.Function if there is staticmethod "infer_shape" defined. + """Register a shape inference function for a torch.autograd.Function if there is staticmethod + "infer_shape" defined. The signature of the shape inference function should be: @staticmethod @@ -51,6 +52,11 @@ def infer_shape( if hasattr(kclass, "infer_shape") and kclass_name not in cls._CLASS_MAP: cls._CLASS_MAP[kclass_name] = kclass.infer_shape + @classmethod + def register_func(cls, name: str, func: Callable) -> None: + """Register a shape inference function for a torch.autograd.Function by name.""" + cls._CLASS_MAP[name] = func + @classmethod def get_shape_infer(cls, name: str) -> Optional[Callable]: return cls._CLASS_MAP.get(name, None) @@ -228,9 +234,9 @@ def _export_pt_1_10(g, n, *args, **kwargs): input_float_tuples.extend(list(arg)) continue - is_inspect_activation = ( - func_full_qual_name == "onnxruntime.training.utils.hooks._subscriber_manager._InspectActivation" - ) + from onnxruntime.training.utils.hooks._statistics_subscriber import _InspectActivation + + is_inspect_activation = func_full_qual_name == get_fully_qualified_class_name(_InspectActivation) if is_inspect_activation and isinstance(arg, str): # _InspectActivation is a special case where the first argument is a string # that is used to determine the activation name to be inspected. @@ -307,14 +313,7 @@ def _export_pt_1_10(g, n, *args, **kwargs): _export = wrap_custom_export_function(_export_pt_1_10) -def _post_process_after_export(exported_model: ModelProto, enable_custom_autograd_function: bool) -> ModelProto: - """Post process the exported model.""" - if enable_custom_autograd_function: - exported_model = _post_process_enabling_autograd_function(exported_model) - return exported_model - - -def _post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto: +def post_process_enabling_autograd_function(exported_model: ModelProto) -> ModelProto: # Loop all PythonOp, append "_ctx" as the first output. index = 0 for node in exported_model.graph.node: @@ -330,8 +329,7 @@ def _post_process_enabling_autograd_function(exported_model: ModelProto) -> Mode op_name_prefix = kclass_name break - if not node.name: - node.name = f"{op_name_prefix}_id_{index}" - index += 1 + node.name = f"{op_name_prefix}_id_{index}" + index += 1 return exported_model diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index 845c7d83c2e7b..a5b96c4e37140 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -376,6 +376,16 @@ def wrap_all_outputs(result): result = backward_function(*wrapped_args) # Extract results as DLPack tensor list. + if isinstance(result, torch.Tensor): + result = [result] + elif isinstance(result, (tuple, list)): + result = list(result) + else: + raise wrap_exception( + ORTModuleIOError, + TypeError(f"ORTModule does not support the following model output type {type(result)}."), + ) + wrapped_returned_args = wrap_all_outputs(result) torch_interop_utils.unregister_grad_fn(id(ctx)) diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 2227b630aee23..dfaac5f0fa836 100755 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -19,11 +19,10 @@ import onnxruntime from onnxruntime.capi import _pybind_state as C from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -from onnxruntime.training.utils import ORTModelInputOutputSchemaType +from onnxruntime.training.utils import ORTModelInputOutputSchemaType, onnx_dtype_to_pytorch from onnxruntime.training.utils.hooks import configure_ort_compatible_zero_stage3 from . import _are_deterministic_algorithms_enabled, _io, _logger, _onnx_models, _utils -from ._custom_autograd_function_exporter import _post_process_after_export from ._fallback import ( ORTModuleDeviceException, ORTModuleONNXModelException, @@ -141,9 +140,14 @@ def __init__( register_triton_op_executor() + self._zero_stage3_param_map = {} if self._runtime_options.enable_zero_stage3_support: # Cannot toggle feature enabling/disabling after the first time enabled. - configure_ort_compatible_zero_stage3() + from onnxruntime.training.utils.hooks._zero_offload_subscriber import _get_all_zero_stage3_params + + self._zero_stage3_param_map = _get_all_zero_stage3_params(self._flattened_module) + + configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="ort_output", stats_overwrite=True) def _get_torch_gpu_allocator_function_addresses(self): if self._runtime_options.use_external_gpu_allocator and torch.cuda.is_available(): @@ -345,7 +349,8 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu ) if os.path.exists(cache_dir) and os.path.isfile(filename): self._logger.info( - f"Cached model detected! Cached model will be used to save export and initialization time. If you want the model to be re-exported then DELETE {filename}." + f"Cached model detected! Cached model will be used to save export and initialization time." + f"If you want the model to be re-exported then DELETE {filename}." ) exported_model = onnx.load(filename) return exported_model @@ -409,9 +414,24 @@ def _get_exported_model(self, input_schema: ORTModelInputOutputSchemaType, *inpu ) exported_model = onnx.load_model_from_string(f.getvalue()) - exported_model = _post_process_after_export( - exported_model, self._runtime_options.enable_custom_autograd_function - ) + if self._runtime_options.enable_custom_autograd_function: + from ._custom_autograd_function_exporter import post_process_enabling_autograd_function + + exported_model = post_process_enabling_autograd_function(exported_model) + + if self._runtime_options.enable_zero_stage3_support: + from ._zero_stage3_compatibility import post_processing_enable_zero_stage3_compat + + exported_model = post_processing_enable_zero_stage3_compat( + exported_model, + self._zero_stage3_param_map, + [name for name, _ in self._flattened_module.named_parameters()], + ) + + # Cannot append pull weight trigger name to input names as following, otherwise, the later check ( + # https://github.com/microsoft/onnxruntime/blob/068300d97eb25e5b52324e7af54a45ed1fa6a4c3/orttraining/orttraining/python/training/ortmodule/_training_manager.py#L466C18-L466C18) + # find input info mismatch, will re-initialize the graph builder. + # self._input_info.require_grad_names.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) # Cache model for future runs if cache_dir: @@ -477,7 +497,14 @@ def _initialize_graph_builder(self): grad_builder_config = C.OrtModuleGraphBuilderConfiguration() grad_builder_config.initializer_names = initializer_names grad_builder_config.initializer_names_to_train = initializer_names_to_train - grad_builder_config.input_names_require_grad = self._input_info.require_grad_names + + input_names_require_grad = self._input_info.require_grad_names + if self._runtime_options.enable_zero_stage3_support: + from ._zero_stage3_compatibility import STAGE3_PULL_WEIGHT_TRIGGER_NAME + + # Add stage3 pull weight trigger name to require_grad_names, so that it will be included in the gradient graph. + input_names_require_grad.append(STAGE3_PULL_WEIGHT_TRIGGER_NAME) + grad_builder_config.input_names_require_grad = input_names_require_grad grad_builder_config.build_gradient_graph = self._export_mode == torch.onnx.TrainingMode.TRAINING grad_builder_config.enable_caching = self._runtime_options.enable_grad_acc_optimization grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel( @@ -553,6 +580,9 @@ def _enable_conditional_optimizations( inputs, kwargs ) + if self._runtime_options.enable_zero_stage3_support: + self._append_pull_weight_trigger_as_input(kwargs, detected_device) + _, embed_sparsity_results, label_sparsity_results = _io._combine_input_buffers_initializers( self._graph_initializers, self._graph_builder.get_graph_info().user_input_names, @@ -562,6 +592,7 @@ def _enable_conditional_optimizations( kwargs, detected_device, self._runtime_inspector, + self._zero_stage3_param_map, ) # Enable sparsity-based optimization when applicable. @@ -587,6 +618,21 @@ def _enable_conditional_optimizations( if self._runtime_options.print_memory_stat: self._runtime_inspector.enable_memory_inspector(self._original_module) + def _append_pull_weight_trigger_as_input(self, kwargs: Dict, device: torch.device): + from ._zero_stage3_compatibility import ( + STAGE3_PULL_WEIGHT_TRIGGER_NAME, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, + ) + + kwargs[STAGE3_PULL_WEIGHT_TRIGGER_NAME] = torch.zeros( + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, + dtype=onnx_dtype_to_pytorch(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE), + device=device, + ).requires_grad_() + + return kwargs + def _log_feature_stats(self): if get_rank() != 0: return diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index b7c01a1f5baf9..8d8be81c549d1 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -159,6 +159,9 @@ def forward(self, *inputs, **kwargs): # Assert that the input and model device match _utils._check_same_device(self._device, "Input argument to forward", *inputs) + if self._runtime_options.enable_zero_stage3_support: + self._append_pull_weight_trigger_as_input(kwargs, self._device) + prepared_input_list, _, _ = _io._combine_input_buffers_initializers( self._graph_initializers, self._graph_info.user_input_names, @@ -168,6 +171,7 @@ def forward(self, *inputs, **kwargs): kwargs, self._device, self._runtime_inspector, + self._zero_stage3_param_map, ) user_outputs, _ = InferenceManager.execution_session_run_forward( diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 18b965c549645..e7c1b30daae0d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -168,6 +168,7 @@ def _combine_input_buffers_initializers( kwargs: Mapping[str, ORTModelInputOutputType], device: torch.device, rt_inspector: RuntimeInspector, + zero_stage3_offload_param_map: Optional[Dict[str, torch.nn.parameter.Parameter]], ): """Creates forward `*inputs` list from user input and PyTorch initializers @@ -254,7 +255,12 @@ def _expand_inputs(current_input, non_none_inputs, name=""): ) # params is a list of all initializers known to the onnx graph - result.extend(params) + if zero_stage3_offload_param_map: + for p in params: + if p not in zero_stage3_offload_param_map.values(): + result.append(p) + else: + result.extend(params) return result, embed_sparsity_results, label_sparsity_results diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 3be4c05797978..19effe2086e0a 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -311,6 +311,9 @@ def forward(self, *inputs, **kwargs): self._gradient_accumulation_manager.maybe_update_cache_before_run() + if self._runtime_options.enable_zero_stage3_support: + self._append_pull_weight_trigger_as_input(kwargs, self._device) + prepared_input_list, _, _ = _io._combine_input_buffers_initializers( self._graph_initializers, self._graph_info.user_input_names, @@ -320,6 +323,7 @@ def forward(self, *inputs, **kwargs): kwargs, self._device, self._runtime_inspector, + self._zero_stage3_param_map, ) outputs = unflatten_user_output( diff --git a/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py new file mode 100644 index 0000000000000..17756600d601e --- /dev/null +++ b/orttraining/orttraining/python/training/ortmodule/_zero_stage3_compatibility.py @@ -0,0 +1,312 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from typing import Dict, List, Optional, Tuple, Union + +import torch +from onnx import ModelProto, NodeProto, TensorProto, ValueInfoProto, helper + +from onnxruntime.capi._pybind_state import register_torch_autograd_function +from onnxruntime.training.utils import pytorch_dtype_to_onnx + +from ._custom_autograd_function_exporter import PythonOpShapeInferStore +from ._utils import get_fully_qualified_class_name + +STAGE3_PULL_WEIGHT_TRIGGER_NAME = "pull_weight_trigger" +STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE = TensorProto.FLOAT +STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE = [1] + + +def post_processing_enable_zero_stage3_compat( + exported_model: ModelProto, + zero_stage3_named_params: Dict[str, torch.nn.parameter.Parameter], + all_param_names: List[str], +) -> ModelProto: + """This function is used to enable zero stage3 compatibility. + + Args: + exported_model (ModelProto): The exported model. + zero_stage3_named_params (Optional[Dict[str, torch.nn.parameter.Parameter]]): The offload named parameters. + all_param_names (List[str]): All parameter names. + """ + + # Register symbolic shape inference functions for PythonOp used in DeepSpeed ZeRO stage3. + _register_symbolic_shape_infer_functions() + + # Create weight retrieving function using zero_stage3_named_params. + func_full_qual_name = _create_weight_retrieval_function(zero_stage3_named_params) + + consumer_map = {} + for node in exported_model.graph.node: + for inp in node.input: + if inp not in consumer_map: + consumer_map[inp] = [] + + if node not in consumer_map[inp]: + consumer_map[inp].append(node) + + def _get_param_pull_trigger_name(param_name: str) -> str: + return f"pull_{param_name}" + + def _get_func_name(node: NodeProto) -> Optional[str]: + for attr in node.attribute: + if attr.name == "func_name": + return attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s + return None + + # Create weight retrieving PythonOp. + new_input, weight_pull_node = _create_weight_retrieval_pythonop( + zero_stage3_named_params, + func_full_qual_name, + STAGE3_PULL_WEIGHT_TRIGGER_NAME, + [_get_param_pull_trigger_name(pname) for pname in zero_stage3_named_params], + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE, + ) + + from onnxruntime.training.utils.hooks._zero_offload_subscriber import ORTZeROOffloadPreForwardFunction + + prefowrad_function_name = get_fully_qualified_class_name(ORTZeROOffloadPreForwardFunction) + + # Connect weight consumers to use the full-sized parameter output of ORTZeROOffloadPreForwardFunction. + for graph_input in exported_model.graph.input: + if graph_input.name not in zero_stage3_named_params: + continue + + if graph_input.name not in consumer_map: + continue + + consumers = consumer_map[graph_input.name] + pre_forward_pythonop_node = None + + for c in consumers: + if c.op_type != "PythonOp": + continue + + func_name = _get_func_name(c) + if func_name == prefowrad_function_name: + assert ( + pre_forward_pythonop_node is None + ), "Multiple ORTZeROOffloadPreForwardFunction nodes found, it should not happen" + pre_forward_pythonop_node = c + + if pre_forward_pythonop_node is None: + raise RuntimeError( + "Fail to find ORTZeROOffloadPreForwardFunction for partitioned param: " + graph_input.name + ) + + index_offset_on_python_op_input = [] + for i, input_name in enumerate(pre_forward_pythonop_node.input): + if input_name == graph_input.name: + index_offset_on_python_op_input.append(i) + + assert ( + len(index_offset_on_python_op_input) == 1 + ), f"index_offset_on_python_op_input length is not 1: {index_offset_on_python_op_input}" + + reverse_index_among_inputs = index_offset_on_python_op_input[0] - len(pre_forward_pythonop_node.input) + new_input_name = _get_param_pull_trigger_name(graph_input.name) + pre_forward_pythonop_node.input[index_offset_on_python_op_input[0]] = new_input_name + + _update_python_op_input_related_attributes( + pre_forward_pythonop_node, + new_input_name, + len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE), # new rank + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, # new data type + ) + + output_index = reverse_index_among_inputs + len(pre_forward_pythonop_node.output) + pre_forward_pythonop_node.output[output_index] = graph_input.name + + # If the consumer of original `graph_input.name` is PythonOp, we need also update its attributes because now + # `graph_input.name` as output of pre_forward_pythonop_node, is full-sized parameter, the rank might differ + # from the original one. + for c in consumers: + if c == pre_forward_pythonop_node or c.op_type != "PythonOp": + continue + _update_python_op_input_related_attributes( + c, + graph_input.name, + len(zero_stage3_named_params[graph_input.name].ds_shape), # new rank + pytorch_dtype_to_onnx(zero_stage3_named_params[graph_input.name].dtype), # new data type + ) + + # Delete exported_model.graph.input + graph_inputs_to_remove = [ + graph_input for graph_input in exported_model.graph.input if graph_input.name in zero_stage3_named_params + ] + for input_to_remove in graph_inputs_to_remove: + exported_model.graph.input.remove(input_to_remove) + + # Re-order graph input to make sure the weight pull trigger is before all parameter inputs. + offset = 0 + for graph_input in exported_model.graph.input: + if graph_input.name in all_param_names: + break + offset += 1 + + exported_model.graph.input.insert(offset, new_input) + exported_model.graph.node.insert(0, weight_pull_node) + + return exported_model + + +def _create_weight_retrieval_function( + zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]] +) -> str: + """This function is used to create a weight retrieving function using zero_stage3_named_params.""" + + class WeightRetrievalFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, weight_in_trigger): + params = list(zero_stage3_named_params.values()) + ctx.params = params + ctx.dtype = weight_in_trigger.dtype + ctx.device = weight_in_trigger.device + ctx.shape = weight_in_trigger.shape + return (torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype),) * len(params) + + @staticmethod + def backward(ctx, *grad_outputs): + return torch.zeros(ctx.shape, device=ctx.device, dtype=ctx.dtype) + + @staticmethod + def infer_shape( + node: NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + param_count = len(zero_stage3_named_params.values()) + tensor_output_shapes = [ + tensor_input_shapes[0], + ] * param_count + tensor_output_dtypes = [ + tensor_input_dtypes[0], + ] * param_count + return tensor_output_shapes, tensor_output_dtypes + + func_full_qual_name = get_fully_qualified_class_name(WeightRetrievalFunction) + register_torch_autograd_function(func_full_qual_name, WeightRetrievalFunction) + PythonOpShapeInferStore.register(WeightRetrievalFunction) + + return func_full_qual_name + + +def _register_symbolic_shape_infer_functions(): + """This function is used to register symbolic shape inference functions for PythonOp used in + DeepSpeed ZeRO stage3.""" + + def _simple_pass_through_infer_shape( + node: NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + return tensor_input_shapes, tensor_input_dtypes + + PythonOpShapeInferStore.register_func( + "deepspeed.runtime.zero.parameter_offload.PreBackwardFunction", _simple_pass_through_infer_shape + ) + PythonOpShapeInferStore.register_func( + "deepspeed.runtime.zero.parameter_offload.PostBackwardFunction", _simple_pass_through_infer_shape + ) + + def _linear_infer_shape( + node: NodeProto, + tensor_input_shapes: List[Optional[List[Union[int, str]]]], + tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], + ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: + # output = input.matmul(weight.t()) + tensor_input_shapes[0] # input + shape2 = tensor_input_shapes[1] # weight + output_shape = tensor_input_shapes[0] + output_shape[-1] = shape2[-2] + return [output_shape], [tensor_input_dtypes[0]] + + PythonOpShapeInferStore.register_func( + "deepspeed.runtime.zero.linear.LinearFunctionForZeroStage3", _linear_infer_shape + ) + + +def _create_weight_retrieval_pythonop( + zero_stage3_named_params: Optional[Dict[str, torch.nn.parameter.Parameter]], + func_full_qual_name: str, + input_name: str, + output_names: List[str], + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, + STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE: List[int], +) -> Tuple[ValueInfoProto, NodeProto]: + """This function is used to create a weight retrieving PythonOp.""" + offload_param_count = 0 if zero_stage3_named_params is None else len(zero_stage3_named_params) + new_input = helper.make_tensor_value_info( + input_name, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE, STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE + ) + output_rank_for_pull_weight_trigger = len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE) + output_dtype_for_pull_weight_trigger = STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE + output_tensor_ranks = [ + output_rank_for_pull_weight_trigger, + ] * offload_param_count + output_tensor_types = [ + output_dtype_for_pull_weight_trigger, + ] * offload_param_count + + node_attributes = { + "comment": "", + "inplace": 0, + "input_convention": "d", + "input_tensor_ranks": [len(STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_SHAPE)], + "input_tensor_types": [STAGE3_PULL_WEIGHT_TRIGGER_OUTPUT_DTYPE], + "output_tensor_ranks": output_tensor_ranks, + "output_tensor_types": output_tensor_types, + "training_mode": 1, + "func_name": func_full_qual_name, + } + + weight_pull_node = helper.make_node( + "PythonOp", + [input_name], + ["pull_weight_trigger_ctx", *output_names], + "pull_weight_trigger", # node name + "PythonOp for weight retrieving.", + "com.microsoft", + **node_attributes, + ) + + return new_input, weight_pull_node + + +def _update_python_op_input_related_attributes(node: NodeProto, input_name: str, new_rank: int, new_dtype: int): + """This function is used to update PythonOp's input related attributes, e.g. + input_tensor_ranks and input_tensor_types. + + Args: + node (NodeProto): The PythonOp node. + input_name (str): The input name to be updated. + new_rank (int): The new rank of the input, to be used in input_tensor_ranks. + new_dtype (int): The new data type of the input, to be used in input_tensor_types. + """ + input_tensor_ranks = None + input_tensor_dtypes = None + rank_attr = None + dtype_attr = None + for attr in node.attribute: + if attr.name == "input_tensor_ranks": + input_tensor_ranks = attr.ints + rank_attr = attr + if attr.name == "input_tensor_types": + input_tensor_dtypes = attr.ints + dtype_attr = attr + + assert input_tensor_ranks is not None, "input_tensor_ranks is None" + assert input_tensor_dtypes is not None, "input_tensor_dtypes is None" + + for index, node_input_name in enumerate(node.input): + if node_input_name == input_name: + input_tensor_ranks[index] = new_rank + input_tensor_dtypes[index] = new_dtype + + node.attribute.remove(rank_attr) + node.attribute.remove(dtype_attr) + node.attribute.append(helper.make_attribute("input_tensor_ranks", input_tensor_ranks)) + node.attribute.append(helper.make_attribute("input_tensor_types", input_tensor_dtypes)) diff --git a/orttraining/orttraining/python/training/utils/__init__.py b/orttraining/orttraining/python/training/utils/__init__.py index acf2698d55eaf..fa7c9f2750cdd 100644 --- a/orttraining/orttraining/python/training/utils/__init__.py +++ b/orttraining/orttraining/python/training/utils/__init__.py @@ -9,7 +9,7 @@ extract_data_and_schema, unflatten_data_using_schema, ) -from onnxruntime.training.utils.torch_type_map import pytorch_dtype_to_onnx +from onnxruntime.training.utils.torch_type_map import onnx_dtype_to_pytorch, pytorch_dtype_to_onnx __all__ = [ "PrimitiveType", @@ -18,4 +18,5 @@ "extract_data_and_schema", "unflatten_data_using_schema", "pytorch_dtype_to_onnx", + "onnx_dtype_to_pytorch", ] diff --git a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py index 6c8027b2fefaa..db1c69cf95ba4 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_statistics_subscriber.py @@ -6,6 +6,7 @@ import os import shutil import warnings +from io import TextIOWrapper from pathlib import Path from typing import List, Optional, Tuple, Union @@ -178,87 +179,97 @@ def _summarize_activations(self, tensor: torch.Tensor, depth: int, name: str, st order_file_path = step_path / "order.txt" tensor_file_path = step_path / output_file_name - # This is to try the best effort to align the count of numbers per line for easier comparison in diff views, - # though it does not always guarantee to do this way. - torch.set_printoptions(precision=6, linewidth=128) - - tensor_shape = tensor.shape - tensor_dtype = tensor.dtype - flatten_array = tensor.flatten().view(-1) - - if self._run_on_cpu: - flatten_array = flatten_array.to("cpu") - - if self._run_on_cpu: - num_nan = torch.isnan(flatten_array).sum() - num_inf = torch.isinf(flatten_array).sum() - num_neg = (flatten_array < 0).sum() - num_pos = (flatten_array > 0).sum() - num_zero = (flatten_array == 0).sum() - min_value = flatten_array.min() - max_value = flatten_array.max() - mean_value = flatten_array.mean() - std_value = flatten_array.std() - else: - # Split the calculation for each bucket, then do another round of calculation on the bucket results. - # This can at the best effort reduce the peak memory impact. - bucket_size = self._bucket_size - element_count = flatten_array.numel() - ceil_bucket_count = (element_count + bucket_size - 1) // (bucket_size) - nan_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - inf_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - neg_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - pos_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - zero_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - min_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - max_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - mean_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - std_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) - - # Summary for each bucket - element_count_per_bucket = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) - for i in range(ceil_bucket_count): - end = min((i + 1) * bucket_size, element_count) - bucket = flatten_array[i * bucket_size : end] - element_count_per_bucket[i] = bucket.numel() - - nan_buckets[i] = torch.isnan(bucket).sum() - inf_buckets[i] = torch.isinf(bucket).sum() - neg_buckets[i] = (bucket < 0).sum() - pos_buckets[i] = (bucket > 0).sum() - zero_buckets[i] = (bucket == 0).sum() - min_buckets[i] = bucket.min() - max_buckets[i] = bucket.max() - mean_buckets[i] = bucket.sum() - std_buckets[i] = bucket.std() - - # Reduction across all buckets - num_nan = nan_buckets.sum() - num_inf = inf_buckets.sum() - num_neg = neg_buckets.sum() - num_pos = pos_buckets.sum() - num_zero = zero_buckets.sum() - min_value = min_buckets.min() - max_value = max_buckets.max() - mean_value = float(mean_buckets.sum()) / float(element_count) - # Here we refer to - # https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups - # to calculate the combined standard deviation of all buckets. - s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * ( - (mean_buckets - mean_value) ** 2 - ) - std_value = torch.sqrt(s.sum() / (element_count - 1)) - with order_file_path.open(mode="a", encoding="utf-8") as f: f.write(f"{output_file_name}\n") with tensor_file_path.open(mode="w", encoding="utf-8") as f: - f.write( - f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n" - f"min: {min_value} max: {max_value}, mean: {mean_value}, " - f"std: {std_value} \n" - f"nan: {num_nan}, inf: {num_inf}\n" - ) - f.write(f"samples(top 128): {flatten_array[:128]}\n") - f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n") - f.write(f"{'='*16}\n") + _summarize_tensor(display_name, tensor, f, depth, self._run_on_cpu, self._bucket_size) + + +def _summarize_tensor( + display_name: str, + tensor: torch.Tensor, + f: TextIOWrapper, + depth: int = 0, + run_on_cpu: bool = False, + bucket_size: int = 1024 * 1024 * 1024 // 2, +): + # This is to try the best effort to align the count of numbers per line for easier comparison in diff views, + # though it does not always guarantee to do this way. + torch.set_printoptions(precision=6, linewidth=128) + + tensor_shape = tensor.shape + tensor_dtype = tensor.dtype + flatten_array = tensor.flatten().view(-1) + + if run_on_cpu: + flatten_array = flatten_array.to("cpu") + + if run_on_cpu: + num_nan = torch.isnan(flatten_array).sum() + num_inf = torch.isinf(flatten_array).sum() + num_neg = (flatten_array < 0).sum() + num_pos = (flatten_array > 0).sum() + num_zero = (flatten_array == 0).sum() + min_value = flatten_array.min() + max_value = flatten_array.max() + mean_value = flatten_array.mean() + std_value = flatten_array.std() + else: + # Split the calculation for each bucket, then do another round of calculation on the bucket results. + # This can at the best effort reduce the peak memory impact. + element_count = flatten_array.numel() + ceil_bucket_count = (element_count + bucket_size - 1) // (bucket_size) + nan_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + inf_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + neg_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + pos_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + zero_buckets = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + min_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + max_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + mean_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + std_buckets = torch.zeros(ceil_bucket_count, dtype=flatten_array.dtype, device=flatten_array.device) + + # Summary for each bucket + element_count_per_bucket = torch.zeros(ceil_bucket_count, dtype=torch.int64, device=flatten_array.device) + for i in range(ceil_bucket_count): + end = min((i + 1) * bucket_size, element_count) + bucket = flatten_array[i * bucket_size : end] + element_count_per_bucket[i] = bucket.numel() + + nan_buckets[i] = torch.isnan(bucket).sum() + inf_buckets[i] = torch.isinf(bucket).sum() + neg_buckets[i] = (bucket < 0).sum() + pos_buckets[i] = (bucket > 0).sum() + zero_buckets[i] = (bucket == 0).sum() + min_buckets[i] = bucket.min() + max_buckets[i] = bucket.max() + mean_buckets[i] = bucket.sum() + std_buckets[i] = bucket.std() + + # Reduction across all buckets + num_nan = nan_buckets.sum() + num_inf = inf_buckets.sum() + num_neg = neg_buckets.sum() + num_pos = pos_buckets.sum() + num_zero = zero_buckets.sum() + min_value = min_buckets.min() + max_value = max_buckets.max() + mean_value = float(mean_buckets.sum()) / float(element_count) + # Here we refer to + # https://math.stackexchange.com/questions/2971315/how-do-i-combine-standard-deviations-of-two-groups + # to calculate the combined standard deviation of all buckets. + s = (element_count_per_bucket - 1) * (std_buckets**2) + element_count_per_bucket * ( + (mean_buckets - mean_value) ** 2 + ) + std_value = torch.sqrt(s.sum() / (element_count - 1)) + + f.write( + f"{'>'*max(0, depth) + display_name} shape: {tensor_shape} dtype: {tensor_dtype} size: {flatten_array.size()} \n" + f"min: {min_value} max: {max_value}, mean: {mean_value}, " + f"std: {std_value} \n" + f"nan: {num_nan}, inf: {num_inf}\n" + ) + f.write(f"samples(top 128): {flatten_array[:128]}\n") + f.write(f"neg: {num_neg}, pos: {num_pos}, zero: {num_zero},\n") + f.write(f"{'='*16}\n") diff --git a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py index db38f58d8f324..b2bc64be42fc1 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py +++ b/orttraining/orttraining/python/training/utils/hooks/_subscriber_manager.py @@ -29,14 +29,6 @@ def no_increase_global_step(): finally: ORT_NO_INCREASE_GLOBAL_STEP[0] = False - @staticmethod - def infer_shape( - node: onnx.NodeProto, - tensor_input_shapes: List[Optional[List[Union[int, str]]]], - tensor_input_dtypes: List[torch.onnx.TensorProtoDataType], - ) -> Tuple[List[Optional[List[Union[int, str]]]], List[torch.onnx.TensorProtoDataType]]: - return tensor_input_shapes, tensor_input_dtypes - class _IncrementStep(torch.autograd.Function): """This class is used to manage the global execution step, e.g. @@ -55,8 +47,9 @@ def forward(ctx, run_ctx: RuntimeStates, *input_tensor_list: Tuple[torch.Tensor, ctx.current_step = run_ctx.global_states.execution_step ctx.run_ctx = run_ctx - if ctx.current_step >= 0: - print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}") + # Uncomment the following line for debugging purposes. + # if ctx.current_step >= 0: + # print(f"{'='*6} Completed forward pass for STEP {ctx.current_step} {'='*6}") if ORT_NO_INCREASE_GLOBAL_STEP[0] is False: ctx.run_ctx.global_states.execution_step += 1 @@ -191,7 +184,7 @@ def _reset_recursively(module: torch.nn.Module, depth: int, next_module_index: L next_module_index: list of int, carrying a global unique module index that can be used next. """ module_index = next_module_index[0] - module.id = module_index # STAGE3WARN: needed by DeepSpeed + module.id = module_index # STAGE3WARN#1: needed by DeepSpeed self._run_ctx.global_states.module_index_to_depth[module_index] = depth self._run_ctx.global_states.module_to_module_index[module] = module_index @@ -217,7 +210,7 @@ def _register_hooks_recursively(self, module: torch.nn.Module, depth: int, next_ next_module_index: list of int, carrying a global unique module index that can be used next. """ module_index = next_module_index[0] - module.id = module_index # STAGE3WARN: needed by DeepSpeed + module.id = module_index # STAGE3WARN#2: needed by DeepSpeed self._run_ctx.global_states.module_index_to_depth[module_index] = depth self._run_ctx.global_states.module_to_module_index[module] = module_index diff --git a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py index 3d42e172eea82..ad1297962db71 100644 --- a/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py +++ b/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py @@ -23,25 +23,37 @@ from ._subscriber_base import RuntimeStates, SubscriberBase -# Used to monkey patch the original function -# Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/parameter_offload.py#L333 -def _setup_zero_stage3_ort_compatible_hooks(self): - self.hierarchy = 0 +def _get_ort_compatible_zero_stage3_hook_function(debug, stats_output_dir, stats_overwrite): + """Create ort compatible hook function for DeepSpeed ZeRO stage3. - from onnxruntime.training.utils.hooks import SubscriberManager, ZeROOffloadSubscriber - from onnxruntime.training.utils.hooks._zero_offload_subscriber import _zero_offload_one_time_initializer + Args: + debug: whether to enable convergence debugging. + stats_output_dir: the directory to store convergence stats. + stats_overwrite: whether to overwrite the stats file if it already exists. + """ + + # Used to monkey patch the original function + # Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/parameter_offload.py#L333 + def _setup_zero_stage3_ort_compatible_hooks(self): + self.hierarchy = 0 + + from onnxruntime.training.utils.hooks import StatisticsSubscriber, SubscriberManager, ZeROOffloadSubscriber + from onnxruntime.training.utils.hooks._zero_offload_subscriber import _zero_offload_one_time_initializer - # Each DeepSpeed engine has a separate subscriber manager. - self._offload_subscriber_manager = SubscriberManager() - self._offload_subscriber_manager.subscribe( - self.module, [ZeROOffloadSubscriber(self, _zero_offload_one_time_initializer)] - ) - self.forward_hooks.extend(self._offload_subscriber_manager._pre_forward_hooks) - self.forward_hooks.extend(self._offload_subscriber_manager._post_forward_hooks) + subscribers = [ZeROOffloadSubscriber(self, _zero_offload_one_time_initializer)] + if debug is True: + subscribers.append(StatisticsSubscriber(output_dir=stats_output_dir, override_output_dir=stats_overwrite)) + # Each DeepSpeed engine has a separate subscriber manager. + self._offload_subscriber_manager = SubscriberManager() + self._offload_subscriber_manager.subscribe(self.module, subscribers) + self.forward_hooks.extend(self._offload_subscriber_manager._pre_forward_hooks) + self.forward_hooks.extend(self._offload_subscriber_manager._post_forward_hooks) - # Add top module to stack trace - global FWD_MODULE_STACK # noqa: PLW0602 - FWD_MODULE_STACK.append(self.module) + # Add top module to stack trace + global FWD_MODULE_STACK # noqa: PLW0602 + FWD_MODULE_STACK.append(self.module) + + return _setup_zero_stage3_ort_compatible_hooks # Adapted from https://github.com/microsoft/DeepSpeed/blob/e8318634b4313eaad89842cf4322e1762d34ced3/deepspeed/runtime/zero/linear.py#L104 @@ -86,14 +98,16 @@ def collect_code(self, function: Callable): _zero_offload_one_time_initializer.collect_code(DeepSpeedZeRoOffload.setup_zero_stage3_hooks) # This is the function to enable ORT ZeRO offload. - def configure_ort_compatible_zero_stage3(): + def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir="./", stats_overwrite=False): """Configure ZeRO stage3 to be ORT compatible. This function will overwrite the original DeepSpeed ZeRO stage3 hooks to make it ORT compatible. """ # Only done once no matter how many times this function is called for different modules. - DeepSpeedZeRoOffload.setup_zero_stage3_hooks = _setup_zero_stage3_ort_compatible_hooks + DeepSpeedZeRoOffload.setup_zero_stage3_hooks = _get_ort_compatible_zero_stage3_hook_function( + debug, stats_output_dir, stats_overwrite + ) from deepspeed.runtime.zero.linear import zero3_linear_wrap @@ -103,7 +117,7 @@ def configure_ort_compatible_zero_stage3(): except ImportError as e: warnings.warn(f"DeepSpeed import error {e}") - def configure_ort_compatible_zero_stage3(): + def configure_ort_compatible_zero_stage3(debug=False, stats_output_dir=None, stats_overwrite=False): raise RuntimeError("DeepSpeed is not installed, cannot configure ORT compatible ZeRO stage3.") @@ -115,13 +129,13 @@ def _get_params_for_current_module(module: torch.nn.Module) -> List[torch.nn.par """ from deepspeed.runtime.zero.partitioned_param_coordinator import iter_params - # Retrive the parameters that are not available for this module. + # Retrieve all parameters for this module. partitioned_params = [param for param in iter_params(module)] return partitioned_params -def _get_all_offloaded_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]: +def _get_all_zero_stage3_params(module: torch.nn.Module) -> Dict[str, torch.nn.parameter.Parameter]: """Retrieve all the parameters that are offloaded.""" from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus @@ -134,16 +148,13 @@ def _get_all_offloaded_params(module: torch.nn.Module) -> Dict[str, torch.nn.par class ORTZeROOffloadPreForwardFunction(torch.autograd.Function): - """This function is a common bridge to call original PyTorch's - pre_forward_function and post_backward_function. - """ + """This function is a common bridge to call original PyTorch's pre_forward_function""" @staticmethod def forward( ctx, module, pre_forward_with_kwargs_function, - post_backward_function, args_schema, kwargs_schema, args_tensor_count, @@ -155,7 +166,6 @@ def forward( ctx: context object module: the module to be called pre_forward_with_kwargs_function: the function to be called before forward (PyTorch's pre_forward_function) - post_backward_function: the function to be called after backward (PyTorch's post_backward_function) args_schema: the schema of the args, used to reconstruct the args in original form in PyTorch's pre_forward_function's inputs. kwargs_schema: the schema of the kwargs, used to reconstruct the kwargs in original form in @@ -168,6 +178,17 @@ def forward( args_tensors = tensor_list[:args_tensor_count] kwargs_tensors = tensor_list[args_tensor_count : args_tensor_count + kwargs_tensor_count] + # For PyTorch runs, the sizes are all 0, it does not need a gradient because + # param._detach().requires_grad_(False) is called. + # But for ORT runs, the sizes are all [1], as output of weight retrieval function. + # So we keep track of the shapes and dtypes of the passed-in tensors, then generate the grads in backward. + # While for both PyTorch and ORT runs, the grad is not important because they are not param grads + # anymore, they are only used for completing the full backward propagation. + passed_in_param_tensors = tensor_list[args_tensor_count + kwargs_tensor_count :] + ctx.shapes = [p.shape for p in passed_in_param_tensors] + ctx.dtypes = [p.dtype for p in passed_in_param_tensors] + ctx.devices = [p.device for p in passed_in_param_tensors] + args = unflatten_data_using_schema(args_tensors, args_schema) kwargs = unflatten_data_using_schema(kwargs_tensors, kwargs_schema) @@ -179,6 +200,8 @@ def forward( partitioned_params = _get_params_for_current_module(module) ctx.partitioned_params = partitioned_params + assert len(partitioned_params) == len(passed_in_param_tensors) + f_ret = pre_forward_with_kwargs_function(module, args, kwargs) if f_ret is None: @@ -188,7 +211,6 @@ def forward( updated_args, updated_kwargs = f_ret ctx.module = module - ctx.post_backward_function = post_backward_function updated_args_tensors, _ = extract_data_and_schema(updated_args) updated_kwargs_tensors, _ = extract_data_and_schema(updated_kwargs) @@ -203,17 +225,32 @@ def forward( @staticmethod def backward(ctx, *grads): updated_grads = grads - if ctx.post_backward_function is not None: - ret = ctx.post_backward_function(ctx.module, grads) - if ret is not None: - updated_grads = ret - # TODO(pengwa) Update grad for partitioned parameters. input_count = len(updated_grads) - len(ctx.partitioned_params) - zeros = [torch.zeros(0, dtype=p.dtype, device=p.device) for p in ctx.partitioned_params] - zero_grads = updated_grads[:input_count] + tuple(zeros) - - return (None, None, None, None, None, None, None, *zero_grads) + param_start_offset = input_count + + # Only need to accumulate grad explicitly for ORT run (e.g. ctx.shapes[0] == (1,)); + # In the PyTorch run, the accumulation happens automatically. + need_manual_grad_acc = len(ctx.shapes) > 0 and ctx.shapes[0] == (1,) + if need_manual_grad_acc: + for param_index, p in enumerate(ctx.partitioned_params): + g = updated_grads[param_index + param_start_offset] + if g is None: + raise RuntimeError(f"param {p} has no grad, this should not happen.") + # Param gradient accumulation is triggered here, along with the attached hooks, done by PyTorch. + assert p.shape == g.shape, f"param_index: {param_index} - param shape {p.shape} != grad shape {g.shape}" + p.backward(g) + + # At this point, the **real** param grads are already updated, the following grads are only used for + # completing the full backward propagation, will not affect parameter updates. + passed_in_param_grad = [ + torch.zeros(shape, dtype=dtype, device=device) + for shape, dtype, device in zip(ctx.shapes, ctx.dtypes, ctx.devices) + ] + + zero_grads = updated_grads[:input_count] + tuple(passed_in_param_grad) + + return (None, None, None, None, None, None, *zero_grads) @staticmethod def infer_shape( @@ -258,14 +295,14 @@ def forward( module: the module to be called post_forward_function: the function to be called after forward (PyTorch's post_forward_function) pre_backward_function: the function to be called before backward (PyTorch's pre_backward_function) - output_schema: the schema of the output, used to reconstruct the output in original form in + output_schema: the schema of the output, used to reconstruct the output in its original form in PyTorch's post_forward_function's inputs. output_tensors: the list of tensors. """ outputs = unflatten_data_using_schema(output_tensors, output_schema) - # STAGE3WARN: _post_forward_module_hook's second argument `input is not used, so we just pass a None here. + # STAGE3WARN#3: _post_forward_module_hook's second argument `input is not used, so we just pass a None here. updated_outputs = post_forward_function(module, None, outputs) if updated_outputs is None: @@ -341,11 +378,19 @@ def pre_forward_module_apply_impl( input and output for torch.autograd.Function, so we do flatten and unflatten here. """ + ## Handle `_post_backward_module_hook` - args_tensors, args_schema = extract_data_and_schema(args) - kwargs_tensors, kwargs_schema = extract_data_and_schema(kwargs) + # Put `_post_backward_module_hook` first because in backward, it is responsible for unloading parameters, + # we want ORTZeROOffloadPreForwardFunction's backward still be able to access the full sized parameters. + _post_backward_module_hook = self._functions.get("_post_backward_module_hook") + # STAGE3WARN#4: most logic in _post_backward_module_hook can be traced correctly so we don't need to + # wrap with PythonOp. For those cannot be traced, we handle them in STAGE3WARN#5. + updated_args = _post_backward_module_hook(module, args) - partitioned_params = _get_params_for_current_module(module) + ## Handle `_pre_forward_module_hook` + + args_tensors, args_schema = extract_data_and_schema(updated_args) + kwargs_tensors, kwargs_schema = extract_data_and_schema(kwargs) _pre_forward_module_hook = self._functions.get("_pre_forward_module_hook") @@ -358,18 +403,29 @@ def _wrap_pre_forward_module_hook(module, args, kwargs): if rets is not None: updated_args = rets - # STAGE3WARN: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration. + # STAGE3WARN#5: Moved from _post_backward_module_hook to make sure ORT run will trigger every iteration. module.ds_grads_remaining = 0 + return updated_args, updated_kwargs - all_tensors = args_tensors + kwargs_tensors + partitioned_params + # Need to pass the parameters as input to let the exporter trace the related weights for + # current ORTZeROOffloadPreForwardFunction + partitioned_params = _get_params_for_current_module(module) + # Don't require grad for passed-in parameter, otherwise it will be treated as a leaf node, in backward + # returned 0-sized grad did not match the param's gradient accumulator function's input shape metadata, + # PyTorch run will fail during backward. + # This will not harm parameter gradient build either in ORT or PyTorch, imagine the weights are used by + # computation anyway, so the gradient will be built. This hook only references the parameter, but won't + # generate a gradient path for it. + detached_partitioned_params = [p.detach().requires_grad_(False) for p in partitioned_params] + + all_tensors = args_tensors + kwargs_tensors + detached_partitioned_params self._check_all_tensor(all_tensors, module, "pre_forward_module_apply_impl input check") rets = ORTZeROOffloadPreForwardFunction.apply( module, _wrap_pre_forward_module_hook, - None, args_schema, kwargs_schema, args_tensor_count, @@ -385,11 +441,6 @@ def _wrap_pre_forward_module_hook(module, args, kwargs): updated_args = unflatten_data_using_schema(updated_args_tensors, args_schema) updated_kwargs = unflatten_data_using_schema(updated_kwargs_tensors, kwargs_schema) - _post_backward_module_hook = self._functions.get("_post_backward_module_hook") - # STAGE3WARN: Other part of _post_backward_module_hook can be traced correctly so we don't need to - # wrap with PythonOp. - updated_args = _post_backward_module_hook(module, updated_args) - return updated_args, updated_kwargs def post_forward_module_apply_impl( @@ -411,7 +462,7 @@ def post_forward_module_apply_impl( _post_forward_module_hook = self._functions.get("_post_forward_module_hook") def _wrap_post_forward_module_hook(module, input, outputs): - # STAGE3WARN: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here. + # STAGE3WARN#6: _post_forward_module_hook applied this for each tensor output, so we do a simple wrap here. from deepspeed.runtime.zero.partition_parameters import is_zero_param updated_outputs = _post_forward_module_hook(module, input, outputs) @@ -438,8 +489,8 @@ def _wrap_post_forward_module_hook(module, input, outputs): updated_outputs = unflatten_data_using_schema(updated_outputs_tensors, outputs_schema) _pre_backward_module_hook = self._functions.get("_pre_backward_module_hook") - # STAGE3WARN: _pre_backward_module_hook's second argument `input is not used, so we just pass a None here. - # STAGE3WARN: part of the original _pre_backward_module_hook can be traced correctly so we moved them into + # STAGE3WARN#7: _pre_backward_module_hook's second argument `input is not used, so we just pass a None here. + # STAGE3WARN#8: part of the original _pre_backward_module_hook can be traced correctly so we moved them into # _wrap_post_forward_module_hook above. updated_outputs = _pre_backward_module_hook(module, None, updated_outputs) diff --git a/orttraining/orttraining/python/training/utils/torch_type_map.py b/orttraining/orttraining/python/training/utils/torch_type_map.py index 699747723f457..bdacab8ad04fe 100644 --- a/orttraining/orttraining/python/training/utils/torch_type_map.py +++ b/orttraining/orttraining/python/training/utils/torch_type_map.py @@ -33,6 +33,8 @@ _DTYPE_TO_ONNX = {torch_dtype: onnx_dtype for k, (onnx_dtype, torch_dtype) in _CAST_PYTORCH_TO_ONNX.items()} +_ONNX_TO_DTYPE = {onnx_dtype: torch_dtype for torch_dtype, onnx_dtype in _DTYPE_TO_ONNX.items()} + def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torch.onnx.TensorProtoDataType: """Converts a pytorch dtype or scalar type string to an onnx dtype.""" @@ -45,3 +47,10 @@ def pytorch_dtype_to_onnx(dtype_or_scalar_type: Union[torch.dtype, str]) -> torc if dtype not in _DTYPE_TO_ONNX: raise RuntimeError(f"Unsupported dtype {dtype}") return _DTYPE_TO_ONNX[dtype] + + +def onnx_dtype_to_pytorch(dtype: torch.onnx.TensorProtoDataType) -> torch.dtype: + """Converts an onnx dtype to a pytorch dtype.""" + if dtype not in _ONNX_TO_DTYPE: + raise RuntimeError(f"Unsupported dtype {dtype}") + return _ONNX_TO_DTYPE[dtype] diff --git a/orttraining/orttraining/test/framework/checkpointing_test.cc b/orttraining/orttraining/test/framework/checkpointing_test.cc index b91cc2f1d5f5f..a7ee776b9bc39 100644 --- a/orttraining/orttraining/test/framework/checkpointing_test.cc +++ b/orttraining/orttraining/test/framework/checkpointing_test.cc @@ -52,21 +52,18 @@ void CompareOrtValuesToTensorProtoValues( ASSERT_EQ(name_to_ort_value.size(), name_to_tensor_proto.size()); NameMLValMap name_to_ort_value_from_tensor_proto{}; - std::vector> tensor_buffers{}; + AllocatorPtr tmp_allocator = std::make_shared(); for (const auto& name_and_tensor_proto : name_to_tensor_proto) { const auto& name = name_and_tensor_proto.first; const auto& tensor_proto = name_and_tensor_proto.second; TensorShape shape{tensor_proto.dims().data(), static_cast(tensor_proto.dims().size())}; ASSERT_EQ(tensor_proto.data_type(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - std::vector tensor_buffer(shape.Size() * sizeof(float)); - MemBuffer m(tensor_buffer.data(), tensor_buffer.size(), cpu_alloc_info); OrtValue ort_value; - ASSERT_STATUS_OK(utils::TensorProtoToMLValue( - Env::Default(), model_path.c_str(), tensor_proto, m, ort_value)); + ASSERT_STATUS_OK(utils::TensorProtoToOrtValue(Env::Default(), model_path.c_str(), tensor_proto, + tmp_allocator, ort_value)); name_to_ort_value_from_tensor_proto.emplace(name, ort_value); - tensor_buffers.emplace_back(std::move(tensor_buffer)); } for (const auto& name_and_ort_value : name_to_ort_value) { diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py index 318de843efb8f..d205e8f2377dd 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py @@ -135,8 +135,31 @@ def _torch_layer_norm(input, weight, bias, **kwargs): return torch.nn.functional.layer_norm(input, normalized_shape, weight, bias) +def _torch_gelu(input): + return torch.nn.functional.gelu(input) + + +def _torch_quick_gelu(input, **kwargs): + alpha = kwargs.get("alpha", 1.702) + return input * torch.sigmoid(input * alpha) + + +def _torch_gelu_grad(dy, x): + alpha = 0.70710678118654752440 + beta = 1.12837916709551257390 * 0.70710678118654752440 * 0.5 + cdf = 0.5 * (1 + torch.erf(x * alpha)) + pdf = beta * torch.exp(x * x * -0.5) + return dy * (cdf + x * pdf) + + +def _torch_quick_gelu_grad(dy, x, **kwargs): + alpha = kwargs.get("alpha", 1.702) + sigmoid = torch.sigmoid(x * alpha) + return dy * sigmoid * (1.0 + x * alpha * (1.0 - sigmoid)) + + class TorchFuncExecutor: - _INFER_FUNC_MAP = { # noqa: RUF012 + _TORCH_FUNC_MAP = { # noqa: RUF012 "Add": _torch_add, "Sub": _torch_sub, "Mul": _torch_mul, @@ -154,13 +177,17 @@ class TorchFuncExecutor: "ReduceMin": _torch_reduce_min, "Softmax": _torch_softmax, "LayerNormalization": _torch_layer_norm, + "Gelu": _torch_gelu, + "QuickGelu": _torch_quick_gelu, + "GeluGrad": _torch_gelu_grad, + "QuickGeluGrad": _torch_quick_gelu_grad, } @classmethod def run(cls, op_type, *torch_tensors, **kwargs): - if op_type not in cls._INFER_FUNC_MAP: + if op_type not in cls._TORCH_FUNC_MAP: raise NotImplementedError(f"Unsupported op type: {op_type}") - return cls._INFER_FUNC_MAP[op_type](*torch_tensors, **kwargs) + return cls._TORCH_FUNC_MAP[op_type](*torch_tensors, **kwargs) def _run_op_test(op_type, onnx_dtype, create_model_func, gen_inputs_func, **kwargs): @@ -169,6 +196,8 @@ def _run_op_test(op_type, onnx_dtype, create_model_func, gen_inputs_func, **kwar pt_inputs = gen_inputs_func(_onnx_dtype_to_torch_dtype(onnx_dtype)) ort_inputs = copy.deepcopy(pt_inputs) ort_inputs = [tensor.to(torch.uint8) if tensor.dtype == torch.bool else tensor for tensor in ort_inputs] + if "::" in op_type: + _, op_type = op_type.split("::") pt_outputs = TorchFuncExecutor.run(op_type, *pt_inputs, **kwargs) model_str = create_model_func(op_type, onnx_dtype, **kwargs).SerializeToString() ort_outputs = call_triton_by_onnx(hash(model_str), model_str, *[to_dlpack(tensor) for tensor in ort_inputs]) @@ -260,13 +289,27 @@ def _run_tunable_op_test(module_cls, dtype, gen_inputs_func, tunable_op, impl_co del os.environ["ORTMODULE_TUNING_RESULTS_PATH"] -@pytest.mark.parametrize("op_type", ["Add", "Sub", "Mul", "Div"]) +@pytest.mark.parametrize( + "op", + [ + ("Add", {}), + ("Sub", {}), + ("Mul", {}), + ("Div", {}), + ("com.microsoft::GeluGrad", {}), + ("com.microsoft::QuickGeluGrad", {}), + ("com.microsoft::QuickGeluGrad", {"alpha": 1.0}), + ], +) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16]) @pytest.mark.parametrize("input_shapes", [([1024, 2], [1024, 2]), ([2, 3, 3, 3], [3, 1, 3]), ([2049], [1])]) -def test_binary_elementwise_op(op_type, onnx_dtype, input_shapes): - def _create_model(op_type, onnx_dtype): +def test_binary_elementwise_op(op, onnx_dtype, input_shapes): + def _create_model(op_type, onnx_dtype, **kwargs): + domain = "" + if "::" in op_type: + domain, op_type = op_type.split("::") graph = helper.make_graph( - [helper.make_node(op_type, ["X", "Y"], ["Z"], name="test")], + [helper.make_node(op_type, ["X", "Y"], ["Z"], name="test", domain=domain, **kwargs)], "test", [ helper.make_tensor_value_info("X", onnx_dtype, None), @@ -282,7 +325,7 @@ def _gen_inputs(dtype): torch.randn(*input_shapes[1], dtype=dtype, device=DEVICE), ] - _run_op_test(op_type, onnx_dtype, _create_model, _gen_inputs) + _run_op_test(op[0], onnx_dtype, _create_model, _gen_inputs, **op[1]) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16]) @@ -303,13 +346,25 @@ def _gen_inputs(dtype): _run_op_test("Sum", onnx_dtype, _create_model, _gen_inputs) -@pytest.mark.parametrize("op_type", ["Sqrt", "Exp"]) +@pytest.mark.parametrize( + "op", + [ + ("Sqrt", {}), + ("Exp", {}), + ("com.microsoft::Gelu", {}), + ("com.microsoft::QuickGelu", {}), + ("com.microsoft::QuickGelu", {"alpha": 1.0}), + ], +) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16]) @pytest.mark.parametrize("input_shape", [[1024, 4], [2, 3, 3, 3], [2049, 1]]) -def test_unary_elementwise_op(op_type, onnx_dtype, input_shape): - def _create_model(op_type, onnx_dtype): +def test_unary_elementwise_op(op, onnx_dtype, input_shape): + def _create_model(op_type, onnx_dtype, **kwargs): + domain = "" + if "::" in op_type: + domain, op_type = op_type.split("::") graph = helper.make_graph( - [helper.make_node(op_type, ["X"], ["Y"], name="test")], + [helper.make_node(op_type, ["X"], ["Y"], name="test", domain=domain, **kwargs)], "test", [helper.make_tensor_value_info("X", onnx_dtype, None)], [helper.make_tensor_value_info("Y", onnx_dtype, None)], @@ -319,7 +374,7 @@ def _create_model(op_type, onnx_dtype): def _gen_inputs(dtype): return [torch.rand(*input_shape, dtype=dtype, device=DEVICE)] - _run_op_test(op_type, onnx_dtype, _create_model, _gen_inputs) + _run_op_test(op[0], onnx_dtype, _create_model, _gen_inputs, **op[1]) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16]) diff --git a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py index 56338ddbaffef..d5c37b3e36ee7 100644 --- a/orttraining/orttraining/test/python/orttraining_test_python_bindings.py +++ b/orttraining/orttraining/test/python/orttraining_test_python_bindings.py @@ -360,14 +360,18 @@ def test_add_get_property(property_value): if isinstance(property_value, float): property_value = float(np.float32(property_value)) - state["property"] = property_value - assert "property" in state - assert state["property"] == property_value + assert len(state.properties) == 0 + + state.properties["property"] = property_value + assert "property" in state.properties + assert state.properties["property"] == property_value + assert len(state.properties) == 1 CheckpointState.save_checkpoint(state, checkpoint_file_path) new_state = CheckpointState.load_checkpoint(checkpoint_file_path) - assert "property" in new_state - assert new_state["property"] == property_value + assert "property" in new_state.properties + assert new_state.properties["property"] == property_value + assert len(new_state.properties) == 1 def test_get_input_output_names(): @@ -563,3 +567,60 @@ def test_eval_step_with_ort_values(): fetches = model(inputs, labels) assert isinstance(fetches, OrtValue) assert fetches + + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +def test_get_and_set_parameter_values(device): + with tempfile.TemporaryDirectory() as temp_dir: + ( + checkpoint_file_path, + training_model_file_path, + eval_model_file_path, + _, + pt_model, + ) = _create_training_artifacts( + temp_dir, requires_grad=["fc2.weight", "fc2.bias"], frozen_params=["fc1.weight", "fc1.bias"] + ) + + state = CheckpointState.load_checkpoint(checkpoint_file_path) + + model = Module(training_model_file_path, state, eval_model_file_path, device=device) + + state_dict = pt_model.state_dict() + assert len(state_dict) == len(state.parameters) + for parameter_name, _ in state.parameters: + assert parameter_name in state_dict + + for name, pt_param in pt_model.named_parameters(): + ort_param = state.parameters[name] + assert ort_param.name == name + assert np.allclose(pt_param.detach().cpu().numpy(), ort_param.data) + if name in ["fc1.weight", "fc1.bias"]: + assert ort_param.requires_grad is False + assert ort_param.grad is None + else: + assert ort_param.requires_grad is True + assert np.allclose(ort_param.grad, np.zeros_like(ort_param.data, dtype=np.float32)) + + original_param = state.parameters["fc1.weight"].data + state.parameters["fc1.weight"].data = np.ones_like(state.parameters["fc1.weight"].data, dtype=np.float32) + updated_param = state.parameters["fc1.weight"].data + assert np.allclose(updated_param, np.ones_like(updated_param, dtype=np.float32)) + + model.train() + inputs = torch.randn(64, 784).numpy() + labels = torch.randint(high=10, size=(64,), dtype=torch.int64).numpy() + loss = model(inputs, labels) + assert loss is not None + for name, _ in pt_model.named_parameters(): + ort_param = state.parameters[name] + assert ort_param.name == name + if name in ["fc1.weight", "fc1.bias"]: + assert ort_param.requires_grad is False + assert ort_param.grad is None + else: + assert ort_param.requires_grad is True + assert ort_param.grad.any() + + state.parameters["fc1.weight"] = original_param + assert np.allclose(state.parameters["fc1.weight"].data, original_param) diff --git a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc index d734be8e3474b..e46952d87c2bf 100644 --- a/orttraining/orttraining/test/training_api/core/training_capi_tests.cc +++ b/orttraining/orttraining/test/training_api/core/training_capi_tests.cc @@ -318,4 +318,106 @@ TEST(TrainingCApiTest, LoadModelsFromBufferThrows) { testing::HasSubstr("Training Session Creation failed. Train model data cannot be NULL.")); } } + +TEST(TrainingCApiTest, GetParameter) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); +} + +TEST(TrainingCApiTest, UpdateParameter) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, Ort::SessionOptions(), checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); + + OrtValue* updated_param_value = std::make_unique().release(); + GenerateRandomInput(std::array{500, 784}, *updated_param_value); + Ort::Value updated_parameter{updated_param_value}; + checkpoint_state.UpdateParameter("fc1.weight", updated_parameter); + + Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight"); + gsl::span actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span not_expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); + + checkpoint_state.UpdateParameter("fc1.weight", parameter); + current_parameter = checkpoint_state.GetParameter("fc1.weight"); + actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + not_expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); +} + +#ifdef USE_CUDA +TEST(TrainingCApiTest, UpdateParameterDifferentDevices) { + auto model_uri = MODEL_FOLDER "training_model.onnx"; + + Ort::Env env; + Ort::SessionOptions session_options; + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0)); + Ort::CheckpointState checkpoint_state = Ort::CheckpointState::LoadCheckpoint(MODEL_FOLDER "checkpoint.ckpt"); + Ort::TrainingSession training_session = Ort::TrainingSession(env, session_options, checkpoint_state, model_uri); + + Ort::Value parameter = checkpoint_state.GetParameter("fc1.weight"); + auto tensor_info = parameter.GetTensorTypeAndShapeInfo(); + auto shape = tensor_info.GetShape(); + ASSERT_EQ(shape.size(), 2U); + ASSERT_EQ(shape.front(), static_cast(500)); + ASSERT_EQ(shape.back(), static_cast(784)); + + OrtValue* updated_param_value = std::make_unique().release(); + GenerateRandomInput(std::array{500, 784}, *updated_param_value); + Ort::Value updated_parameter{updated_param_value}; + checkpoint_state.UpdateParameter("fc1.weight", updated_parameter); + + Ort::Value current_parameter = checkpoint_state.GetParameter("fc1.weight"); + gsl::span actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + gsl::span not_expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); + + checkpoint_state.UpdateParameter("fc1.weight", parameter); + current_parameter = checkpoint_state.GetParameter("fc1.weight"); + actual = gsl::span(current_parameter.GetTensorMutableData(), + current_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + expected = gsl::span(parameter.GetTensorMutableData(), + parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + not_expected = gsl::span(updated_parameter.GetTensorMutableData(), + updated_parameter.GetTensorTypeAndShapeInfo().GetElementCount()); + ASSERT_EQ(actual, expected); + ASSERT_NE(actual, not_expected); +} +#endif + } // namespace onnxruntime::training::test diff --git a/orttraining/orttraining/training_api/checkpoint_property.h b/orttraining/orttraining/training_api/checkpoint_property.h index d7b1e295df53e..3c38c99b3152f 100644 --- a/orttraining/orttraining/training_api/checkpoint_property.h +++ b/orttraining/orttraining/training_api/checkpoint_property.h @@ -22,10 +22,12 @@ struct PropertyBag { PropertyBag() = default; void AddProperty(const std::string& name, const PropertyDataType& val) { - ORT_ENFORCE(named_properties_.find(name) == named_properties_.end(), - "Duplicated property named ", name); - - named_properties_.insert({name, val}); + auto it = named_properties_.find(name); + if (it == named_properties_.end()) { + named_properties_.insert({name, val}); + } else { + it->second = val; + } } template diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h index 0af737074964d..0e8544a7639ba 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h @@ -608,14 +608,14 @@ struct OrtTrainingApi { /// \name Accessing The Training Session State /// @{ - /** \brief Adds the given property to the checkpoint state. + /** \brief Adds or updates the given property to/in the checkpoint state. * * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint - * state by the user if they desire by calling this function with the appropriate property name and - * value. The given property name must be unique to be able to successfully add the property. + * state by the user by calling this function with the corresponding property name and value. + * The given property name must be unique to be able to successfully add the property. * * \param[in] checkpoint_state The checkpoint state which should hold the property. - * \param[in] property_name Unique name of the property being added. + * \param[in] property_name Name of the property being added or updated. * \param[in] property_type Type of the property associated with the given name. * \param[in] property_value Property value associated with the given name. * @@ -632,7 +632,7 @@ struct OrtTrainingApi { * exist in the checkpoint state to be able to retrieve it successfully. * * \param[in] checkpoint_state The checkpoint state that is currently holding the property. - * \param[in] property_name Unique name of the property being retrieved. + * \param[in] property_name Name of the property being retrieved. * \param[in] allocator Allocator used to allocate the memory for the property_value. * \param[out] property_type Type of the property associated with the given name. * \param[out] property_value Property value associated with the given name. @@ -669,6 +669,57 @@ struct OrtTrainingApi { ORT_API2_STATUS(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); + /** \brief Retrieves the type and shape information of the parameter associated with the given parameter name. + * + * This function retrieves the type and shape of the parameter associated with the given parameter name. + * The parameter must exist in the checkpoint state to be able to retrieve its type and shape information successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being retrieved. + * \param[out] parameter_type_and_shape The type and shape of the parameter being retrieved. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape); + + /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + * + * This function updates a model parameter in the checkpoint state with the given parameter data. + * The training session must be already created with the checkpoint state that contains the parameter + * being updated. The given parameter is copied over to the registered device for the training session. + * The parameter must exist in the checkpoint state to be able to update it successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being updated. + * \param[in] parameter The parameter data that should replace the existing parameter data. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter); + + /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + * + * This function retrieves the model parameter data from the checkpoint state for the given parameter name. + * The parameter is copied over and returned as an OrtValue. The training session must be already created + * with the checkpoint state that contains the parameter being retrieved. + * The parameter must exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] checkpoint_state The checkpoint state. + * \param[in] parameter_name Name of the parameter being retrieved. + * \param[in] allocator Allocator used to allocate the memory for the parameter. + * \param[out] parameter The parameter data that is retrieved from the checkpoint state. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + */ + ORT_API2_STATUS(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter); + /// @} }; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h index 0edef20ba6da8..218bef524200c 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h @@ -112,13 +112,13 @@ class CheckpointState : public detail::Base { const std::basic_string& path_to_checkpoint, const bool include_optimizer_state = false); - /** \brief Adds the given property to the checkpoint state. + /** \brief Adds or updates the given property to/in the checkpoint state. * * Runtime properties such as epoch, training step, best score, and others can be added to the checkpoint - * state by the user if they desire by calling this function with the appropriate property name and - * value. The given property name must be unique to be able to successfully add the property. + * state by the user by calling this function with the corresponding property name and value. + * The given property name must be unique to be able to successfully add the property. * - * \param[in] property_name Unique name of the property being added. + * \param[in] property_name Name of the property being added or updated. * \param[in] property_value Property value associated with the given name. * */ @@ -129,12 +129,38 @@ class CheckpointState : public detail::Base { * Gets the property value from an existing entry in the checkpoint state. The property must * exist in the checkpoint state to be able to retrieve it successfully. * - * \param[in] property_name Unique name of the property being retrieved. + * \param[in] property_name Name of the property being retrieved. * \return Property value associated with the given property name. * */ Property GetProperty(const std::string& property_name); + /** \brief Updates the data associated with the model parameter in the checkpoint state for the given parameter name. + * + * This function updates a model parameter in the checkpoint state with the given parameter data. + * The training session must be already created with the checkpoint state that contains the parameter + * being updated. The given parameter is copied over to the registered device for the training session. + * The parameter must exist in the checkpoint state to be able to update it successfully. + * + * \param[in] parameter_name Name of the parameter being updated. + * \param[in] parameter The parameter data that should replace the existing parameter data. + * + */ + void UpdateParameter(const std::string& parameter_name, const Value& parameter); + + /** \brief Gets the data associated with the model parameter from the checkpoint state for the given parameter name. + * + * This function retrieves the model parameter data from the checkpoint state for the given parameter name. + * The parameter is copied over to the provided OrtValue. The training session must be already created + * with the checkpoint state that contains the parameter being retrieved. + * The parameter must exist in the checkpoint state to be able to retrieve it successfully. + * + * \param[in] parameter_name Name of the parameter being retrieved. + * \return The parameter data that is retrieved from the checkpoint state. + * + */ + Value GetParameter(const std::string& parameter_name); + /// @} }; diff --git a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h index c0048458ddf4d..7d1326a10f8f8 100644 --- a/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h +++ b/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h @@ -279,4 +279,16 @@ inline Property CheckpointState::GetProperty(const std::string& property_name) { return property; } +inline void CheckpointState::UpdateParameter(const std::string& parameter_name, const Value& parameter) { + ThrowOnError(GetTrainingApi().UpdateParameter(p_, parameter_name.c_str(), parameter)); +} + +inline Value CheckpointState::GetParameter(const std::string& parameter_name) { + AllocatorWithDefaultOptions allocator; + OrtValue* parameter; + ThrowOnError(GetTrainingApi().GetParameter(p_, parameter_name.c_str(), allocator, ¶meter)); + + return Value{parameter}; +} + } // namespace Ort diff --git a/orttraining/orttraining/training_api/module.cc b/orttraining/orttraining/training_api/module.cc index d1775e358163c..cf49a01517d6b 100644 --- a/orttraining/orttraining/training_api/module.cc +++ b/orttraining/orttraining/training_api/module.cc @@ -119,6 +119,61 @@ Status TransformModelInputsForInference(Graph& inference_graph, #endif } // namespace +Status Parameter::CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const { + ORT_ENFORCE(data.IsAllocated(), "Given parameter data is not allocated. Cannot copy the checkpoint parameter to it."); + ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type."); + ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(), + "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(), + ", Got: ", data.Get().Shape().ToString()); +#ifdef ENABLE_STRIDED_TENSORS + auto data_strides = data.Get().Strides(); + auto param_strides = data_.Get().Strides(); + ORT_ENFORCE(data_strides.size() == param_strides.size(), + "Parameter data stride mismatch. Expected strides of size: ", param_strides.size(), + ", Got: ", data_strides.size()); + ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()), + "Parameter data stride value mismatch."); +#endif + ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(), + "Parameter data type mismatch. Expected: ", data_.Get().DataType(), + ", Got: ", data.Get().DataType()); + ORT_ENFORCE(data_transfer_manager != nullptr, + "Data transfer manager must be provided to copy data to the parameter. " + "Please create the TrainingSession before trying to update the parameter."); + + ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data_.Get(), *data.GetMutable())); + + return Status::OK(); +} + +Status Parameter::CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data) { + ORT_ENFORCE(data_.IsAllocated(), + "The checkpoint parameter is not allocated. Cannot copy the given parameter data to it."); + ORT_ENFORCE(data.IsTensor(), "Parameter data should be of tensor type."); + ORT_ENFORCE(data.Get().Shape() == data_.Get().Shape(), + "Parameter data shape mismatch. Expected: ", data_.Get().Shape().ToString(), + ", Got: ", data.Get().Shape().ToString()); +#ifdef ENABLE_STRIDED_TENSORS + auto data_strides = data.Get().Strides(); + auto param_strides = data_.Get().Strides(); + ORT_ENFORCE(data_strides.size() == param_strides.size(), + "Parameter data stride mismatch. Expected strides of size: ", param_strides.size(), + ", Got: ", data_strides.size()); + ORT_ENFORCE(std::equal(data_strides.begin(), data_strides.end(), param_strides.begin()), + "Parameter data stride value mismatch."); +#endif + ORT_ENFORCE(data.Get().DataType() == data_.Get().DataType(), + "Parameter data type mismatch. Expected: ", data_.Get().DataType(), + ", Got: ", data.Get().DataType()); + ORT_ENFORCE(data_transfer_manager != nullptr, + "Data transfer manager must be provided to copy data to the parameter. " + "Please create the TrainingSession before trying to update the parameter."); + + ORT_THROW_IF_ERROR(data_transfer_manager->CopyTensor(data.Get(), *data_.GetMutable())); + + return Status::OK(); +} + Status Parameter::SetGrad(const std::string& gradient_name, const OrtValue& param_grad) { // assert param is allocated ORT_ENFORCE(data_.IsAllocated(), "Parameter data should be allocated before allocating gradient."); @@ -334,6 +389,10 @@ Module::Module(const ModelIdentifiers& model_identifiers, } } +Module::~Module() { + state_->module_checkpoint_state.train_session_data_transfer_mgr = nullptr; +} + size_t Module::GetTrainingModelOutputCount() const noexcept { return train_output_names_.size(); } diff --git a/orttraining/orttraining/training_api/module.h b/orttraining/orttraining/training_api/module.h index adb633343263e..f323e6be72d49 100644 --- a/orttraining/orttraining/training_api/module.h +++ b/orttraining/orttraining/training_api/module.h @@ -21,6 +21,8 @@ struct Parameter { // Return the mutable data. OrtValue& Data() { return data_; } + Status CopyTo(const DataTransferManager* data_transfer_manager, OrtValue& data) const; + Status CopyFrom(const DataTransferManager* data_transfer_manager, const OrtValue& data); const std::string& Name() const { return name_; } // Returns whether this parameter is trainable or not. @@ -34,7 +36,6 @@ struct Parameter { // Reset and release the gradient buffer of this Parameter greedily. Status ResetGrad(); - protected: Status SetGrad(const std::string& gradient_name, const OrtValue& param_grad); private: @@ -83,6 +84,8 @@ struct Module { const std::vector>& providers, gsl::span op_domains = gsl::span()); + ~Module(); + // Return the trainable/nontrainable parameters std::vector> Parameters() const; diff --git a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc index 6693bba348648..38a9aad9640ea 100644 --- a/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc +++ b/orttraining/orttraining/training_api/onnxruntime_training_c_api.cc @@ -333,6 +333,10 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::LoadCheckpointFromBuffer, _In_ const void* _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state) { API_IMPL_BEGIN + if (checkpoint_buffer == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid checkpoint buffer. Actual: nullptr."); + } + *checkpoint_state = nullptr; auto chkpt_state = std::make_unique(); const auto* checkpoint_bytes = reinterpret_cast(checkpoint_buffer); @@ -559,6 +563,76 @@ ORT_API_STATUS_IMPL(OrtTrainingApis::GetProperty, _In_ const OrtCheckpointState* API_IMPL_END } +ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape) { + API_IMPL_BEGIN + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + + return OrtApis::GetTensorTypeAndShape(&it->second->Data(), parameter_type_and_shape); + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter) { + API_IMPL_BEGIN + if (parameter == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr."); + } + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + ORT_API_RETURN_IF_STATUS_NOT_OK(it->second->CopyFrom( + chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, *parameter)); + + return nullptr; + API_IMPL_END +} + +ORT_API_STATUS_IMPL(OrtTrainingApis::GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter) { + API_IMPL_BEGIN + + if (parameter == nullptr) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Expected a valid parameter. Actual: nullptr."); + } + + auto chkpt_state = reinterpret_cast(checkpoint_state); + auto it = chkpt_state->module_checkpoint_state.named_parameters.find(parameter_name); + if (it == chkpt_state->module_checkpoint_state.named_parameters.end()) { + std::string err_msg = "Parameter name " + std::string(parameter_name) + " not found in checkpoint state."; + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, err_msg.c_str()); + } + + if (!it->second->Data().IsTensor()) { + return OrtApis::CreateStatus(ORT_FAIL, "Expected a tensor type for the parameter. Found a non-tensor type."); + } + const auto& parameter_tensor = it->second->Data().Get(); + ORT_API_RETURN_IF_ERROR(OrtApis::CreateTensorAsOrtValue( + allocator, parameter_tensor.Shape().GetDims().data(), parameter_tensor.Shape().NumDimensions(), + ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, parameter)); + + auto status = it->second->CopyTo( + chkpt_state->module_checkpoint_state.train_session_data_transfer_mgr, **parameter); + if (!status.IsOK()) { + OrtApis::ReleaseValue(*parameter); + return onnxruntime::ToOrtStatus(status); + } + + return nullptr; + API_IMPL_END +} + static constexpr OrtTrainingApi ort_training_api = { // NOTE: The C# bindings depend on the API order within this struct. Since Training APIs are not officially // released, it is OK to change the order here, however a corresponding matching change should also be done in the @@ -592,7 +666,10 @@ static constexpr OrtTrainingApi ort_training_api = { &OrtTrainingApis::TrainingSessionGetEvalModelInputName, &OrtTrainingApis::AddProperty, &OrtTrainingApis::GetProperty, - &OrtTrainingApis::LoadCheckpointFromBuffer}; + &OrtTrainingApis::LoadCheckpointFromBuffer, + &OrtTrainingApis::GetParameterTypeAndShape, + &OrtTrainingApis::UpdateParameter, + &OrtTrainingApis::GetParameter}; ORT_API(const OrtTrainingApi*, OrtTrainingApis::GetTrainingApi, uint32_t) { // No constraints on the API version yet. diff --git a/orttraining/orttraining/training_api/ort_training_apis.h b/orttraining/orttraining/training_api/ort_training_apis.h index c87108957c975..2a8c1e30361c6 100644 --- a/orttraining/orttraining/training_api/ort_training_apis.h +++ b/orttraining/orttraining/training_api/ort_training_apis.h @@ -94,4 +94,14 @@ ORT_API_STATUS_IMPL(GetProperty, _In_ const OrtCheckpointState* checkpoint_state ORT_API_STATUS_IMPL(LoadCheckpointFromBuffer, _In_ const void* checkpoint_buffer, _In_ const size_t num_bytes, _Outptr_ OrtCheckpointState** checkpoint_state); +ORT_API_STATUS_IMPL(GetParameterTypeAndShape, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Outptr_ OrtTensorTypeAndShapeInfo** parameter_type_and_shape); + +ORT_API_STATUS_IMPL(UpdateParameter, _Inout_ OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _In_ OrtValue* parameter); + +ORT_API_STATUS_IMPL(GetParameter, _In_ const OrtCheckpointState* checkpoint_state, + _In_ const char* parameter_name, _Inout_ OrtAllocator* allocator, + _Outptr_ OrtValue** parameter); + } // namespace OrtTrainingApis diff --git a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc index 4e7fcbc95bb1d..e1d4be24861f5 100644 --- a/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc +++ b/orttraining/orttraining/training_ops/cpu/torch/torch_custom_function_kernel_base.cc @@ -153,8 +153,11 @@ void PythonOpBase::RunForward(OpKernelContext* context, inplace_ != 0, kernel_invoke_id_); - ORT_ENFORCE(1 + returned_ortvalues.size() == static_cast(context->OutputCount()), - "Output count mismatch for PythonOp run"); + const size_t returned_output_count = 1 + returned_ortvalues.size(); + const size_t kernel_output_count = static_cast(context->OutputCount()); + ORT_ENFORCE(returned_output_count == kernel_output_count, "Output count mismatch for PythonOp run, ", + "returned_output_count: ", returned_output_count, ", expected kernel_output_count: ", + kernel_output_count); } void PythonOpBase::SetOutputs(OpKernelContext* context, void* diff_ctx, std::vector& returned_args) const { diff --git a/requirements-dev.txt b/requirements-dev.txt index 73e04e6b37c0b..1b5ca65cf8037 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,4 +17,4 @@ scikit-learn scipy sympy wheel -setuptools>=41.4.0 +setuptools>=61.0.0 diff --git a/requirements-training.txt b/requirements-training.txt index 4b1be6cef9b7c..dbfd7305d1bec 100644 --- a/requirements-training.txt +++ b/requirements-training.txt @@ -6,4 +6,4 @@ onnx packaging protobuf sympy -setuptools>=41.4.0 +setuptools>=61.0.0 diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index c4fb5499983cb..c25a61ee5d568 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2055,7 +2055,7 @@ def build_nuget_package( ): if not (is_windows() or is_linux()): raise BuildError( - "Currently csharp builds and nuget package creation is only supportted on Windows and Linux platforms." + "Currently csharp builds and nuget package creation is only supported on Windows and Linux platforms." ) csharp_build_dir = os.path.join(source_dir, "csharp") diff --git a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml index 21bc1c481b3e6..395c190ce9e11 100644 --- a/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-ci-pipeline.yml @@ -56,10 +56,6 @@ stages: clean: true submodules: none - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - - task: UsePythonVersion@0 inputs: versionSpec: '3.8' @@ -200,8 +196,11 @@ stages: - stage: arm64_test dependsOn: ['arm64_build'] jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'aarch64' machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' - device: 'CPU' + base_image: 'arm64v8/almalinux:8' + devtoolset_rootpath: /opt/rh/gcc-toolset-12/root + ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 + prepend_path: '/opt/rh/gcc-toolset-12/root/usr/bin:' diff --git a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml index 2c5a69e216d14..146186e9eeaf5 100644 --- a/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-cpu-aten-pipeline.yml @@ -53,10 +53,6 @@ jobs: clean: true submodules: recursive - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - - template: templates/get-docker-image-steps.yml parameters: Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2014_aten_cpu diff --git a/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml index 0a7dc0e456a95..e4441853240e5 100644 --- a/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-multi-gpu-tensorrt-ci-pipeline.yml @@ -36,5 +36,3 @@ jobs: JobName: 'Linux_CI_Multi_GPU_TensorRT_Dev' # The latest TensorRT container only supports ubuntu20.04 and python 3.8 RunDockerBuildArgs: '-o ubuntu20.04 -d tensorrt -x "--enable_multi_device_test"' - DoNugetPack: 'false' - ArtifactName: 'drop-linux' diff --git a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml index 93ee17b4cc7e6..c92fc93abba37 100644 --- a/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-openvino-ci-pipeline.yml @@ -33,6 +33,4 @@ jobs: AgentPool : 'Linux-CPU-2019' JobName: 'Linux_CI_Dev' RunDockerBuildArgs: '-o ubuntu20.04 -d openvino -v 2023.0.0 -x "--use_openvino CPU_FP32 --build_wheel"' - DoNugetPack: 'false' - ArtifactName: 'drop-linux' TimeoutInMinutes: 120 diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml index 007630edb25be..018672e0b2dea 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-ci-pipeline.yml @@ -54,10 +54,6 @@ jobs: clean: true submodules: none - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - - task: UsePythonVersion@0 inputs: versionSpec: '3.8' @@ -88,6 +84,7 @@ jobs: mkdir -p $(Pipeline.Workspace)/ccache docker run --rm \ --volume /data/onnx:/data/onnx:ro \ + --volume /data/models:/build/models:ro \ --volume $(Build.SourcesDirectory):/onnxruntime_src \ --volume $(Build.BinariesDirectory):/build \ --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ @@ -109,51 +106,11 @@ jobs: --build_wheel \ --enable_onnx_tests \ --enable_training \ - --use_cache \ - --update --build; \ + --use_cache; \ ccache -sv; \ ccache -z" workingDirectory: $(Build.SourcesDirectory) - - task: CmdLine@2 - displayName: 'Install python deps' - inputs: - script: | - set -e -x - python3 -m pip uninstall -y ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml onnx -qq - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt $(Build.BinariesDirectory)/requirements.txt - # Test ORT with the latest ONNX release. - sed -i "s/git+http:\/\/github\.com\/onnx\/onnx.*/onnx/" $(Build.BinariesDirectory)/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements.txt - mkdir $(Build.BinariesDirectory)/requirements_torch_cpu/ - cp $(Build.SourcesDirectory)/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - python3 -m pip install -r $(Build.BinariesDirectory)/requirements_torch_cpu/requirements.txt - - - task: CmdLine@2 - displayName: 'Install Release python package' - inputs: - script: | - rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 - python3 -m pip install $(Build.BinariesDirectory)/Release/dist/*.whl - - - task: PythonScript@0 - displayName: 'Run Release unit tests' - inputs: - scriptPath: $(Build.SourcesDirectory)/tools/ci_build/build.py - workingDirectory: $(Build.BinariesDirectory)/Release - arguments: >- - --build_dir $(Build.BinariesDirectory) - --cmake_generator Ninja - --config Release - --test - --skip_submodule_sync - --build_shared_lib - --parallel - --build_wheel - --enable_onnx_tests - --enable_training - --ctest_path "" - - task: PublishTestResults@2 displayName: 'Publish unit test results' inputs: diff --git a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml index adf5695bd76eb..2d2719fef8f3d 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-linux-gpu-ci-pipeline.yml @@ -32,7 +32,6 @@ jobs: parameters: AgentPool : 'Onnxruntime-Linux-GPU-NC6sv3' JobName: 'Onnxruntime_Linux_GPU_Training' - SubmoduleCheckoutMode: 'recursive' RunDockerBuildArgs: > -o ubuntu20.04 -d gpu -t onnxruntime_orttraining_ortmodule_tests_image @@ -40,24 +39,16 @@ jobs: -e -x " --enable_training - --config $(buildConfig) + --config Release --use_cuda --cuda_version=11.8 --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8 --build_wheel --enable_nvtx_profile --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=70 " - DoNugetPack: 'false' RunInjectedPipeline: 'true' InjectedPipeline: 'orttraining-linux-gpu-test-ci-pipeline.yml' DockerImageTag: 'onnxruntime_orttraining_ortmodule_tests_image' - BuildConfig: $(buildConfig) - ArtifactName: 'drop-linux' TimeoutInMinutes: 140 # Enable unreleased onnx opsets in CI builds # This facilitates testing the implementation for the new opsets AllowReleasedOpsetOnly: '0' - Strategy: - maxParallel: 2 - matrix: - Release: - buildConfig: Release diff --git a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml index 523390debc887..8dd1f0c5c6461 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-pai-ci-pipeline.yml @@ -11,6 +11,14 @@ pr: - 'onnxruntime/core/providers/js' name: 'orttraining_ci_$(Date:yyyyMMdd)_$(Rev:r)' +resources: + repositories: + - repository: manylinux + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 + variables: - name: video value: 44 @@ -22,7 +30,101 @@ variables: value: Release jobs: -- job: Linux_Build +- job: Linux_Build_manylinux + variables: + skipComponentGovernanceDetection: true + CCACHE_DIR: $(Pipeline.Workspace)/ccache + TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] + workspace: + clean: all + pool: onnxruntime-Ubuntu2004-AMD-CPU + timeoutInMinutes: 120 + + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + + - checkout: self + clean: true + submodules: recursive + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: >- + --build-arg INSTALL_DEPS_EXTRA_ARGS=-tmur + --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 + --build-arg BUILD_UID=$(id -u) + --build-arg ROCM_VERSION=$(RocmVersion) + --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root + --build-arg PREPEND_PATH=/opt/rh/gcc-toolset-12/root/usr/bin: + --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64:/usr/local/lib + Repository: onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-manylinux-build + + - task: Cache@2 + inputs: + key: '"manylinux" | "$(TODAY)" | "$(Build.SourceBranch)" | "$(Build.SourceVersion)"' + path: $(CCACHE_DIR) + cacheHitVar: CACHE_RESTORED + restoreKeys: | + "manylinux" | "$(TODAY)" | "$(Build.SourceBranch)" + "manylinux" | "$(TODAY)" | + displayName: Cache Task + + - script: mkdir -p $(CCACHE_DIR) + condition: ne(variables.CACHE_RESTORED, 'true') + displayName: Create Cache Dir + + - task: CmdLine@2 + inputs: + script: |- + export ROCM_HOME=/opt/rocm + docker run --rm \ + --ipc=host \ + --network=host \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --shm-size=1024m \ + --user $UID:$(id -g $USER) \ + -e CC=/opt/rh/gcc-toolset-12/root/usr/bin/cc -e CXX=/opt/rh/gcc-toolset-12/root/usr/bin/c++ -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" \ + -e CCACHE_DIR=/cache \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume $(CCACHE_DIR):/cache \ + --workdir /onnxruntime_src \ + onnxruntimetrainingrocm-cibuild-rocm$(RocmVersion)-manylinux-build \ + /bin/bash -c " + set -ex; \ + ccache -s; \ + /opt/python/cp38-cp38/bin/python3 tools/ci_build/build.py \ + --config $(BuildConfig) \ + --enable_training \ + --mpi_home /opt/ompi \ + --cmake_extra_defines \ + CMAKE_HIP_COMPILER=${ROCM_HOME}/llvm/bin/clang++ \ + onnxruntime_BUILD_UNIT_TESTS=OFF \ + FETCHCONTENT_TRY_FIND_PACKAGE_MODE=NEVER \ + --use_cache \ + --use_rocm \ + --rocm_version=$(RocmVersion) \ + --rocm_home ${ROCM_HOME} \ + --nccl_home ${ROCM_HOME}\ + --update \ + --build_dir /build \ + --build \ + --parallel \ + --build_wheel \ + --skip_submodule_sync \ + --skip_tests; \ + ccache -sv; \ + ccache -z" + displayName: 'Build onnxruntime' + + - template: templates/explicitly-defined-final-tasks.yml + +- job: Linux_Build_ubuntu variables: skipComponentGovernanceDetection: true CCACHE_DIR: $(Pipeline.Workspace)/ccache @@ -115,13 +217,12 @@ jobs: - template: templates/explicitly-defined-final-tasks.yml - -- job: Linux_Test +- job: Linux_Test_ubuntu workspace: clean: all pool: AMD-GPU dependsOn: - - Linux_Build + - Linux_Build_ubuntu timeoutInMinutes: 120 steps: diff --git a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml index eb837b35af428..a45b7d57205d1 100644 --- a/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml +++ b/tools/ci_build/github/azure-pipelines/orttraining-py-packaging-pipeline-rocm.yml @@ -14,51 +14,39 @@ stages: - template: templates/rocm.yml parameters: PythonVersion: '3.8' - RocmVersion: '5.4.2' - - template: templates/rocm.yml - parameters: - PythonVersion: '3.9' - RocmVersion: '5.4.2' - - template: templates/rocm.yml - parameters: - PythonVersion: '3.10' - RocmVersion: '5.4.2' - - template: templates/rocm.yml - parameters: - PythonVersion: '3.8' - RocmVersion: '5.5' + RocmVersion: '5.6' - template: templates/rocm.yml parameters: PythonVersion: '3.9' - RocmVersion: '5.5' + RocmVersion: '5.6' - template: templates/rocm.yml parameters: PythonVersion: '3.10' - RocmVersion: '5.5' + RocmVersion: '5.6' - template: templates/rocm.yml parameters: PythonVersion: '3.8' - RocmVersion: '5.6' + RocmVersion: '5.7' - template: templates/rocm.yml parameters: PythonVersion: '3.9' - RocmVersion: '5.6' + RocmVersion: '5.7' - template: templates/rocm.yml parameters: PythonVersion: '3.10' - RocmVersion: '5.6' + RocmVersion: '5.7' - template: templates/rocm.yml parameters: PythonVersion: '3.8' - RocmVersion: '5.6' + RocmVersion: '5.7' BuildConfig: 'RelWithDebInfo' - template: templates/rocm.yml parameters: PythonVersion: '3.9' - RocmVersion: '5.6' + RocmVersion: '5.7' BuildConfig: 'RelWithDebInfo' - template: templates/rocm.yml parameters: PythonVersion: '3.10' - RocmVersion: '5.6' + RocmVersion: '5.7' BuildConfig: 'RelWithDebInfo' diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index c684e08ba1258..2161a9205f22d 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -3,24 +3,38 @@ resources: - pipeline: build source: 'Python packaging pipeline' trigger: true + branch: main # branch to pick the artifact, Used only for manual triggered pipeline runs for testing the pipeline itself + #TODO: Remove the following dependency. Running python tests should not need to use manylinux. + repositories: + - repository: manylinux # The name used to reference this repository in the checkout step + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 stages: - stage: Linux_Test_CPU_x86_64_stage jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'x86_64' machine_pool: 'onnxruntime-Ubuntu2004-AMD-CPU' - device: 'CPU' + base_image: 'registry.access.redhat.com/ubi8/ubi' + devtoolset_rootpath: /opt/rh/gcc-toolset-12/root + ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 + prepend_path: '/opt/rh/gcc-toolset-12/root/usr/bin:' - stage: Linux_Test_CPU_aarch64_stage dependsOn: [] jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'aarch64' machine_pool: 'aiinfra-linux-ARM64-CPU-2019' - device: 'CPU' + base_image: 'arm64v8/almalinux:8' + devtoolset_rootpath: /opt/rh/gcc-toolset-12/root + ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 + prepend_path: '/opt/rh/gcc-toolset-12/root/usr/bin:' - stage: Packages_Somking_Test dependsOn: [] @@ -31,19 +45,6 @@ stages: machine_pool: vmImage: 'macOS-13' itemPattern: '*/*mac*x86_64.whl' - - template: templates/py-package-smoking-test.yml - parameters: - job_name: Test_WIN_64_Wheels - itemPattern: '*/*win_amd64.whl' - machine_pool: - vmImage: 'windows-2022' - - template: templates/py-package-smoking-test.yml - parameters: - job_name: Test_WIN_32_Wheels - itemPattern: '*/*win32.whl' - python_arch: 'x86' - machine_pool: - vmImage: 'windows-2022' - template: templates/py-package-smoking-test.yml parameters: job_name: Test_LINUX_x86_64_Wheels @@ -61,7 +62,7 @@ stages: - Linux_Test_CPU_aarch64_stage - Packages_Somking_Test jobs: - - template: templates/py-packaging-linux-test.yml + - template: templates/py-packaging-linux-test-cuda.yml parameters: arch: 'x86_64' machine_pool: 'Onnxruntime-Linux-GPU' diff --git a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml index 796938dc22a67..15fcec0511741 100644 --- a/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/c-api-linux-cpu.yml @@ -68,7 +68,7 @@ jobs: script: | mkdir -p $HOME/.onnx docker run --rm -e CFLAGS="${{parameters.OnnxruntimeCFlags}}" -e CXXFLAGS="${{parameters.OnnxruntimeCXXFlags}}" --volume /data/onnx:/data/onnx:ro --volume $(Build.SourcesDirectory):/onnxruntime_src --volume $(Build.BinariesDirectory):/build \ - --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}} /bin/bash -c "python3 \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx -e NIGHTLY_BUILD onnxruntimecpubuildcentos8${{parameters.OnnxruntimeArch}} /bin/bash -c "python3.9 \ /onnxruntime_src/tools/ci_build/build.py --build_java --build_nodejs --build_dir /build --config Release \ --skip_submodule_sync --parallel --build_shared_lib ${{ parameters.AdditionalBuildFlags }} && cd /build/Release && make install DESTDIR=/build/linux-${{parameters.OnnxruntimeArch}}" workingDirectory: $(Build.SourcesDirectory) diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index f17bc8de5739b..cf73691a5eecc 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.81 + version: 1.0.90 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.81 + version: 1.0.90 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml index 8868e671a5fa5..09c52f4d5ba0d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml @@ -31,7 +31,7 @@ steps: architecture: ${{parameters.BuildArch}} - script: | - python -m pip install -q setuptools wheel numpy flatbuffers + python -m pip install --upgrade "setuptools>=68.2.2" wheel numpy flatbuffers workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Install python modules' diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml index 46f2ae7b97acc..3b1fde6cb6e4f 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-vs-2022-job.yml @@ -97,7 +97,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' force32bit: ${{ parameters.isX86 }} # Our build machine doesn't have java x86 diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml index 05b2dee77e689..7b9788d90b17d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-ci.yml @@ -1,23 +1,14 @@ parameters: AgentPool : 'onnxruntime-Ubuntu2004-AMD-CPU' StageName : 'Linux_CI_Dev' - SubmoduleCheckoutMode: '' RunDockerBuildArgs: '-o ubuntu20.04 -d cpu -x "--build_wheel"' - DoNodejsPack: 'false' - DoNugetPack: 'false' NuPackScript: '' RunInjectedPipeline: 'false' InjectedPipeline: '' DockerImageTag: '' - BuildConfig: '' - ArtifactName: 'drop-linux' TimeoutInMinutes: 120 # Controls whether unreleased onnx opsets are allowed. Default is set to 1 AllowReleasedOpsetOnly: '1' - # to inject strategy, you need to pass in the whole yaml structure - - # https://docs.microsoft.com/en-us/azure/devops/pipelines/yaml-schema?view=azure-devops&tabs=schema#strategies - # see example in orttraining-linux-gpu-ci-pipeline.yml - Strategy: '' jobs: - job: ${{ parameters.StageName }} @@ -28,16 +19,8 @@ jobs: ALLOW_RELEASED_ONNX_OPSET_ONLY: ${{ parameters.AllowReleasedOpsetOnly }} skipComponentGovernanceDetection: true pool: ${{ parameters.AgentPool }} - ${{ if ne(parameters.Strategy, '') }}: - strategy: - ${{ parameters.Strategy }} steps: - checkout: self - ${{ if ne(parameters.SubmoduleCheckoutMode, '') }}: - submodules: ${{ parameters.SubmoduleCheckoutMode }} - - task: NodeTool@0 - inputs: - versionSpec: '16.x' - template: run-docker-build-steps.yml parameters: RunDockerBuildArgs: '${{ parameters.RunDockerBuildArgs }}' @@ -48,31 +31,10 @@ jobs: searchFolder: '$(Build.BinariesDirectory)' testRunTitle: 'Unit Test Run' condition: succeededOrFailed() - - ${{ if eq(parameters['DoNugetPack'], 'true') }}: - - script: | - ${{ parameters.NuPackScript }} - displayName: 'Create Artifacts' - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Artifact' - inputs: - artifactName: ${{ parameters.ArtifactName }} - targetPath: '$(Build.ArtifactStagingDirectory)' - - ${{ if eq(parameters['DoNodejsPack'], 'true') }}: - - script: | - npm pack - cp $(Build.SourcesDirectory)/js/node/onnxruntime-*.tgz $(Build.ArtifactStagingDirectory) - cp -R $(Build.SourcesDirectory)/js/node/prebuilds $(Build.ArtifactStagingDirectory)/prebuilds - workingDirectory: '$(Build.SourcesDirectory)/js/node' - displayName: 'Create NPM Package' - - task: PublishPipelineArtifact@0 - displayName: 'Publish Pipeline Artifact: ${{ parameters.ArtifactName }}' - inputs: - artifactName: ${{ parameters.ArtifactName }} - targetPath: '$(Build.ArtifactStagingDirectory)' - ${{ if eq(parameters['RunInjectedPipeline'], 'true') }}: - template: | ${{ parameters.InjectedPipeline }} parameters: DockerImageTag: ${{ parameters.DockerImageTag }} - BuildConfig: ${{ parameters.BuildConfig }} + BuildConfig: Release - template: clean-agent-build-directory-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index 0e584b550f562..fe7f752513f3c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -83,7 +83,7 @@ jobs: architecture: $(buildArch) - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: download-deps.yml - task: PythonScript@0 @@ -98,7 +98,6 @@ jobs: cd '$(Build.SourcesDirectory)/cmake/external/emsdk' ./emsdk install 3.1.44 ccache-git-emscripten-64bit ./emsdk activate 3.1.44 ccache-git-emscripten-64bit - ln -s $(Build.SourcesDirectory)/cmake/external/emsdk/ccache/git-emscripten_64bit/bin/ccache /usr/local/bin/ccache displayName: 'emsdk install and activate ccache for emscripten' condition: eq('${{ parameters.WithCache }}', 'true') diff --git a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml index adfcd98e37230..f5e5435cfaca0 100644 --- a/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/templates/mac-cpu-packing-jobs.yml @@ -50,7 +50,7 @@ jobs: versionSpec: 3.11 - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: set-version-number-variables-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml index cee3bd9c9e968..8d5ca19a73535 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-package-smoking-test.yml @@ -39,36 +39,22 @@ jobs: versionSpec: $(PythonVersion) architecture: ${{ parameters.python_arch }} - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - artifactName: 'onnxruntime' - targetPath: '$(Build.BinariesDirectory)/whl' - itemPattern: ${{parameters.itemPattern}} - # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - buildType: current - # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: 841 - preferTriggeringPipeline: true - runVersion: 'latest' - buildType: specific + - download: build # pipeline resource identifier. + artifact: 'onnxruntime' - task: Bash@3 inputs: targetType: 'inline' script: | set -ex - files=(whl/*.whl) + files=(*.whl) FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') - python3 -m pip install --find-links "$(Build.BinariesDirectory)/whl" $PYTHON_PACKAGE_NAME - pip show $PYTHON_PACKAGE_NAME - python -c "import onnxruntime as ort; print(ort.__version__)" - workingDirectory: $(Build.BinariesDirectory) + python3 -m pip install --find-links "$(Pipeline.Workspace)/build/onnxruntime" $PYTHON_PACKAGE_NAME + python3 -m pip show $PYTHON_PACKAGE_NAME + python3 -c "import onnxruntime as ort; print(ort.__version__)" + workingDirectory: $(Pipeline.Workspace)/build/onnxruntime displayName: Test Package Installation - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml new file mode 100644 index 0000000000000..cc90085e184dc --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cpu.yml @@ -0,0 +1,117 @@ +parameters: +- name: arch + type: string + +- name: base_image + type: string + +- name: devtoolset_rootpath + type: string + +- name: ld_library_path_arg + type: string + +- name: prepend_path + type: string + +- name: machine_pool + type: string + +- name: extra_job_id + type: string + default: '' + +- name: python_wheel_suffix + type: string + default: '' + + +# TODO: Ideally it should fetch information from the build that triggers it +- name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: timeout + type: number + default: 120 + +jobs: +- job: Linux_Test_CPU${{ parameters.extra_job_id }}_${{ parameters.arch }} + timeoutInMinutes: ${{ parameters.timeout }} + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: ${{ parameters.machine_pool }} + steps: + - checkout: self + clean: true + submodules: none + # The public ADO project + - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: + - download: current # pipeline resource identifier. + artifact: 'drop-linux-cpu-${{ parameters.arch }}' + + - download: current # pipeline resource identifier. + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + + - bash: | + set -e -x + mv "$(Pipeline.Workspace)/drop-linux-cpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp + find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + # The private ADO project + - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: + - download: build # pipeline resource identifier. + artifact: 'drop-linux-cpu-${{ parameters.arch }}' + + - download: build # pipeline resource identifier. + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + + - bash: | + set -e -x + ls $(Pipeline.Workspace)/build + mv "$(Pipeline.Workspace)/build/drop-linux-cpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp + find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + + # The BinSkim task uses a dotnet program which doesn't support ARM CPUs yet + - ${{ if eq(parameters.arch, 'x86_64') }}: + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '$(Build.BinariesDirectory)/tmp/**/*.so' + continueOnError: true + + #- task: PostAnalysis@2 + # inputs: + # GdnBreakAllTools: true + # GdnBreakPolicy: M365 + # GdnBreakPolicyMinSev: Error + + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2_28_cpu + Context: tools/ci_build/github/linux/docker/inference/x64/python/cpu + DockerBuildArgs: "--build-arg POLICY=manylinux_2_28 --build-arg BUILD_UID=$( id -u ) --build-arg BASEIMAGE=${{ parameters.base_image }} --build-arg PLATFORM=${{ parameters.arch }} --build-arg PREPEND_PATH=${{ parameters.prepend_path }} --build-arg LD_LIBRARY_PATH_ARG=${{ parameters.ld_library_path_arg }} --build-arg DEVTOOLSET_ROOTPATH=${{ parameters.devtoolset_rootpath }}" + Repository: onnxruntimecpubuildpython${{ parameters.arch }} + ${{ if eq(parameters.arch, 'aarch64') }}: + UpdateDepsTxt: false + + - task: Bash@3 + displayName: 'Bash Script' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockertest.sh + arguments: -d CPU -c ${{parameters.cmake_build_type}} -i onnxruntimecpubuildpython${{ parameters.arch }} + + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml new file mode 100644 index 0000000000000..43ed0172825bc --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -0,0 +1,98 @@ +parameters: +- name: arch + type: string + +- name: device + type: string + values: + - CPU + - GPU + +- name: machine_pool + type: string + +- name: extra_job_id + type: string + default: '' + +- name: python_wheel_suffix + type: string + default: '' + + +# TODO: Ideally it should fetch information from the build that triggers it +- name: cmake_build_type + type: string + default: 'Release' + values: + - Debug + - Release + - RelWithDebInfo + - MinSizeRel + +- name: timeout + type: number + default: 120 + +jobs: +- job: Linux_Test_GPU${{ parameters.extra_job_id }}_${{ parameters.arch }} + timeoutInMinutes: ${{ parameters.timeout }} + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: ${{ parameters.machine_pool }} + steps: + - checkout: self + clean: true + submodules: none + # The public ADO project + # - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: + + # The private ADO project + - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: + - download: build # pipeline resource identifier. + artifact: 'drop-linux-gpu-${{ parameters.arch }}' + + - download: build # pipeline resource identifier. + artifact: 'onnxruntime${{ parameters.python_wheel_suffix }}' + + - bash: | + set -e -x + ls $(Pipeline.Workspace)/build + mv "$(Pipeline.Workspace)/build/drop-linux-gpu-${{ parameters.arch }}" $(Build.BinariesDirectory)/${{parameters.cmake_build_type}} + mv "$(Pipeline.Workspace)/build/onnxruntime${{ parameters.python_wheel_suffix }}" "$(Build.BinariesDirectory)/whl" + cp -r "$(Build.BinariesDirectory)/whl" $(Build.BinariesDirectory)/tmp + find "$(Build.BinariesDirectory)/tmp" -name '*.whl' -exec bash -c 'unzip -d "${1%.*}" "$1"' _ {} \; + + # The BinSkim task uses a dotnet program which doesn't support ARM CPUs yet + - ${{ if eq(parameters.arch, 'x86_64') }}: + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '$(Build.BinariesDirectory)/tmp/**/*.so' + continueOnError: true + + #- task: PostAnalysis@2 + # inputs: + # GdnBreakAllTools: true + # GdnBreakPolicy: M365 + # GdnBreakPolicyMinSev: Error + + - template: get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: "--network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }}" + Repository: onnxruntimecuda118xtrt86build${{ parameters.arch }} + + - task: Bash@3 + displayName: 'Bash Script' + inputs: + targetType: filePath + filePath: tools/ci_build/github/linux/run_python_dockertest.sh + arguments: -d GPU -c ${{parameters.cmake_build_type}} -i onnxruntimecuda118xtrt86build${{ parameters.arch }} + + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml deleted file mode 100644 index 8ddc917e8591e..0000000000000 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test.yml +++ /dev/null @@ -1,85 +0,0 @@ -parameters: -- name: arch - type: string - -- name: device - type: string - -- name: machine_pool - type: string - -- name: extra_job_id - type: string - default: '' - -- name: python_wheel_suffix - type: string - default: '' - - -# TODO: Ideally it should fetch information from the build that triggers it -- name: cmake_build_type - type: string - default: 'Release' - values: - - Debug - - Release - - RelWithDebInfo - - MinSizeRel - -- name: timeout - type: number - default: 120 - -jobs: -- job: Linux_Test_${{ parameters.device }}${{ parameters.extra_job_id }}_${{ parameters.arch }} - timeoutInMinutes: ${{ parameters.timeout }} - variables: - skipComponentGovernanceDetection: true - workspace: - clean: all - pool: ${{ parameters.machine_pool }} - steps: - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - artifactName: 'drop-linux-${{ lower(parameters.device) }}-${{ parameters.arch }}' - targetPath: '$(Build.BinariesDirectory)/${{parameters.cmake_build_type}}' - # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - buildType: current - # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: 841 - preferTriggeringPipeline: true - runVersion: 'latest' - buildType: specific - - - task: DownloadPipelineArtifact@2 - displayName: 'Download Pipeline Artifact' - inputs: - artifactName: 'onnxruntime${{ parameters.python_wheel_suffix }}' - targetPath: '$(Build.BinariesDirectory)/whl' - # The public ADO project - ${{ if eq(variables['System.CollectionId'], 'f3ad12f2-e480-4533-baf2-635c95467d29') }}: - buildType: current - # The private ADO project - ${{ if eq(variables['System.CollectionId'], 'bc038106-a83b-4dab-9dd3-5a41bc58f34c') }}: - project: '530acbc4-21bc-487d-8cd8-348ff451d2ff' - definition: 841 - preferTriggeringPipeline: true - runVersion: 'latest' - buildType: specific - - - - task: Bash@3 - displayName: 'Bash Script' - inputs: - targetType: filePath - filePath: tools/ci_build/github/linux/run_python_tests.sh - arguments: -d ${{ parameters.device }} -c ${{parameters.cmake_build_type}} - - - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 - displayName: 'Clean Agent Directories' - condition: always() diff --git a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml index 8c54e71448992..e63939ae0114c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/react-native-ci.yml @@ -80,7 +80,7 @@ stages: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - script: brew install coreutils ninja npm yarn diff --git a/tools/ci_build/github/azure-pipelines/templates/rocm.yml b/tools/ci_build/github/azure-pipelines/templates/rocm.yml index cc2e8745e8946..d43029266b4b0 100644 --- a/tools/ci_build/github/azure-pipelines/templates/rocm.yml +++ b/tools/ci_build/github/azure-pipelines/templates/rocm.yml @@ -51,7 +51,6 @@ jobs: --build-arg INSTALL_DEPS_EXTRA_ARGS=-tmur --build-arg BUILD_UID=$(id -u) --network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 - --build-arg ROCM_VERSION=$(RocmVersion) --build-arg DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root --build-arg PREPEND_PATH=/opt/rh/gcc-toolset-12/root/usr/bin: --build-arg LD_LIBRARY_PATH_ARG=/opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64:/usr/local/lib diff --git a/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml index 4494fd36b336e..96e6ff89cd4f1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-browserstack-ci.yml @@ -29,7 +29,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: DownloadPipelineArtifact@2 inputs: patterns: 'Release_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml index 0b7bd3f645442..3f6c6af753a98 100644 --- a/tools/ci_build/github/azure-pipelines/templates/web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/web-ci.yml @@ -74,7 +74,7 @@ stages: displayName: 'Checkout submodule onnx' - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: linux-web-init-and-check.yml - task: Bash@3 displayName: 'Extract commit SHA and save to __commit.txt' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml index 80d285f3fd3fb..8d28b4ce580b4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-ci.yml @@ -101,7 +101,7 @@ stages: - task: NodeTool@0 condition: and(succeeded(), eq('${{ parameters.buildNodejs}}', true)) inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: jobs/set-winenv.yml parameters: diff --git a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml index 9d36e2dbe4944..406683af80222 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml @@ -74,7 +74,7 @@ jobs: architecture: $(buildArch) - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - template: download-deps.yml - task: PythonScript@0 diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index bad7448715936..788b02f539821 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -29,6 +29,7 @@ jobs: pool: ${{ parameters.PoolName }} variables: + webgpuCommandlineExtraFlags: '--chromium-flags=--ignore-gpu-blocklist --chromium-flags=--gpu-vendor-id=0x10de' runCodesignValidationInjection: false timeoutInMinutes: 60 workspace: @@ -72,7 +73,7 @@ jobs: displayName: 'Testing: force EOL to lf on windows for /js/**' - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: DownloadPipelineArtifact@2 inputs: patterns: '${{ parameters.BuildConfig }}_*/**/*' @@ -159,12 +160,22 @@ jobs: npm test -- -e=edge -b=webgl,wasm,xnnpack workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (wasm,webgl,xnnpack backend)' - condition: ne('${{ parameters.RunWebGpuTests }}', 'true') + condition: eq('${{ parameters.RunWebGpuTests }}', 'false') - script: | - npm test -- -e=edge -b=webgl,wasm,xnnpack,webgpu --chromium-flags=--ignore-gpu-blocklist --chromium-flags=--gpu-vendor-id=0x10de + npm test -- -e=edge -b=webgl,wasm,xnnpack,webgpu $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (ALL backends)' - condition: ne('${{ parameters.RunWebGpuTests }}', 'false') + condition: eq('${{ parameters.RunWebGpuTests }}', 'true') + - script: | + npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-tensor $(webgpuCommandlineExtraFlags) + workingDirectory: '$(Build.SourcesDirectory)\js\web' + displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-tensor)' + condition: eq('${{ parameters.RunWebGpuTests }}', 'true') + - script: | + npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-location $(webgpuCommandlineExtraFlags) + workingDirectory: '$(Build.SourcesDirectory)\js\web' + displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-location)' + condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - script: | npm test -- --webgl-texture-pack-mode -b=webgl -e=edge workingDirectory: '$(Build.SourcesDirectory)\js\web' diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml index 723567389579d..f7876f15029c1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-multi-browsers.yml @@ -33,7 +33,7 @@ jobs: displayName: 'Checkout submodule onnx' - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: DownloadPipelineArtifact@2 inputs: patterns: 'Release_*/**/*' diff --git a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml index f3a5728d6519b..98f1bf7ea1a16 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml @@ -32,7 +32,7 @@ jobs: - task: NodeTool@0 inputs: - versionSpec: '16.x' + versionSpec: '18.x' - task: NuGetToolInstaller@0 displayName: Use Nuget 5.7.0 diff --git a/tools/ci_build/github/linux/copy_strip_binary.sh b/tools/ci_build/github/linux/copy_strip_binary.sh index b875a3937aaa9..63690b69fc91a 100755 --- a/tools/ci_build/github/linux/copy_strip_binary.sh +++ b/tools/ci_build/github/linux/copy_strip_binary.sh @@ -48,6 +48,7 @@ fi cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_c_api.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_api.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_cxx_inline.h $BINARY_DIR/$ARTIFACT_NAME/include +cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_float16.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/providers/cpu/cpu_provider_factory.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h $BINARY_DIR/$ARTIFACT_NAME/include cp $SOURCE_DIR/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h $BINARY_DIR/$ARTIFACT_NAME/include diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu index a9a1e6b39a8cb..af87852561e0a 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cpu @@ -1,9 +1,9 @@ -ARG BASEIMAGE=amd64/almalinux:8 +ARG BASEIMAGE=registry.access.redhat.com/ubi8/ubi ARG POLICY=manylinux_2_28 ARG PLATFORM=x86_64 ARG DEVTOOLSET_ROOTPATH=/opt/rh/gcc-toolset-12/root ARG LD_LIBRARY_PATH_ARG=${DEVTOOLSET_ROOTPATH}/usr/lib64:${DEVTOOLSET_ROOTPATH}/usr/lib:${DEVTOOLSET_ROOTPATH}/usr/lib64/dyninst:${DEVTOOLSET_ROOTPATH}/usr/lib/dyninst:/usr/local/lib64 -ARG PREPEND_PATH=${DEVTOOLSET_ROOTPATH}/usr/bin: +ARG PREPEND_PATH=/usr/lib/jvm/msopenjdk-11/bin:${DEVTOOLSET_ROOTPATH}/usr/bin: #Build manylinux2014 docker image begin FROM $BASEIMAGE AS runtime_base @@ -26,7 +26,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -35,7 +34,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -137,9 +135,7 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ - build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ build_scripts/requirements3.10.txt \ @@ -156,6 +152,7 @@ CMD ["/bin/bash"] #Build manylinux2014 docker image end ENV PATH ${DEVTOOLSET_ROOTPATH}/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 index dab8df6703c4f..933b0211b0e6c 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ @@ -156,7 +153,7 @@ ENV SSL_CERT_FILE=/opt/_internal/certs.pem CMD ["/bin/bash"] #Build manylinux2014 docker image end - +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 #Add our own dependencies ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 index 303e83eb23bca..003bb2324c049 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ @@ -163,7 +160,7 @@ RUN v="8.4.1-1.cuda11.6" &&\ yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo &&\ yum -y install libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} \ libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} - +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 #Add our own dependencies ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 index d17e4b24582fe..0337ffc5e00a0 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ @@ -163,7 +160,7 @@ RUN v="8.5.1-1.cuda11.8" &&\ yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo &&\ yum -y install libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} \ libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} - +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 #Add our own dependencies ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 index 3c0ac22e38b5a..2c953a10cbf64 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -147,7 +145,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.7.txt \ build_scripts/requirements3.8.txt \ @@ -171,7 +168,7 @@ RUN v="8.6.1.6-1.cuda11.8" &&\ yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo &&\ yum -y install libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} libnvinfer-vc-plugin8-${v}\ libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} libnvinfer-vc-plugin-devel-${v} libnvinfer-headers-devel-${v} libnvinfer-headers-plugin-devel-${v} - +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 #Add our own dependencies ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm index 10ce8f0ed65f7..19599c9f613d4 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_rocm @@ -185,6 +185,13 @@ RUN cd /tmp/scripts && \ rm -rf /tmp/scripts +# Install ccache to reuse this dockerfile for CI +RUN mkdir -p /tmp/ccache && \ + cd /tmp/ccache && \ + wget -q -O - https://github.com/ccache/ccache/releases/download/v4.7.4/ccache-4.7.4-linux-x86_64.tar.xz | tar --strip 1 -J -xf - && \ + cp /tmp/ccache/ccache /usr/bin && \ + rm -rf /tmp/ccache + ARG BUILD_UID=1001 ARG BUILD_USER=onnxruntimedev RUN adduser --uid $BUILD_UID $BUILD_USER diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 index 326e15d58456a..09ab7951552a0 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_training_cuda11_8 @@ -31,7 +31,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -40,7 +39,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -140,7 +138,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 index cdf504c8e3b03..bbdb411b790a0 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 @@ -22,7 +22,7 @@ RUN dnf install -y \ ln -s /usr/bin/pip3 pip3.8; RUN pip3 install --upgrade pip -RUN pip3 install setuptools>=41.0.0 +RUN pip3 install setuptools>=68.2.2 # Install TensorRT RUN dnf install -y libnvinfer8 libnvonnxparsers8 libnvparsers8 libnvinfer-plugin8 libnvinfer-lean8 libnvinfer-vc-plugin8 libnvinfer-dispatch8 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 index c211fa9b9e2b8..83a974469234f 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubuntu_cuda11_8_tensorrt8_6 @@ -7,40 +7,30 @@ # Build base image with required system packages FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu20.04 AS base -# The local directory into which to build and install CMAKE -ARG ONNXRUNTIME_LOCAL_CODE_DIR=/code - -ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${ONNXRUNTIME_LOCAL_CODE_DIR}/cmake-3.27.3-linux-x86_64/bin:/opt/miniconda/bin:${PATH} +ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/src/tensorrt/bin:${PATH} ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update &&\ - apt-get install -y sudo git bash unattended-upgrades wget -RUN unattended-upgrade + apt-get install -y git bash wget # Install python3 RUN apt-get install -y --no-install-recommends \ python3 \ python3-pip \ python3-dev \ - python3-wheel &&\ - cd /usr/local/bin &&\ - ln -s /usr/bin/python3 python &&\ - ln -s /usr/bin/pip3 pip; + python3-wheel + RUN pip install --upgrade pip -RUN pip install setuptools>=41.0.0 # Install TensorRT RUN v="8.6.1.6-1+cuda11.8" &&\ apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/7fa2af80.pub &&\ apt-get update &&\ - sudo apt-get install -y libnvinfer8=${v} libnvonnxparsers8=${v} libnvparsers8=${v} libnvinfer-plugin8=${v} libnvinfer-lean8=${v} libnvinfer-vc-plugin8=${v} libnvinfer-dispatch8=${v}\ + apt-get install -y libnvinfer8=${v} libnvonnxparsers8=${v} libnvparsers8=${v} libnvinfer-plugin8=${v} libnvinfer-lean8=${v} libnvinfer-vc-plugin8=${v} libnvinfer-dispatch8=${v}\ libnvinfer-headers-dev=${v} libnvinfer-headers-plugin-dev=${v} libnvinfer-dev=${v} libnvonnxparsers-dev=${v} libnvparsers-dev=${v} libnvinfer-plugin-dev=${v} libnvinfer-lean-dev=${v} libnvinfer-vc-plugin-dev=${v} libnvinfer-dispatch-dev=${v}\ python3-libnvinfer=${v} libnvinfer-samples=${v} tensorrt-dev=${v} tensorrt-libs=${v} -# Install Valgrind -RUN apt-get install -y valgrind - ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_dotnet.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4 index 10f404c7c6a85..8b32425afce1c 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4 @@ -31,7 +31,7 @@ RUN apt-get install -y --no-install-recommends \ ln -s /usr/bin/pip3 pip; RUN pip install --upgrade pip -RUN pip install setuptools>=41.0.0 +RUN pip install setuptools>=68.2.2 # Install TensorRT RUN v="8.4.1-1+cuda11.6" &&\ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5 index cacc09f0c7455..cfc7023ef8e61 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5 @@ -28,7 +28,7 @@ RUN apt-get install -y --no-install-recommends \ ln -s /usr/bin/pip3 pip; RUN pip install --upgrade pip -RUN pip install setuptools>=41.0.0 +RUN pip install setuptools>=68.2.2 # Install TensorRT RUN v="8.5.1-1+cuda11.8" &&\ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 index 0a4885e774047..edc41197be5c9 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 @@ -28,7 +28,7 @@ RUN apt-get install -y --no-install-recommends \ ln -s /usr/bin/pip3 pip; RUN pip install --upgrade pip -RUN pip install setuptools>=41.0.0 +RUN pip install setuptools>=68.2.2 # Install TensorRT RUN v="8.6.1.6-1+cuda11.8" &&\ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin index c9308ade37396..21b09b2d8978e 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin @@ -42,7 +42,7 @@ RUN apt-get install -y --no-install-recommends \ ln -s /usr/bin/pip3 pip; RUN pip install --upgrade pip -RUN pip install setuptools>=41.0.0 +RUN pip install setuptools>=68.2.2 # Install TensorRT from tar.gz RUN tar -xzvf /TensorRT-${TAR_TRT_VERSION}.tar.gz diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh index a1ade39e57e16..adb0464d6496a 100755 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_centos.sh @@ -1,9 +1,8 @@ #!/bin/bash set -e -x -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for CentOS version : $os_major_version" - -dnf install -y glibc-langpack-\* glibc-locale-source which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel java-11-openjdk-devel graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran -locale \ No newline at end of file +dnf install -y python39-devel glibc-langpack-\* glibc-locale-source which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel java-11-openjdk-devel graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran +locale diff --git a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh index 7ecd0525c7e7e..7598ab0a7a536 100755 --- a/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh +++ b/tools/ci_build/github/linux/docker/inference/aarch64/default/cpu/scripts/install_deps.sh @@ -14,20 +14,20 @@ function GetFile { echo "File '$path' already exists. Skipping download" return 0 else - rm -rf $path + rm -rf "$path" fi fi if [[ -f $uri ]]; then echo "'$uri' is a file path, copying file to '$path'" - cp $uri $path + cp "$uri" "$path" return $? fi echo "Downloading $uri" # Use aria2c if available, otherwise use curl if command -v aria2c > /dev/null; then - aria2c -q -d $(dirname $path) -o $(basename $path) "$uri" + aria2c -q -d "$(dirname $path)" -o "$(basename $path)" "$uri" else curl "$uri" -sSL --retry $download_retries --retry-delay $retry_wait_time_seconds --create-dirs -o "$path" --fail fi @@ -38,9 +38,10 @@ mkdir -p /tmp/src cd /tmp/src +CPU_ARCH=$(uname -m) echo "Installing cmake" -GetFile https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3-linux-`uname -m`.tar.gz /tmp/src/cmake-3.27.3-linux-`uname -m`.tar.gz -tar -zxf /tmp/src/cmake-3.27.3-linux-`uname -m`.tar.gz --strip=1 -C /usr +GetFile "https://github.com/Kitware/CMake/releases/download/v3.27.3/cmake-3.27.3-linux-$CPU_ARCH.tar.gz" "/tmp/src/cmake.tar.gz" +tar -zxf /tmp/src/cmake.tar.gz --strip=1 -C /usr echo "Installing Ninja" GetFile https://github.com/ninja-build/ninja/archive/v1.10.0.tar.gz /tmp/src/ninja-linux.tar.gz @@ -52,7 +53,7 @@ mv ./build-cmake/ninja /usr/bin popd echo "Installing Node.js" -CPU_ARCH=`uname -m` + if [[ "$CPU_ARCH" = "x86_64" ]]; then NODEJS_ARCH=x64 elif [[ "$CPU_ARCH" = "aarch64" ]]; then @@ -64,16 +65,5 @@ fi GetFile https://nodejs.org/dist/v18.17.1/node-v18.17.1-linux-${NODEJS_ARCH}.tar.gz /tmp/src/node-v18.17.1-linux-${NODEJS_ARCH}.tar.gz tar --strip 1 -xf /tmp/src/node-v18.17.1-linux-${NODEJS_ARCH}.tar.gz -C /usr -# The Python version in CentOS 7's python3 package is no longer supported (3.6) so we will build Python from source. -echo "Installing Python" -PYTHON_VERSION="3.8.17" -GetFile https://www.python.org/ftp/python/${PYTHON_VERSION}/Python-${PYTHON_VERSION}.tgz /tmp/src/Python-${PYTHON_VERSION}.tgz -tar -zxf Python-${PYTHON_VERSION}.tgz -pushd Python-${PYTHON_VERSION} -./configure -make -make install -popd - cd / rm -rf /tmp/src diff --git a/tools/ci_build/github/linux/docker/inference/x64/default/cpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x64/default/cpu/Dockerfile index 0324f377b8e9e..caf9583807b62 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/default/cpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x64/default/cpu/Dockerfile @@ -5,10 +5,10 @@ ARG BASEIMAGE=amd64/almalinux:8 FROM $BASEIMAGE -ENV PATH /opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin +ENV PATH /usr/lib/jvm/msopenjdk-11/bin:/opt/rh/gcc-toolset-12/root/usr/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 - +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x64/default/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x64/default/cpu/scripts/install_centos.sh index 8e18a237a807e..b5f8bf1a49a19 100755 --- a/tools/ci_build/github/linux/docker/inference/x64/default/cpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/x64/default/cpu/scripts/install_centos.sh @@ -1,9 +1,9 @@ #!/bin/bash set -e -x -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for CentOS version : $os_major_version" - -dnf install -y python39-devel glibc-langpack-\* glibc-locale-source which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel java-11-openjdk-devel graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran -locale \ No newline at end of file +rpm -Uvh https://packages.microsoft.com/config/centos/$os_major_version/packages-microsoft-prod.rpm +dnf install -y python39-devel glibc-langpack-\* glibc-locale-source which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel msopenjdk-11 graphviz gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran +locale diff --git a/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile b/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile index 386759890d085..318791072f46d 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile +++ b/tools/ci_build/github/linux/docker/inference/x64/default/gpu/Dockerfile @@ -4,8 +4,10 @@ # This file is used by Zip-Nuget Packaging NoContribOps Pipeline,Zip-Nuget-Java Packaging Pipeline FROM nvidia/cuda:11.8.0-cudnn8-devel-ubi8 +ENV PATH /usr/lib/jvm/msopenjdk-11/bin:/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin ENV LANG=en_US.UTF-8 ENV LC_ALL=en_US.UTF-8 +ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 ADD scripts /tmp/scripts RUN cd /tmp/scripts && /tmp/scripts/install_centos.sh && /tmp/scripts/install_deps.sh && rm -rf /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/inference/x64/default/gpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x64/default/gpu/scripts/install_centos.sh index 3cf259dc7240e..31e3e40f1b7ee 100755 --- a/tools/ci_build/github/linux/docker/inference/x64/default/gpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/x64/default/gpu/scripts/install_centos.sh @@ -1,9 +1,9 @@ #!/bin/bash set -e -x -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for CentOS version : $os_major_version" - -dnf install -y python39-devel python3-devel glibc-langpack-\* glibc-locale-source which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel java-11-openjdk-devel -locale \ No newline at end of file +rpm -Uvh https://packages.microsoft.com/config/centos/$os_major_version/packages-microsoft-prod.rpm +dnf install -y python39-devel glibc-langpack-\* glibc-locale-source which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel msopenjdk-11 +locale diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2_28_cpu b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2_28_cpu index 33660cbb3f2e5..06e75ee1a39f6 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2_28_cpu +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/Dockerfile.manylinux2_28_cpu @@ -26,7 +26,6 @@ COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors # setup entrypoint, this will wrap commands with `linux32` with i686 images COPY build_scripts/install-entrypoint.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ @@ -35,7 +34,6 @@ COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint ENTRYPOINT ["manylinux-entrypoint"] COPY build_scripts/install-runtime-packages.sh \ - build_scripts/update-system-packages.sh \ build_scripts/build_utils.sh \ /build_scripts/ RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ @@ -132,7 +130,6 @@ COPY --from=build_git /manylinux-rootfs / COPY --from=build_cpython /manylinux-rootfs / COPY --from=all_python /opt/_internal /opt/_internal/ COPY build_scripts/finalize.sh \ - build_scripts/update-system-packages.sh \ build_scripts/python-tag-abi-tag.py \ build_scripts/requirements3.8.txt \ build_scripts/requirements3.9.txt \ diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_centos.sh b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_centos.sh index 98bb730a43776..c81e57c60c9da 100755 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_centos.sh +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/install_centos.sh @@ -1,11 +1,11 @@ #!/bin/bash set -e -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for os major version : $os_major_version" dnf install -y glibc-langpack-\* -yum install -y which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget +yum install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget # export PATH=/opt/python/cp38-cp38/bin:$PATH @@ -17,4 +17,4 @@ mkdir build cd build cmake .. cmake --install . -cd ../.. \ No newline at end of file +cd ../.. diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt index c0c6505ca010d..5341ae062d332 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt @@ -2,9 +2,9 @@ numpy==1.21.6 ; python_version < '3.11' numpy==1.24.2 ; python_version >= '3.11' mypy pytest -setuptools>=41.4.0 +setuptools>=68.2.2 wheel -git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx +git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx protobuf==3.20.2 -sympy==1.10.1 +sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/manylinux.patch b/tools/ci_build/github/linux/docker/manylinux.patch index f1821f9197525..75923e746f93c 100644 --- a/tools/ci_build/github/linux/docker/manylinux.patch +++ b/tools/ci_build/github/linux/docker/manylinux.patch @@ -94,7 +94,7 @@ index 9ef1e99..ec52833 100755 +fi \ No newline at end of file diff --git a/install-runtime-packages.sh b/install-runtime-packages.sh -index 137d2e2..4269afb 100755 +index 137d2e2..203b4bc 100755 --- a/install-runtime-packages.sh +++ b/install-runtime-packages.sh @@ -33,7 +33,7 @@ source $MY_DIR/build_utils.sh @@ -130,7 +130,7 @@ index 137d2e2..4269afb 100755 elif [ "${AUDITWHEEL_ARCH}" == "aarch64" ] || [ "${AUDITWHEEL_ARCH}" == "ppc64le" ] || [ "${AUDITWHEEL_ARCH}" == "s390x" ]; then # Software collection (for devtoolset-10) yum -y install centos-release-scl-rh -@@ -86,19 +88,18 @@ if [ "${AUDITWHEEL_POLICY}" == "manylinux2014" ]; then +@@ -86,19 +88,21 @@ if [ "${AUDITWHEEL_POLICY}" == "manylinux2014" ]; then fi elif [ "${AUDITWHEEL_POLICY}" == "manylinux_2_28" ]; then PACKAGE_MANAGER=dnf @@ -148,6 +148,9 @@ index 137d2e2..4269afb 100755 - TOOLCHAIN_DEPS="gcc-toolset-12-binutils gcc-toolset-12-gcc gcc-toolset-12-gcc-c++ gcc-toolset-12-gcc-gfortran" - if [ "${AUDITWHEEL_ARCH}" == "x86_64" ]; then - TOOLCHAIN_DEPS="${TOOLCHAIN_DEPS} yasm" ++ if test -f "/etc/yum.repos.d/ubi.repo"; then ++ sed -i 's/enabled\s*=\s*1/enabled = 1\nexclude=dotnet* aspnet* netstandard*/g' /etc/yum.repos.d/ubi.repo ++ fi + if [[ -d /usr/local/cuda ]]; then + TOOLCHAIN_DEPS="gcc gcc-c++" + else @@ -155,7 +158,7 @@ index 137d2e2..4269afb 100755 fi elif [ "${AUDITWHEEL_POLICY}" == "musllinux_1_1" ]; then TOOLCHAIN_DEPS="binutils gcc g++ gfortran" -@@ -121,12 +122,6 @@ else +@@ -121,12 +125,6 @@ else exit 1 fi diff --git a/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh b/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh index b9accb134b26d..c4689ed19c148 100755 --- a/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh +++ b/tools/ci_build/github/linux/docker/scripts/install_dotnet.sh @@ -2,13 +2,15 @@ set -e -x if [ -f /etc/redhat-release ]; then - dnf update --refresh -y \ - && dnf install -y dotnet-sdk-6.0 + # If you found the following command went successfully but dotnet command still reports no sdk was found, most likely + # it was because the dotnet packages were installed from more than one dnf repos. + dnf install -y dotnet-sdk-6.0 dotnet-runtime-6.0 elif [ -f /etc/os-release ]; then # Get Ubuntu version - declare repo_version=$(if command -v lsb_release &> /dev/null; then lsb_release -r -s; else grep -oP '(?<=^VERSION_ID=).+' /etc/os-release | tr -d '"'; fi) + declare repo_version + repo_version=$(if command -v lsb_release &> /dev/null; then lsb_release -r -s; else grep -oP '(?<=^VERSION_ID=).+' /etc/os-release | tr -d '"'; fi) # Download Microsoft signing key and repository - wget https://packages.microsoft.com/config/ubuntu/$repo_version/packages-microsoft-prod.deb -O packages-microsoft-prod.deb + wget "https://packages.microsoft.com/config/ubuntu/$repo_version/packages-microsoft-prod.deb" -O packages-microsoft-prod.deb # Install Microsoft signing key and repository dpkg -i packages-microsoft-prod.deb # Clean up diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh index 4f544a50cb94d..63b953a95add6 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_centos.sh @@ -1,17 +1,18 @@ #!/bin/bash set -e -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) echo "installing for os major version : $os_major_version" if [ "$os_major_version" -gt 7 ]; then PACKAGE_MANAGER="dnf" - $PACKAGE_MANAGER install -y which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget + $PACKAGE_MANAGER install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget else PACKAGE_MANAGER="yum" - $PACKAGE_MANAGER install -y which gdb redhat-lsb-core expat-devel tar unzip zlib-devel make libunwind bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget + $PACKAGE_MANAGER install -y which redhat-lsb-core expat-devel tar unzip zlib-devel make libunwind bzip2 bzip2-devel perl-IPC-Cmd openssl-devel wget fi +rpm -Uvh https://packages.microsoft.com/config/centos/$os_major_version/packages-microsoft-prod.rpm # Install Java # Install automatic documentation generation dependencies -$PACKAGE_MANAGER install -y java-11-openjdk-devel graphviz +$PACKAGE_MANAGER install -y msopenjdk-11 graphviz diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh index a1cb4be5b72c9..8c79918120d8d 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps.sh @@ -3,18 +3,20 @@ set -e -x # Development tools and libraries if [ -f /etc/redhat-release ]; then - yum update && yum -y install graphviz - os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) + dnf -y install graphviz elif [ -f /etc/os-release ]; then apt-get update && apt-get install -y graphviz - os_major_version=$(cat /etc/os-release | tr -dc '0-9.'|cut -d \. -f1) else echo "Unsupported OS" exit 1 fi # Install dotnet -source $(cd "$(dirname "${BASH_SOURCE[0]}")/.." &> /dev/null && pwd)/install_dotnet.sh +LOCAL_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)" +PARENT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." &> /dev/null && pwd)" +# ShellCheck is unable to follow dynamic paths, such as source "$somedir/file". +# shellcheck disable=SC1091 +source "$PARENT_DIR/install_dotnet.sh" if [ ! -d "/opt/conda/bin" ]; then PYTHON_EXES=("/opt/python/cp38-cp38/bin/python3.8" "/opt/python/cp39-cp39/bin/python3.9" "/opt/python/cp310-cp310/bin/python3.10" "/opt/python/cp311-cp311/bin/python3.11") @@ -22,23 +24,17 @@ else PYTHON_EXES=("/opt/conda/bin/python") fi -SYS_LONG_BIT=$(getconf LONG_BIT) mkdir -p /tmp/src -GLIBC_VERSION=$(getconf GNU_LIBC_VERSION | cut -f 2 -d \.) - -if [[ $SYS_LONG_BIT = "64" ]]; then - LIBDIR="lib64" -else - LIBDIR="lib" -fi cd /tmp/src -source $(cd "$(dirname "${BASH_SOURCE[0]}")" &> /dev/null && pwd)/install_shared_deps.sh +# shellcheck disable=SC1091 +source "$LOCAL_DIR/install_shared_deps.sh" cd /tmp/src if ! [ -x "$(command -v protoc)" ]; then - source ${0/%install_deps.sh/..\/install_protobuf.sh} +# shellcheck disable=SC1091 + source "$PARENT_DIR/install_protobuf.sh" fi export ONNX_ML=1 @@ -46,7 +42,7 @@ export CMAKE_ARGS="-DONNX_GEN_PB_TYPE_STUBS=OFF -DONNX_WERROR=OFF" for PYTHON_EXE in "${PYTHON_EXES[@]}" do - ${PYTHON_EXE} -m pip install -r ${0/%install_deps\.sh/requirements\.txt} + ${PYTHON_EXE} -m pip install -r "${0/%install_deps\.sh/requirements\.txt}" done cd / diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh index ed220b487d06c..1f85f72aef423 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_aten.sh @@ -11,7 +11,7 @@ else PYTHON_EXES=("/opt/conda/bin/python") fi -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) SYS_LONG_BIT=$(getconf LONG_BIT) mkdir -p /tmp/src diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_eager.sh b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_eager.sh index e141e0793a2bd..ad3366b0bb3b6 100755 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_eager.sh +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/install_deps_eager.sh @@ -11,7 +11,7 @@ else PYTHON_EXES=("/opt/conda/bin/python") fi -os_major_version=$(cat /etc/redhat-release | tr -dc '0-9.'|cut -d \. -f1) +os_major_version=$(tr -dc '0-9.' < /etc/redhat-release |cut -d \. -f1) SYS_LONG_BIT=$(getconf LONG_BIT) mkdir -p /tmp/src diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index c8ff7a804e1df..b2893286803b0 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -2,10 +2,10 @@ numpy==1.21.6 ; python_version < '3.11' numpy==1.24.2 ; python_version >= '3.11' mypy pytest -setuptools>=41.4.0 +setuptools>=68.2.2 wheel -git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx +git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx protobuf==3.20.2 -sympy==1.10.1 +sympy==1.12 flatbuffers neural-compressor>=2.2.1 diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index 2248652c98043..5d48a93b09c90 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -3,11 +3,11 @@ numpy==1.21.6 ; python_version < '3.11' numpy==1.24.2 ; python_version >= '3.11' mypy pytest -setuptools>=41.4.0 +setuptools>=68.2.2 wheel>=0.35.1 -git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx +git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx argparse -sympy==1.10.1 +sympy==1.12 flatbuffers protobuf==3.20.2 packaging diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.0.0_cu11.8/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.0.0_cu11.8/requirements.txt index fa28a810370f7..b3b2651c8d26d 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.0.0_cu11.8/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.0.0_cu11.8/requirements.txt @@ -5,4 +5,4 @@ torchvision==0.15.1+cu118 torchtext==0.15.1 # TODO(bmeswani): packaging 22.0 removes support for LegacyVersion leading to errors because transformers 4.4.2 uses LegacyVersion packaging==21.3 -setuptools>=41.4.0 +setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt index 94b16d7ff4894..95d02b8400339 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/torch_stable.html torch==2.0.0+cpu -setuptools>=41.4.0 +setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt index b071770c629e2..08e251eddbf96 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt @@ -1,7 +1,7 @@ --pre -f https://download.pytorch.org/whl/torch_stable.html torch==1.13.1+cpu -setuptools>=41.4.0 +setuptools>=68.2.2 cerberus h5py scikit-learn diff --git a/tools/ci_build/github/linux/run_python_dockertest.sh b/tools/ci_build/github/linux/run_python_dockertest.sh new file mode 100755 index 0000000000000..332dd9c7284c0 --- /dev/null +++ b/tools/ci_build/github/linux/run_python_dockertest.sh @@ -0,0 +1,29 @@ +#!/bin/bash +set -e -x +BUILD_CONFIG="Release" + +while getopts "i:d:x:c:" parameter_Option +do case "${parameter_Option}" +in +i) DOCKER_IMAGE=${OPTARG};; +d) DEVICE=${OPTARG};; +c) BUILD_CONFIG=${OPTARG};; +esac +done + +if [ $DEVICE = "GPU" ]; then + ADDITIONAL_DOCKER_PARAMETER="--gpus all" +fi + +mkdir -p $HOME/.onnx +docker run --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $BUILD_SOURCESDIRECTORY:/onnxruntime_src \ + --volume $BUILD_BINARIESDIRECTORY:/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + -w /onnxruntime_src \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + $ADDITIONAL_DOCKER_PARAMETER \ + $DOCKER_IMAGE tools/ci_build/github/linux/run_python_tests.sh -d $DEVICE -c $BUILD_CONFIG diff --git a/tools/ci_build/github/linux/run_python_tests.sh b/tools/ci_build/github/linux/run_python_tests.sh index c11ea42cd0541..f080c7e8c39d8 100755 --- a/tools/ci_build/github/linux/run_python_tests.sh +++ b/tools/ci_build/github/linux/run_python_tests.sh @@ -15,7 +15,8 @@ c) BUILD_CONFIG=${OPTARG};; esac done -cd $BUILD_BINARIESDIRECTORY +export PATH=/opt/python/cp38-cp38/bin:$PATH +cd /build files=(whl/*.whl) FILE_NAME="${files[0]}" FILE_NAME=$(basename $FILE_NAME) @@ -23,7 +24,7 @@ PYTHON_PACKAGE_NAME=$(echo "$FILE_NAME" | cut -f 1 -d '-') echo "Package name:$PYTHON_PACKAGE_NAME" -BUILD_ARGS="--build_dir $BUILD_BINARIESDIRECTORY --config $BUILD_CONFIG --test --skip_submodule_sync --parallel --enable_lto --build_wheel " +BUILD_ARGS="--build_dir /build --config $BUILD_CONFIG --test --skip_submodule_sync --parallel --enable_lto --build_wheel " ARCH=$(uname -m) @@ -35,20 +36,15 @@ if [ $BUILD_DEVICE == "GPU" ]; then BUILD_ARGS="$BUILD_ARGS --use_cuda --use_tensorrt --cuda_version=11.8 --tensorrt_home=/usr --cuda_home=/usr/local/cuda-11.8 --cudnn_home=/usr/local/cuda-11.8" fi # We assume the machine doesn't have gcc and python development header files, so we don't build onnxruntime from source -sudo rm -rf /build /onnxruntime_src -sudo ln -s $BUILD_SOURCESDIRECTORY /onnxruntime_src python3 -m pip install --upgrade pip -python3 -m pip uninstall -y $PYTHON_PACKAGE_NAME ort-nightly-gpu ort-nightly onnxruntime onnxruntime-gpu onnxruntime-training onnxruntime-directml ort-nightly-directml onnx -qq # Install the packages that are needed for installing the onnxruntime python package -python3 -m pip install -r $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG/requirements.txt +python3 -m pip install -r /build/$BUILD_CONFIG/requirements.txt # Install the packages that are needed for running test scripts -# Install the latest ONNX release which may contain not fixed bugs. However, it is what most people use. -python3 -m pip install onnx pytest +python3 -m pip install pytest # The "--no-index" flag is crucial. The local whl folder is just an additional source. Pypi's doc says "there is no # ordering in the locations that are searched" if we don't disable the default one with "--no-index" -python3 -m pip install --no-index --find-links $BUILD_BINARIESDIRECTORY/whl $PYTHON_PACKAGE_NAME -ln -s /data/models $BUILD_BINARIESDIRECTORY -cd $BUILD_BINARIESDIRECTORY/$BUILD_CONFIG +python3 -m pip install --no-index --find-links /build/whl $PYTHON_PACKAGE_NAME +cd /build/$BUILD_CONFIG # Restore file permissions xargs -a perms.txt chmod a+x -python3 $BUILD_SOURCESDIRECTORY/tools/ci_build/build.py $BUILD_ARGS --ctest_path '' +python3 /onnxruntime_src/tools/ci_build/build.py $BUILD_ARGS --ctest_path '' diff --git a/tools/ci_build/github/windows/helpers.ps1 b/tools/ci_build/github/windows/helpers.ps1 index 6e81f901a8288..20df10b244408 100644 --- a/tools/ci_build/github/windows/helpers.ps1 +++ b/tools/ci_build/github/windows/helpers.ps1 @@ -444,7 +444,7 @@ function Install-Abseil { .Description The Install-UTF8-Range function installs Google's utf8_range library from source. utf8_range depends on Abseil. - + .PARAMETER cmake_path The full path of cmake.exe @@ -604,14 +604,17 @@ function Install-ONNX { pushd . Write-Host "Uninstalling onnx and ignore errors if there is any..." - python.exe -m pip uninstall -y onnx -qq + [string[]]$pip_args ="-m", "pip", "uninstall", "-y", "onnx", "-qq" + &"python.exe" $pip_args + if ($lastExitCode -ne 0) { + exit $lastExitCode + } Write-Host "Installing python packages..." - $p = Start-Process -NoNewWindow -Wait -PassThru -FilePath "python.exe" -ArgumentList "-m", "pip", "install", "--disable-pip-version-check", "setuptools", "wheel", "numpy", "protobuf==$protobuf_version" - $exitCode = $p.ExitCode - if ($exitCode -ne 0) { - Write-Host -Object "Install dependent python wheels failed. Exitcode: $exitCode" - exit $exitCode + [string[]]$pip_args = "-m", "pip", "install", "-qq", "--disable-pip-version-check", "setuptools>=68.2.2", "wheel", "numpy", "protobuf==$protobuf_version" + &"python.exe" $pip_args + if ($lastExitCode -ne 0) { + exit $lastExitCode } $url=Get-DownloadURL -name onnx -src_root $src_root diff --git a/tools/perf_view/ort_perf_view.html b/tools/perf_view/ort_perf_view.html index e00e38702d342..509fe5593f6a1 100644 --- a/tools/perf_view/ort_perf_view.html +++ b/tools/perf_view/ort_perf_view.html @@ -5,7 +5,7 @@ - +