diff --git a/cgmanifests/generate_cgmanifest.py b/cgmanifests/generate_cgmanifest.py index a9eaacc6f2938..81181d3ccfb20 100644 --- a/cgmanifests/generate_cgmanifest.py +++ b/cgmanifests/generate_cgmanifest.py @@ -90,55 +90,6 @@ def add_github_dep(name, parsed_url): git_deps[dep] = name -with open( - os.path.join(REPO_DIR, "tools", "ci_build", "github", "linux", "docker", "Dockerfile.manylinux2_28_cuda11"), -) as f: - for line in f: - if not line.strip(): - package_name = None - package_filename = None - package_url = None - if package_filename is None: - m = re.match(r"RUN\s+export\s+(.+?)_ROOT=(\S+).*", line) - if m is not None: - package_name = m.group(1) - package_filename = m.group(2) - else: - m = re.match(r"RUN\s+export\s+(.+?)_VERSION=(\S+).*", line) - if m is not None: - package_name = m.group(1) - package_filename = m.group(2) - elif package_url is None: - m = re.match(r"(.+?)_DOWNLOAD_URL=(\S+)", line) - if m is not None: - package_url = m.group(2) - if package_name == "LIBXCRYPT": - package_url = m.group(2) + "/v" + package_filename + ".tar.gz" - elif package_name == "CMAKE": - package_url = m.group(2) + "/v" + package_filename + "/cmake-" + package_filename + ".tar.gz" - else: - package_url = m.group(2) + "/" + package_filename + ".tar.gz" - parsed_url = urlparse(package_url) - if parsed_url.hostname == "github.com": - add_github_dep("manylinux dependency " + package_name, parsed_url) - else: - registration = { - "Component": { - "Type": "other", - "other": { - "Name": package_name.lower(), - "Version": package_filename.split("-")[-1], - "DownloadUrl": package_url, - }, - "comments": "manylinux dependency", - } - } - registrations.append(registration) - package_name = None - package_filename = None - package_url = None - - def normalize_path_separators(path): return path.replace(os.path.sep, "/") diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 6f1ca84e1a304..08ca90d7c3b7f 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -2,112 +2,6 @@ "$schema": "https://json.schemastore.org/component-detection-manifest.json", "Version": 1, "Registrations": [ - { - "Component": { - "Type": "other", - "other": { - "Name": "autoconf", - "Version": "2.71", - "DownloadUrl": "http://ftp.gnu.org/gnu/autoconf/autoconf-2.71.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "Component": { - "Type": "other", - "other": { - "Name": "automake", - "Version": "1.16.5", - "DownloadUrl": "http://ftp.gnu.org/gnu/automake/automake-1.16.5.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "Component": { - "Type": "other", - "other": { - "Name": "libtool", - "Version": "2.4.7", - "DownloadUrl": "http://ftp.gnu.org/gnu/libtool/libtool-2.4.7.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "Component": { - "Type": "other", - "other": { - "Name": "git", - "Version": "2.36.2", - "DownloadUrl": "https://www.kernel.org/pub/software/scm/git/git-2.36.2.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "Component": { - "Type": "other", - "other": { - "Name": "sqlite_autoconf", - "Version": "3390200", - "DownloadUrl": "https://www.sqlite.org/2022/sqlite-autoconf-3390200.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "Component": { - "Type": "other", - "other": { - "Name": "openssl", - "Version": "1.1.1q", - "DownloadUrl": "https://www.openssl.org/source/openssl-1.1.1q.tar.gz" - }, - "comments": "manylinux dependency" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "50cf2b6dd4fdf04309445f2eec8de7051d953abf", - "repositoryUrl": "https://github.com/besser82/libxcrypt.git" - }, - "comments": "manylinux dependency LIBXCRYPT" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "a896e3d066448b3530dbcaa48869fafefd738f57", - "repositoryUrl": "https://github.com/emscripten-core/emsdk.git" - }, - "comments": "git submodule at cmake/external/emsdk" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "7a2ed51a6b682a83e345ff49fc4cfd7ca47550db", - "repositoryUrl": "https://github.com/google/libprotobuf-mutator.git" - }, - "comments": "git submodule at cmake/external/libprotobuf-mutator" - } - }, - { - "component": { - "type": "git", - "git": { - "commitHash": "e2525550194ce3d8a2c4a3af451c9d9b3ae6650e", - "repositoryUrl": "https://github.com/onnx/onnx.git" - }, - "comments": "git submodule at cmake/external/onnx" - } - }, { "component": { "type": "git", @@ -268,6 +162,16 @@ "comments": "mp11" } }, + { + "component": { + "type": "git", + "git": { + "commitHash": "fdefbe85ed9c362b95b9b401cd19db068a76141f", + "repositoryUrl": "https://github.com/onnx/onnx.git" + }, + "comments": "onnx" + } + }, { "component": { "type": "git", diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 8d5b135b53ddc..08886c8fcf5f3 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -73,6 +73,11 @@ option(onnxruntime_ENABLE_PYTHON "Enable python buildings" OFF) # Enable it may cause LNK1169 error option(onnxruntime_ENABLE_MEMLEAK_CHECKER "Experimental: Enable memory leak checker in Windows debug build" OFF) option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) +# Enable ONNX Runtime CUDA EP's internal unit tests that directly access the EP's internal functions instead of through +# OpKernels. When the option is ON, we will have two copies of GTest library in the same process. It is not a typical +# use. If you hit any problem with that, please do not report it to GTest. Turn OFF the following build option instead. +cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS;LINUX" OFF) + option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF) option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) option(onnxruntime_USE_COREML "Build with CoreML support" OFF) @@ -146,10 +151,11 @@ option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OF option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF) option(onnxruntime_DISABLE_FLOAT8_TYPES "Disable float 8 types" OFF) option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF) -cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON" OFF) +cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON;NOT onnxruntime_USE_CUDA" OFF) # For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone cmake_dependent_option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." ON "onnxruntime_MINIMAL_BUILD;NOT onnxruntime_ENABLE_PYTHON" OFF) -option(onnxruntime_DISABLE_ABSEIL "Do not link to Abseil. Redefine Inlined containers to STD containers." OFF) +# Even when onnxruntime_DISABLE_ABSEIL is ON, ONNX Runtime still needs to link to abseil. +option(onnxruntime_DISABLE_ABSEIL "Do not use Abseil data structures in ONNX Runtime source code. Redefine Inlined containers to STD containers." OFF) option(onnxruntime_EXTENDED_MINIMAL_BUILD "onnxruntime_MINIMAL_BUILD with support for execution providers that compile kernels." OFF) option(onnxruntime_MINIMAL_BUILD_CUSTOM_OPS "Add custom operator kernels support to a minimal build." OFF) @@ -269,10 +275,6 @@ if (onnxruntime_ENABLE_TRAINING_APIS) endif() endif() -if (onnxruntime_USE_CUDA) - set(onnxruntime_DISABLE_RTTI OFF) -endif() - if (onnxruntime_USE_ROCM) if (WIN32) message(FATAL_ERROR "ROCM does not support build in Windows!") diff --git a/cmake/deps.txt b/cmake/deps.txt index 279b5ca649dba..7cf49f02333a4 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -24,7 +24,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 -onnx;https://github.com/onnx/onnx/archive/e2525550194ce3d8a2c4a3af451c9d9b3ae6650e.zip;782f23d788185887f520a90535513e244218e928 +onnx;https://github.com/onnx/onnx/archive/14303de049144035dfd94ace5f7a3b44773b1aad.zip;250eab9690392b248d75b56e605fb49eca373442 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/0462dc31ae78f48744b6141ae376df1f96d3f459.zip;5ff086361956cceb81ed17453a1fd8db2aa4328d protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa @@ -44,4 +44,4 @@ tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2 cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.0.0.zip;0f95b3c1fc1bd1175c4a90b2c9e39074d1bccefd utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156 extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c -composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/d52ec01652b7d620386251db92455968d8d90bdc.zip;6b5ce8edf3625f8817086c194fbf94b664e1b0e0 \ No newline at end of file +composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/d52ec01652b7d620386251db92455968d8d90bdc.zip;6b5ce8edf3625f8817086c194fbf94b664e1b0e0 diff --git a/cmake/deps_update_and_upload.py b/cmake/deps_update_and_upload.py new file mode 100644 index 0000000000000..194d21435f0be --- /dev/null +++ b/cmake/deps_update_and_upload.py @@ -0,0 +1,56 @@ +# in case deps.txt is updated, run this file to update and upload the dependencies so that CI can use them. +# Before running the script, increase the version number found at: +# https://aiinfra.visualstudio.com/Lotus/_artifacts/feed/Lotus/UPack/onnxruntime_build_dependencies/versions +# Run without --do-upload once to verify downloading. Use --do-upload when you are ready to publish. +# python cmake/deps_update_and_upload.py --root-path C:/temp/onnxruntime_deps --version 1.0.82 --do-upload +# update version number in tools\ci_build\github\azure-pipelines\templates\download-deps.yml +import re +import subprocess +import os +import argparse +import tempfile + +parser = argparse.ArgumentParser(description="Update dependencies and publish to Azure Artifacts") +parser.add_argument( + "--root-path", type=str, default=tempfile.gettempdir(), help="Target root path for downloaded files" +) +parser.add_argument("--version", type=str, default="1.0.82", help="Package version to publish") +parser.add_argument("--do-upload", action="store_true", help="Upload the package to Azure Artifacts") +args = parser.parse_args() + +with open("cmake/deps.txt") as file: + text = file.read() + +lines = [line for line in text.split("\n") if not line.startswith("#") and ";" in line] + +root_path = args.root_path + +for line in lines: + url = re.sub("^[^;]+?;https://([^;]+?);.*", r"https://\1", line) + filename = re.sub("^[^;]+?;https://([^;]+?);.*", r"\1", line) + full_path = os.path.join(root_path, filename) + subprocess.run(["curl", "-sSL", "--create-dirs", "-o", full_path, url]) + +package_name = "onnxruntime_build_dependencies" +version = args.version + +# Check if the user is logged in to Azure +result = subprocess.run("az account show", shell=True, capture_output=True, text=True) +if "No subscriptions found" in result.stderr: + # Prompt the user to log in to Azure + print("You are not logged in to Azure. Please log in to continue.") + subprocess.run("az login", shell=True) + +# Publish the package to Azure Artifacts if --no-upload is not specified + +cmd = f'az artifacts universal publish --organization https://dev.azure.com/onnxruntime --feed onnxruntime --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' +if args.do_upload: + subprocess.run(cmd, shell=True) +else: + print("would have run: " + cmd) + +cmd = f'az artifacts universal publish --organization https://dev.azure.com/aiinfra --feed Lotus --name {package_name} --version {version} --description "onnxruntime build time dependencies" --path {root_path}' +if args.do_upload: + subprocess.run(cmd, shell=True) +else: + print("would have run: " + cmd) diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index e1671bcf43ed9..019c6341d2e46 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -37,8 +37,12 @@ if (onnxruntime_BUILD_UNIT_TESTS) set(gtest_disable_pthreads ON) endif() set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) - # Set it to ON will cause crashes in onnxruntime_test_all when onnxruntime_USE_CUDA is ON - set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) + if (CMAKE_SYSTEM_NAME STREQUAL "iOS") + # Needs to update onnxruntime/test/xctest/xcgtest.mm + set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) + else() + set(GTEST_HAS_ABSL ON CACHE BOOL "" FORCE) + endif() # gtest and gmock FetchContent_Declare( googletest diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 59ebf8eca4306..0fe9a0a4b0bfb 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -18,35 +18,21 @@ if (${CMAKE_SYSTEM_NAME} STREQUAL "iOS") set(OUTPUT_STYLE xcode) endif() -set(ONNXRUNTIME_PUBLIC_HEADERS - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h" -) - -if (onnxruntime_ENABLE_TRAINING_APIS) - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h") - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h") - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h") -endif() - -# This macro is to get the path of header files for mobile packaging, for iOS and Android -macro(get_mobile_api_headers _HEADERS) - # include both c and cxx api - set(${_HEADERS} +# Gets the public C/C++ API header files +function(get_c_cxx_api_headers HEADERS_VAR) + set(_headers "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_c_api.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_api.h" - "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h" "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_cxx_inline.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h" + "${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h" ) if (onnxruntime_ENABLE_TRAINING_APIS) - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h") - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h") - list(APPEND ${_HEADERS} "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h") + list(APPEND _headers "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_c_api.h") + list(APPEND _headers "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_api.h") + list(APPEND _headers "${REPO_ROOT}/orttraining/orttraining/training_api/include/onnxruntime_training_cxx_inline.h") endif() # need to add header files for enabled EPs @@ -54,10 +40,13 @@ macro(get_mobile_api_headers _HEADERS) file(GLOB _provider_headers CONFIGURE_DEPENDS "${REPO_ROOT}/include/onnxruntime/core/providers/${f}/*.h" ) - list(APPEND ${_HEADERS} "${_provider_headers}") - unset(_provider_headers) + list(APPEND _headers ${_provider_headers}) endforeach() -endmacro() + + set(${HEADERS_VAR} ${_headers} PARENT_SCOPE) +endfunction() + +get_c_cxx_api_headers(ONNXRUNTIME_PUBLIC_HEADERS) #If you want to verify if there is any extra line in symbols.txt, run # nm -C -g --defined libonnxruntime.so |grep -v '\sA\s' | cut -f 3 -d ' ' | sort @@ -84,11 +73,9 @@ if(WIN32) "${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc" ) elseif(onnxruntime_BUILD_APPLE_FRAMEWORK) - get_mobile_api_headers(APPLE_FRAMEWORK_HEADERS) - # apple framework requires the header file be part of the library onnxruntime_add_shared_library(onnxruntime - ${APPLE_FRAMEWORK_HEADERS} + ${ONNXRUNTIME_PUBLIC_HEADERS} "${CMAKE_CURRENT_BINARY_DIR}/generated_source.c" ) @@ -107,10 +94,9 @@ elseif(onnxruntime_BUILD_APPLE_FRAMEWORK) set_target_properties(onnxruntime PROPERTIES FRAMEWORK TRUE FRAMEWORK_VERSION A - PUBLIC_HEADER "${APPLE_FRAMEWORK_HEADERS}" - MACOSX_FRAMEWORK_INFO_PLIST ${CMAKE_CURRENT_BINARY_DIR}/Info.plist - VERSION ${ORT_VERSION} - SOVERSION ${ORT_VERSION} + MACOSX_FRAMEWORK_INFO_PLIST ${INFO_PLIST_PATH} + SOVERSION ${ORT_VERSION} + # Note: The PUBLIC_HEADER and VERSION properties for the 'onnxruntime' target will be set later in this file. ) else() onnxruntime_add_shared_library(onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/generated_source.c) @@ -180,11 +166,10 @@ endif() # we need to copy C/C++ API headers to be packed into Android AAR package if(CMAKE_SYSTEM_NAME STREQUAL "Android" AND onnxruntime_BUILD_JAVA) - get_mobile_api_headers(ANDROID_AAR_HEADERS) set(ANDROID_HEADERS_DIR ${CMAKE_CURRENT_BINARY_DIR}/android/headers) file(MAKE_DIRECTORY ${ANDROID_HEADERS_DIR}) # copy the header files one by one - foreach(h_ ${ANDROID_AAR_HEADERS}) + foreach(h_ ${ONNXRUNTIME_PUBLIC_HEADERS}) get_filename_component(HEADER_NAME_ ${h_} NAME) add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${h_} ${ANDROID_HEADERS_DIR}/${HEADER_NAME_}) endforeach() @@ -328,7 +313,7 @@ if(onnxruntime_BUILD_APPLE_FRAMEWORK) file(MAKE_DIRECTORY ${STATIC_FRAMEWORK_HEADER_DIR}) # copy the header files one by one, and the Info.plist - foreach(h_ ${APPLE_FRAMEWORK_HEADERS}) + foreach(h_ ${ONNXRUNTIME_PUBLIC_HEADERS}) get_filename_component(HEADER_NAME_ ${h_} NAME) add_custom_command(TARGET onnxruntime POST_BUILD COMMAND ${CMAKE_COMMAND} -E copy_if_different ${h_} ${STATIC_FRAMEWORK_HEADER_DIR}/${HEADER_NAME_}) endforeach() diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index e0ccc504d7b27..992908392c946 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -581,7 +581,7 @@ set_target_properties(onnxruntime_mlas PROPERTIES FOLDER "ONNXRuntime") if (WIN32) target_compile_options(onnxruntime_mlas PRIVATE "$<$:/wd6385>" "$<$:/wd4127>") if (onnxruntime_ENABLE_STATIC_ANALYSIS) - target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize" 131072>) + target_compile_options(onnxruntime_mlas PRIVATE "$<$:/analyze:stacksize 131072>") endif() endif() diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index b9e7873132089..96c05e5282bb5 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -460,13 +460,17 @@ if (onnxruntime_USE_CUDA) if (onnxruntime_REDUCED_OPS_BUILD) substitute_op_reduction_srcs(onnxruntime_providers_cuda_src) endif() - # cuda_provider_interface.cc is removed from the object target: onnxruntime_providers_cuda_obj and - # add to the lib onnxruntime_providers_cuda separatedly. - # onnxruntime_providers_cuda_ut can share all the object files with onnxruntime_providers_cuda except cuda_provider_interface.cc. - set(cuda_provider_interface_src ${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_provider_interface.cc) - list(REMOVE_ITEM onnxruntime_providers_cuda_src ${cuda_provider_interface_src}) - onnxruntime_add_object_library(onnxruntime_providers_cuda_obj ${onnxruntime_providers_cuda_src}) - onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${cuda_provider_interface_src} $) + if(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + # cuda_provider_interface.cc is removed from the object target: onnxruntime_providers_cuda_obj and + # add to the lib onnxruntime_providers_cuda separatedly. + # onnxruntime_providers_cuda_ut can share all the object files with onnxruntime_providers_cuda except cuda_provider_interface.cc. + set(cuda_provider_interface_src ${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_provider_interface.cc) + list(REMOVE_ITEM onnxruntime_providers_cuda_src ${cuda_provider_interface_src}) + onnxruntime_add_object_library(onnxruntime_providers_cuda_obj ${onnxruntime_providers_cuda_src}) + onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${cuda_provider_interface_src} $) + else() + onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${onnxruntime_providers_cuda_src}) + endif() # config_cuda_provider_shared_module can be used to config onnxruntime_providers_cuda_obj, onnxruntime_providers_cuda & onnxruntime_providers_cuda_ut. # This function guarantees that all 3 targets have the same configurations. function(config_cuda_provider_shared_module target) @@ -600,7 +604,9 @@ if (onnxruntime_USE_CUDA) target_compile_definitions(${target} PRIVATE ENABLE_ATEN) endif() endfunction() - config_cuda_provider_shared_module(onnxruntime_providers_cuda_obj) + if(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + config_cuda_provider_shared_module(onnxruntime_providers_cuda_obj) + endif() config_cuda_provider_shared_module(onnxruntime_providers_cuda) install(TARGETS onnxruntime_providers_cuda @@ -1750,7 +1756,7 @@ if (onnxruntime_USE_TVM) target_include_directories(onnxruntime_providers_tvm PRIVATE ${TVM_INCLUDES} - ${PYTHON_INLCUDE_DIRS}) + ${PYTHON_INCLUDE_DIRS}) onnxruntime_add_include_to_target(onnxruntime_providers_tvm onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface) add_dependencies(onnxruntime_providers_tvm ${onnxruntime_EXTERNAL_DEPENDENCIES}) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index ec83eb2095071..0e642c5a7e0aa 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -35,13 +35,17 @@ function(AddTest) if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) #TODO: fix the warnings, they are dangerous - target_compile_options(${_UT_TARGET} PRIVATE "/wd4244") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4244>" + "$<$>:/wd4244>") endif() if (MSVC) - target_compile_options(${_UT_TARGET} PRIVATE "/wd6330") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd6330>" + "$<$>:/wd6330>") #Abseil has a lot of C4127/C4324 warnings. - target_compile_options(${_UT_TARGET} PRIVATE "/wd4127") - target_compile_options(${_UT_TARGET} PRIVATE "/wd4324") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4127>" + "$<$>:/wd4127>") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4324>" + "$<$>:/wd4324>") endif() set_target_properties(${_UT_TARGET} PROPERTIES FOLDER "ONNXRuntimeTest") @@ -60,6 +64,11 @@ function(AddTest) Threads::Threads) target_compile_definitions(${_UT_TARGET} PRIVATE -DUSE_ONNXRUNTIME_DLL) else() + if(onnxruntime_USE_CUDA) + #XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs, + # otherwise it will impact when CUDA DLLs can be unloaded. + target_link_libraries(${_UT_TARGET} PRIVATE cudart) + endif() target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES}) endif() @@ -85,29 +94,22 @@ function(AddTest) # include dbghelp in case tests throw an ORT exception, as that exception includes a stacktrace, which requires dbghelp. target_link_libraries(${_UT_TARGET} PRIVATE debug dbghelp) - if (onnxruntime_USE_CUDA) - # disable a warning from the CUDA headers about unreferenced local functions - if (MSVC) - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd4505>" - "$<$>:/wd4505>") - endif() - endif() if (MSVC) # warning C6326: Potential comparison of a constant with another constant. # Lot of such things came from gtest - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd6326>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") # Raw new and delete. A lot of such things came from googletest. - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") # "Global initializer calls a non-constexpr function." - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd26426>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") endif() target_compile_options(${_UT_TARGET} PRIVATE ${disabled_warnings}) else() target_compile_options(${_UT_TARGET} PRIVATE ${DISABLED_WARNINGS_FOR_TVM}) - target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:-Xcompiler -Wno-error=sign-compare>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options -Wno-error=sign-compare>" "$<$>:-Wno-error=sign-compare>") target_compile_options(${_UT_TARGET} PRIVATE "-Wno-error=uninitialized") endif() @@ -698,7 +700,7 @@ onnxruntime_add_static_library(onnxruntime_test_utils ${onnxruntime_test_utils_s if(MSVC) target_compile_options(onnxruntime_test_utils PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") - target_compile_options(onnxruntime_test_utils PRIVATE "$<$:-Xcompiler /wd6326>" + target_compile_options(onnxruntime_test_utils PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") else() target_compile_definitions(onnxruntime_test_utils PUBLIC -DNSYNC_ATOMIC_CPP11) @@ -755,13 +757,8 @@ set_target_properties(onnx_test_runner_common PROPERTIES FOLDER "ONNXRuntimeTest set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src} ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src} ${onnxruntime_test_quantiztion_src}) -if(NOT TARGET onnxruntime AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - list(APPEND all_tests ${onnxruntime_shared_lib_test_SRC}) -endif() -if (onnxruntime_USE_CUDA) - onnxruntime_add_static_library(onnxruntime_test_cuda_ops_lib ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) - list(APPEND onnxruntime_test_common_libs onnxruntime_test_cuda_ops_lib) +if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) file(GLOB onnxruntime_test_providers_cuda_ut_src CONFIGURE_DEPENDS "${TEST_SRC_DIR}/providers/cuda/test_cases/*" ) @@ -822,7 +819,9 @@ if (onnxruntime_USE_TENSORRT) # made test name contain the "ep" and "model path" information, so we can easily filter the tests using cuda ep or other ep with *cpu_* or *xxx_*. list(APPEND test_all_args "--gtest_filter=-*cpu_*:*cuda_*" ) endif () - +if(NOT onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + list(REMOVE_ITEM all_tests ${TEST_SRC_DIR}/providers/cuda/cuda_provider_test.cc) +endif() AddTest( TARGET onnxruntime_test_all SOURCES ${all_tests} ${onnxruntime_unittest_main_src} @@ -832,11 +831,15 @@ AddTest( DEPENDS ${all_dependencies} TEST_ARGS ${test_all_args} ) + if (MSVC) # The warning means the type of two integral values around a binary operator is narrow than their result. # If we promote the two input values first, it could be more tolerant to integer overflow. # However, this is test code. We are less concerned. - target_compile_options(onnxruntime_test_all PRIVATE "/wd26451" "/wd4244") + target_compile_options(onnxruntime_test_all PRIVATE "$<$:SHELL:--compiler-options /wd26451>" + "$<$>:/wd26451>") + target_compile_options(onnxruntime_test_all PRIVATE "$<$:SHELL:--compiler-options /wd4244>" + "$<$>:/wd4244>") else() target_compile_options(onnxruntime_test_all PRIVATE "-Wno-parentheses") endif() @@ -1092,18 +1095,18 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${ONNXRUNTIME_ROOT}/core/mlas/inc) target_compile_definitions(onnxruntime_benchmark PRIVATE BENCHMARK_STATIC_DEFINE) if(WIN32) - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd4141>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd4141>" "$<$>:/wd4141>") # Avoid using new and delete. But this is a benchmark program, it's ok if it has a chance to leak. - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26400>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26400>" "$<$>:/wd26400>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26814>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26814>" "$<$>:/wd26814>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26814>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26814>" "$<$>:/wd26497>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26426>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") @@ -1255,7 +1258,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo) endif() if (onnxruntime_USE_CUDA) - list(APPEND onnxruntime_shared_lib_test_LIBS onnxruntime_test_cuda_ops_lib cudart) + list(APPEND onnxruntime_shared_lib_test_LIBS cudart) endif() if (onnxruntime_USE_TENSORRT) list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER}) @@ -1270,7 +1273,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) LIBS ${onnxruntime_shared_lib_test_LIBS} DEPENDS ${all_dependencies} ) - + if (onnxruntime_USE_CUDA) + target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) + endif() if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc" @@ -1356,13 +1362,13 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ) onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src}) if(MSVC) - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:-Xcompiler /wd6326>" + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:-Xcompiler /wd26426>" + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") endif() if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") @@ -1476,7 +1482,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") "${TEST_SRC_DIR}/testdata/custom_op_library/cuda/cuda_ops.*") list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include) if (HAS_QSPECTRE) - list(APPEND custom_op_lib_option "$<$:SHELL:-Xcompiler /Qspectre>") + list(APPEND custom_op_lib_option "$<$:SHELL:--compiler-options /Qspectre>") endif() endif() @@ -1503,7 +1509,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") else() set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-DEF:${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.def") if (NOT onnxruntime_USE_CUDA) - target_compile_options(custom_op_library PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(custom_op_library PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") endif() endif() diff --git a/cmake/patches/onnx/onnx.patch b/cmake/patches/onnx/onnx.patch index 155d153019f85..a2d7672a3d48d 100644 --- a/cmake/patches/onnx/onnx.patch +++ b/cmake/patches/onnx/onnx.patch @@ -64,16 +64,3 @@ index 0aab3e26..0f859267 100644 +#endif + #endif // ! ONNX_ONNX_PB_H -diff --git a/onnx/checker.cc b/onnx/checker.cc -index 8fdaf037..1beb1b88 100644 ---- a/onnx/checker.cc -+++ b/onnx/checker.cc -@@ -190,7 +190,7 @@ void check_tensor(const TensorProto& tensor, const CheckerContext& ctx) { - } - std::string data_path = path_join(ctx.get_model_dir(), relative_path); - // use stat64 to check whether the file exists --#ifdef __APPLE__ -+#if defined(__APPLE__) || defined(__wasm__) - struct stat buffer; // APPLE does not have stat64 - if (stat((data_path).c_str(), &buffer) != 0) { - #else diff --git a/docs/FAQ.md b/docs/FAQ.md index e039f4e4e4160..70bedbd02e944 100644 --- a/docs/FAQ.md +++ b/docs/FAQ.md @@ -2,13 +2,13 @@ Here are some commonly raised questions from users of ONNX Runtime and brought up in [Issues](https://github.com/microsoft/onnxruntime/issues). ## Do the GPU builds support quantized models? -The default CUDA build supports 3 standard quantization operators: QuantizeLinear, DequantizeLinear, and MatMulInteger. The TensorRT EP has limited support for INT8 quantized ops. In general, support of quantized models through ORT is continuing to expand on a model-driven basis. For performance improvements, quantization is not always required, and we suggest trying alternative strategies to [performance tune](./ONNX_Runtime_Perf_Tuning.md) before determining that quantization is necessary. +The default CUDA build supports 3 standard quantization operators: QuantizeLinear, DequantizeLinear, and MatMulInteger. The TensorRT EP has limited support for INT8 quantized ops. In general, support of quantized models through ORT is continuing to expand on a model-driven basis. For performance improvements, quantization is not always required, and we suggest trying alternative strategies to [performance tune](https://onnxruntime.ai/docs/performance/tune-performance/) before determining that quantization is necessary. ## How do I change the severity level of the default logger to something other than the default (WARNING)? Setting the severity level to VERBOSE is most useful when debugging errors. Refer to the API documentation: -* Python - [RunOptions.log_severity_level](https://microsoft.github.io/onnxruntime/python/api_summary.html#onnxruntime.RunOptions.log_severity_level) +* Python - [RunOptions.log_severity_level](https://onnxruntime.ai/docs/api/python/api_summary.html#onnxruntime.RunOptions.log_severity_level) ``` import onnxruntime as ort ort.set_default_logger_severity(0) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 33c187a28b62e..14b6b339c11f3 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -67,7 +67,8 @@ Do not modify directly.* |||[11, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[4, 10]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ConcatFromSequence|*in* input_sequence:**S**
*out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|9+|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|20+|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[9, 19]|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(float)| |||[1, 10]|**T** = tensor(float)| |ConvInteger|*in* x:**T1**
*in* w:**T2**
*in* x_zero_point:**T1**
*in* w_zero_point:**T2**
*out* y:**T3**|10+|**T1** = tensor(uint8)
**T2** = tensor(uint8)
**T3** = tensor(int32)| @@ -78,7 +79,7 @@ Do not modify directly.* |Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| |||[11, 13]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)
**T2** = tensor(int32), tensor(int64)| -|DFT|*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| +|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int32), tensor(int64)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float)| |||[11, 12]|**T** = tensor(double), tensor(float)| |||[1, 10]|**T** = tensor(double), tensor(float)| @@ -935,7 +936,7 @@ Do not modify directly.* |Crop|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |CumSum|*in* x:**T**
*in* axis:**T2**
*out* y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||11+|**T** = tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|DFT|*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int64)| +|DFT|*in* input:**T1**
*in* dft_length:**T2**
*in* axis:**tensor(int64)**
*out* output:**T1**

or

*in* input:**T1**
*in* dft_length:**T2**
*out* output:**T1**|17+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(int64)| |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/include/onnxruntime/core/framework/tensor.h b/include/onnxruntime/core/framework/tensor.h index 7f3f26fa4aa02..a867ab6066485 100644 --- a/include/onnxruntime/core/framework/tensor.h +++ b/include/onnxruntime/core/framework/tensor.h @@ -3,8 +3,9 @@ #pragma once -#include +#include #include +#include #include #include @@ -37,7 +38,8 @@ namespace onnxruntime { class Tensor final { public: - // NB! Removing Create() methods returning unique_ptr. Still available in other EPs that are dynamically linked. + // NB! Removing Create() methods returning unique_ptr. + // Still available in other EPs that are dynamically linked. // Strive not to allocate Tensor with new/delete as it is a shallow class and using it by value is just fine. // Use InitOrtValue() methods to allocate for OrtValue. @@ -45,105 +47,104 @@ class Tensor final { /** * Create tensor with given type, shape, pre-allocated memory and allocator info. - * This function won't check if the preallocated buffer(p_data) has enough room for the shape. - * \param p_type Data type of the tensor + * This function does not check if the preallocated buffer(p_data) has enough room for the shape. + * \param elt_type Data type of the tensor elements. * \param shape Shape of the tensor * \param p_data A preallocated buffer. Can be NULL if the shape is empty. - * Tensor does not own the data and will not delete it - * \param alloc Where the buffer('p_data') was allocated from + * Tensor does not own the data and will not delete it + * \param location Memory info for location of p_data. * \param offset Offset in bytes to start of Tensor within p_data. * \param strides Strides span. Can be empty if the tensor is contiguous. */ - Tensor(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, + Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, ptrdiff_t offset = 0, gsl::span strides = {}); - /// - /// Creates an instance of Tensor on the heap using the appropriate __ctor and - /// initializes OrtValue with it. - /// - /// - /// - /// - /// - /// - /// - static void InitOrtValue(MLDataType p_type, const TensorShape& shape, - void* p_data, const OrtMemoryInfo& location, - OrtValue& ort_value, ptrdiff_t offset = 0, - gsl::span strides = {}); - - /// - /// Creates an instance of Tensor who own the pre-allocated buffer. - /// - /// - /// - /// - /// - /// - /// - static void InitOrtValue(MLDataType p_type, const TensorShape& shape, - void* p_data, std::shared_ptr allocator, - OrtValue& ort_value, ptrdiff_t offset = 0, - gsl::span strides = {}); - - static size_t CalculateTensorStorageSize(MLDataType p_type, - const TensorShape& shape, - gsl::span strides = {}); - /** - * Deprecated. The original design is this Tensor class won't do any allocation / release. - * However, this function will allocate the buffer for the shape, and do placement new if p_type is string tensor. - */ - Tensor(MLDataType p_type, const TensorShape& shape, std::shared_ptr allocator, - gsl::span strides = {}); - - /// - /// Creates an instance of Tensor on the heap using the appropriate __ctor and - /// initializes OrtValue with it. - /// - /// - /// - /// - /// - /// - static void InitOrtValue(MLDataType elt_type, - const TensorShape& shape, - std::shared_ptr allocator, - OrtValue& ort_value, - gsl::span strides = {}); - - /// - /// Creates an instance of Tensor on the heap using the appropriate __ctor and - /// initializes OrtValue with it. - /// - /// - /// - static void InitOrtValue(Tensor&& tensor, OrtValue& ort_value); - - /** - * Create tensor with given type, shape, pre-allocated memory and allocator which will be used to free the pre-allocated memory. - * This function won't check if the preallocated buffer(p_data) has enough room for the shape. - * However, this function will de-allocate the buffer upon the tensor getting destructed. - * \param p_type Data type of the tensor + * Create tensor with given type, shape, pre-allocated memory and allocator which will be used to free the + * pre-allocated memory. The Tensor will take over ownership of p_data. + * This function does not check if the preallocated buffer(p_data) has enough room for the shape. + * \param elt_type Data type of the tensor elements. * \param shape Shape of the tensor * \param p_data A preallocated buffer. Can be NULL if the shape is empty. - * Tensor will own the memory and will delete it when the tensor instance is destructed. + * Tensor will own the memory and will delete it when the tensor instance is destructed. * \param deleter Allocator used to free the pre-allocated memory * \param offset Offset in bytes to start of Tensor within p_data. * \param strides Strides span. Can be empty if the tensor is contiguous. */ - Tensor(MLDataType p_type, const TensorShape& shape, void* p_data, std::shared_ptr deleter, + Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, std::shared_ptr deleter, ptrdiff_t offset = 0, gsl::span strides = {}); + /// + /// Create a Tensor that allocates and owns the buffer required for the specified shape. + /// + /// Data type of the tensor elements. + /// Tensor shape. + /// Allocator to use to create and free buffer. + Tensor(MLDataType elt_type, const TensorShape& shape, std::shared_ptr allocator); + ~Tensor(); // Move is allowed ORT_DISALLOW_COPY_AND_ASSIGNMENT(Tensor); Tensor(Tensor&& other) noexcept; - Tensor& operator=(Tensor&& other) noexcept; + /// + /// Creates an instance of Tensor on the heap and initializes OrtValue with it. + /// + /// Data type of the tensor elements. + /// Tensor shape. + /// Tensor data. + /// Memory info for location of p_data. + /// OrtValue to populate with Tensor. + /// Optional offset if Tensor refers to a subset of p_data. + /// Optional strides if Tensor refers to a subset of p_data. + static void InitOrtValue(MLDataType elt_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, + OrtValue& ort_value, + ptrdiff_t offset = 0, gsl::span strides = {}); + + /// + /// Creates an instance of Tensor on the heap which will take over ownership of the pre-allocated buffer. + /// + /// Data type of the tensor elements. + /// + /// Tensor data. + /// Allocator that was used to create p_data and will be used to free it. + /// OrtValue to populate with Tensor. + /// Optional offset if Tensor refers to a subset of p_data. + /// Optional strides if Tensor refers to a subset of p_data. + static void InitOrtValue(MLDataType elt_type, const TensorShape& shape, void* p_data, + std::shared_ptr allocator, + OrtValue& ort_value, + ptrdiff_t offset = 0, gsl::span strides = {}); + + /// + /// Creates an instance of Tensor on the heap and initializes OrtValue with it. + /// The Tensor instance will allocate and own the data required for `shape`. + /// + /// Data type of the tensor elements. + /// Tensor shape. + /// Allocator that was used to create p_data and will be used to free it. + /// OrtValue to populate with Tensor. + static void InitOrtValue(MLDataType elt_type, const TensorShape& shape, std::shared_ptr allocator, + OrtValue& ort_value); + + /// + /// Initializes OrtValue with an existing Tensor. + /// + /// Tensor. + /// OrtValue to populate with Tensor. + static void InitOrtValue(Tensor&& tensor, OrtValue& ort_value); + + /// + /// Calculate the required storage for the tensor. + /// + /// Data type of the tensor elements. + /// Tensor shape. + /// Bytes required. + static size_t CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape); + /** Returns the data type. */ @@ -294,7 +295,7 @@ class Tensor final { // More API methods. private: - void Init(MLDataType p_type, + void Init(MLDataType elt_type, const TensorShape& shape, void* p_raw_data, AllocatorPtr deleter, diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index e483c67a0cfe6..486e2ff2b90a2 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -745,6 +745,8 @@ struct OrtApi { /** \brief Create an OrtEnv * + * \note Invoking this function will return the same instance of the environment as that returned by a previous call + * to another env creation function; all arguments to this function will be ignored. * \param[in] log_severity_level The log severity level. * \param[in] logid The log identifier. * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv @@ -755,17 +757,20 @@ struct OrtApi { /** \brief Create an OrtEnv * + * \note Invoking this function will return the same instance of the environment as that returned by a previous call + * to another env creation function; all arguments to this function will be ignored. If you want to provide your + * own logging function, consider setting it using the SetUserLoggingFunction API instead. * \param[in] logging_function A pointer to a logging function. * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to - * `logging_function`. + * `logging_function`. This parameter is optional. * \param[in] log_severity_level The log severity level. * \param[in] logid The log identifier. * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv * * \snippet{doc} snippets.dox OrtStatus Return Value */ - ORT_API2_STATUS(CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, - OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); + ORT_API2_STATUS(CreateEnvWithCustomLogger, _In_ OrtLoggingFunction logging_function, _In_opt_ void* logger_param, + _In_ OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); /** \brief Enable Telemetry * @@ -3592,6 +3597,9 @@ struct OrtApi { * "rpc_control_latency": QNN RPC control latency. * "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance", * "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default". + * "qnn_saver_path": File path to the QNN Saver backend library. If specified, QNN Saver will be enabled and will + * dump QNN API calls to disk for replay/debugging. QNN Saver produces incorrect model inference results and + * may alter model/EP partitioning. Use only for debugging. * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", @@ -4413,6 +4421,26 @@ struct OrtApi { * \since Version 1.16. */ ORT_API2_STATUS(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resouce_version, _In_ int resource_id, _Outptr_ void** resource); + + /** \brief Set user logging function + * + * By default the logger created by the CreateEnv* functions is used to create the session logger as well. + * This function allows a user to override this default session logger with a logger of their own choosing. This way + * the user doesn't have to create a separate environment with a custom logger. This addresses the problem when + * the user already created an env but now wants to use a different logger for a specific session (for debugging or + * other reasons). + * + * \param[in] options + * \param[in] user_logging_function A pointer to a logging function. + * \param[in] user_logging_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to + * `user_logging_function`. This parameter is optional. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.17. + */ + ORT_API2_STATUS(SetUserLoggingFunction, _Inout_ OrtSessionOptions* options, + _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 47356c3fe3608..45f81783421e0 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2055,6 +2055,7 @@ struct KernelContext { void* GetGPUComputeStream() const; Logger GetLogger() const; OrtAllocator* GetAllocator(const OrtMemoryInfo& memory_info) const; + OrtKernelContext* GetOrtKernelContext() const { return ctx_; } private: OrtKernelContext* ctx_; diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index fd42824e81a56..887339ebc50d6 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -18,22 +18,81 @@ #include "onnxruntime_cxx_api.h" #include #include +#include #include namespace Ort { namespace Custom { -class TensorBase { +class ArgBase { public: - TensorBase(OrtKernelContext* ctx) : ctx_(ctx) {} - virtual ~TensorBase() {} + ArgBase(OrtKernelContext* ctx, + size_t indice, + bool is_input) : ctx_(ctx), indice_(indice), is_input_(is_input) {} + virtual ~ArgBase(){}; + + protected: + struct KernelContext ctx_; + size_t indice_; + bool is_input_; +}; + +using ArgPtr = std::unique_ptr; +using ArgPtrs = std::vector; + +class TensorBase : public ArgBase { + public: + TensorBase(OrtKernelContext* ctx, + size_t indice, + bool is_input) : ArgBase(ctx, indice, is_input) {} + operator bool() const { return shape_.has_value(); } + const std::vector& Shape() const { + if (!shape_.has_value()) { + ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return shape_.value(); + } + + ONNXTensorElementDataType Type() const { + return type_; + } + + int64_t NumberOfElement() const { + if (shape_.has_value()) { + return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies()); + } else { + return 0; + } + } + + std::string Shape2Str() const { + if (shape_.has_value()) { + std::string shape_str; + for (const auto& dim : *shape_) { + shape_str.append(std::to_string(dim)); + shape_str.append(", "); + } + return shape_str; + } else { + return "empty"; + } + } + + bool IsCpuTensor() const { + return strcmp("Cpu", mem_type_) == 0; + } + + virtual const void* DataRaw() const = 0; + virtual size_t SizeInBytes() const = 0; + protected: - struct KernelContext ctx_; std::optional> shape_; + ONNXTensorElementDataType type_ = ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; + const char* mem_type_ = "Cpu"; }; template @@ -48,13 +107,14 @@ struct Span { T operator[](size_t indice) const { return data_[indice]; } + const T* data() const { return data_; } }; template class Tensor : public TensorBase { public: using TT = typename std::remove_reference::type; - Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) { + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { if (is_input_) { if (indice >= ctx_.GetInputCount()) { ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); @@ -64,19 +124,6 @@ class Tensor : public TensorBase { shape_ = type_shape_info.GetShape(); } } - const std::vector& Shape() const { - if (!shape_.has_value()) { - ORT_CXX_API_THROW("tensor shape is not yet initialized", OrtErrorCode::ORT_RUNTIME_EXCEPTION); - } - return shape_.value(); - } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1LL, std::multiplies()); - } else { - return 0; - } - } const TT* Data() const { return reinterpret_cast(const_value_.GetTensorRawData()); } @@ -104,10 +151,15 @@ class Tensor : public TensorBase { } return *Data(); } + const void* DataRaw() const override { + return reinterpret_cast(Data()); + } + + size_t SizeInBytes() const override { + return sizeof(TT) * static_cast(NumberOfElement()); + } private: - size_t indice_; - bool is_input_; ConstValue const_value_; // for input TT* data_{}; // for output Span span_; @@ -118,7 +170,7 @@ class Tensor : public TensorBase { public: using strings = std::vector; - Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) { + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { if (is_input_) { if (indice >= ctx_.GetInputCount()) { ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); @@ -147,16 +199,21 @@ class Tensor : public TensorBase { } } } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies()); - } else { - return 0; - } - } const strings& Data() const { return input_strings_; } + const void* DataRaw() const override { + if (input_strings_.size() != 1) { + ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return reinterpret_cast(input_strings_[0].c_str()); + } + size_t SizeInBytes() const override { + if (input_strings_.size() != 1) { + ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return input_strings_[0].size(); + } void SetStringOutput(const strings& ss, const std::vector& dims) { shape_ = dims; std::vector raw; @@ -179,8 +236,6 @@ class Tensor : public TensorBase { } private: - size_t indice_; - bool is_input_; std::vector input_strings_; // for input }; @@ -190,7 +245,7 @@ class Tensor : public TensorBase { using strings = std::vector; using string_views = std::vector; - Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx), indice_(indice), is_input_(is_input) { + Tensor(OrtKernelContext* ctx, size_t indice, bool is_input) : TensorBase(ctx, indice, is_input) { if (is_input_) { if (indice >= ctx_.GetInputCount()) { ORT_CXX_API_THROW("invalid indice for Ort::Custom::Tensor", OrtErrorCode::ORT_INVALID_ARGUMENT); @@ -211,16 +266,21 @@ class Tensor : public TensorBase { } } } - int64_t NumberOfElement() const { - if (shape_.has_value()) { - return std::accumulate(shape_->begin(), shape_->end(), 1ULL, std::multiplies()); - } else { - return 0; - } - } const string_views& Data() const { return input_string_views_; } + const void* DataRaw() const override { + if (input_string_views_.size() != 1) { + ORT_CXX_API_THROW("DataRaw() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return reinterpret_cast(input_string_views_[0].data()); + } + size_t SizeInBytes() const override { + if (input_string_views_.size() != 1) { + ORT_CXX_API_THROW("SizeInBytes() only applies to string scalar", ORT_RUNTIME_EXCEPTION); + } + return input_string_views_[0].size(); + } void SetStringOutput(const strings& ss, const std::vector& dims) { shape_ = dims; std::vector raw; @@ -243,15 +303,101 @@ class Tensor : public TensorBase { } private: - size_t indice_; - bool is_input_; std::vector chars_; // for input std::vector input_string_views_; // for input }; using TensorPtr = std::unique_ptr; +using TensorPtrs = std::vector; -//////////////////////////// OrtLiteCustomOp //////////////////////////////// +struct TensorArray : public ArgBase { + TensorArray(OrtKernelContext* ctx, + size_t start_indice, + bool is_input) : ArgBase(ctx, + start_indice, + is_input) { + if (is_input) { + auto input_count = ctx_.GetInputCount(); + for (size_t ith_input = start_indice; ith_input < input_count; ++ith_input) { + auto const_value = ctx_.GetInput(start_indice); + auto type_shape_info = const_value.GetTensorTypeAndShapeInfo(); + auto type = type_shape_info.GetElementType(); + TensorPtr tensor; + switch (type) { + case ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT8: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT16: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT16: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT32: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64: + tensor = std::make_unique>(ctx, ith_input, true); + break; + case ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING: + tensor = std::make_unique>(ctx, ith_input, true); + break; + default: + ORT_CXX_API_THROW("unknow input type", ORT_RUNTIME_EXCEPTION); + break; + } + tensors_.emplace_back(tensor.release()); + } // for + } + } + template + T* AllocateOutput(size_t ith_output, const std::vector& shape) { + // ith_output is the indice of output relative to the tensor array + // indice_ + ith_output is the indice relative to context + auto tensor = std::make_unique>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false); + auto raw_output = tensor.get()->Allocate(shape); + tensors_.emplace_back(tensor.release()); + return raw_output; + } + Tensor& AllocateStringTensor(size_t ith_output) { + // ith_output is the indice of output relative to the tensor array + // indice_ + ith_output is the indice relative to context + auto tensor = std::make_unique>(ctx_.GetOrtKernelContext(), indice_ + ith_output, false); + Tensor& output = *tensor; + tensors_.emplace_back(tensor.release()); + return output; + } + size_t Size() const { + return tensors_.size(); + } + const TensorPtr& operator[](size_t ith_input) const { + // ith_input is the indice of output relative to the tensor array + return tensors_.at(ith_input); + } + + private: + TensorPtrs tensors_; +}; + +using Variadic = TensorArray; struct OrtLiteCustomOp : public OrtCustomOp { using ConstOptionalFloatTensor = std::optional&>; @@ -260,34 +406,34 @@ struct OrtLiteCustomOp : public OrtCustomOp { // CreateTuple template static typename std::enable_if>::type - CreateTuple(OrtKernelContext*, std::vector&, size_t, size_t, const std::string&) { + CreateTuple(OrtKernelContext*, ArgPtrs&, size_t, size_t, const std::string&) { return std::make_tuple(); } template static typename std::enable_if::value, std::tuple>::type - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { std::tuple current = std::tuple{context}; - auto next = CreateTuple(context, tensors, num_input, num_output, ep); + auto next = CreateTuple(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } template static typename std::enable_if::value, std::tuple>::type - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { std::tuple current = std::tuple{*context}; - auto next = CreateTuple(context, tensors, num_input, num_output, ep); + auto next = CreateTuple(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } #ifdef ORT_CUDA_CTX template static typename std::enable_if::value, std::tuple>::type - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { thread_local CudaContext cuda_context; cuda_context.Init(*context); std::tuple current = std::tuple{cuda_context}; - auto next = CreateTuple(context, tensors, num_input, num_output, ep); + auto next = CreateTuple(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } #endif @@ -295,143 +441,179 @@ struct OrtLiteCustomOp : public OrtCustomOp { #ifdef ORT_ROCM_CTX template static typename std::enable_if::value, std::tuple>::type - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { thread_local RocmContext rocm_context; rocm_context.Init(*context); std::tuple current = std::tuple{rocm_context}; - auto next = CreateTuple(context, tensors, num_input, num_output, ep); + auto next = CreateTuple(context, args, num_input, num_output, ep); return std::tuple_cat(current, next); } #endif -#define CREATE_TUPLE_INPUT(data_type) \ - template \ - static typename std::enable_if*>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if&>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if*>>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_input < num_input) { \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - } \ - template \ - static typename std::enable_if*>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{&reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if&>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if*>>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_input < num_input) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{&reinterpret_cast*>(tensors.back().get())->AsSpan()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - } \ - template \ - static typename std::enable_if::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsScalar()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_input < num_input) { \ - if ("CPUExecutionProvider" != ep) { \ - ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ - } \ - tensors.push_back(std::make_unique>(context, ith_input, true)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())->AsScalar()}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - } -#define CREATE_TUPLE_OUTPUT(data_type) \ - template \ - static typename std::enable_if*>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_output, false)); \ - std::tuple current = std::tuple{reinterpret_cast(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if&>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - tensors.push_back(std::make_unique>(context, ith_output, false)); \ - std::tuple current = std::tuple{reinterpret_cast(*tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ - template \ - static typename std::enable_if*>>::value, std::tuple>::type \ - CreateTuple(OrtKernelContext* context, std::vector& tensors, size_t num_input, size_t num_output, const std::string& ep) { \ - if (ith_output < num_output) { \ - tensors.push_back(std::make_unique>(context, ith_output, false)); \ - std::tuple current = std::tuple{reinterpret_cast*>(tensors.back().get())}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } else { \ - std::tuple current = std::tuple{}; \ - auto next = CreateTuple(context, tensors, num_input, num_output, ep); \ - return std::tuple_cat(current, next); \ - } \ + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_input, true)); + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_input, true)); + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_output, false)); + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + + template + static typename std::enable_if::value, std::tuple>::type + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { + args.push_back(std::make_unique(context, ith_output, false)); + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; + auto next = CreateTuple(context, args, num_input, num_output, ep); + return std::tuple_cat(current, next); + } + +#define CREATE_TUPLE_INPUT(data_type) \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{&reinterpret_cast*>(args.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("span input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{&reinterpret_cast*>(args.back().get())->AsSpan()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } \ + template \ + static typename std::enable_if::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())->AsScalar()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_input < num_input) { \ + if ("CPUExecutionProvider" != ep) { \ + ORT_CXX_API_THROW("scalar input could only be applied to CPU EP", OrtErrorCode::ORT_RUNTIME_EXCEPTION); \ + } \ + args.push_back(std::make_unique>(context, ith_input, true)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())->AsScalar()}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + } +#define CREATE_TUPLE_OUTPUT(data_type) \ + template \ + static typename std::enable_if*>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if&>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + args.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast(*args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ + template \ + static typename std::enable_if*>>::value, std::tuple>::type \ + CreateTuple(OrtKernelContext* context, ArgPtrs& args, size_t num_input, size_t num_output, const std::string& ep) { \ + if (ith_output < num_output) { \ + args.push_back(std::make_unique>(context, ith_output, false)); \ + std::tuple current = std::tuple{reinterpret_cast*>(args.back().get())}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } else { \ + std::tuple current = std::tuple{}; \ + auto next = CreateTuple(context, args, num_input, num_output, ep); \ + return std::tuple_cat(current, next); \ + } \ } #define CREATE_TUPLE(data_type) \ CREATE_TUPLE_INPUT(data_type) \ @@ -491,6 +673,34 @@ struct OrtLiteCustomOp : public OrtCustomOp { } #endif + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + input_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + + template + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type + ParseArgs(std::vector& input_types, std::vector& output_types) { + output_types.push_back(ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED); + ParseArgs(input_types, output_types); + } + #define PARSE_INPUT_BASE(pack_type, onnx_type) \ template \ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same::value>::type \ @@ -499,6 +709,12 @@ struct OrtLiteCustomOp : public OrtCustomOp { ParseArgs(input_types, output_types); \ } \ template \ + static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same>::value>::type \ + ParseArgs(std::vector& input_types, std::vector& output_types) { \ + input_types.push_back(onnx_type); \ + ParseArgs(input_types, output_types); \ + } \ + template \ static typename std::enable_if<0 <= sizeof...(Ts) && std::is_same>::value>::type \ ParseArgs(std::vector& input_types, std::vector& output_types) { \ input_types.push_back(onnx_type); \ @@ -585,18 +801,40 @@ struct OrtLiteCustomOp : public OrtCustomOp { return self->output_types_[indice]; }; - OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp*, size_t) { - return INPUT_OUTPUT_OPTIONAL; + OrtCustomOp::GetInputCharacteristic = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->input_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL; }; - OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp*, size_t) { - return INPUT_OUTPUT_OPTIONAL; + OrtCustomOp::GetOutputCharacteristic = [](const OrtCustomOp* op, size_t indice) { + auto self = reinterpret_cast(op); + return self->output_types_[indice] == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED ? INPUT_OUTPUT_VARIADIC : INPUT_OUTPUT_OPTIONAL; + }; + + OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { + return 1; + }; + + OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { + return 0; + }; + + OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { + return 1; + }; + + OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { + return 0; }; OrtCustomOp::GetVariadicInputMinArity = [](const OrtCustomOp*) { return 0; }; OrtCustomOp::GetVariadicInputHomogeneity = [](const OrtCustomOp*) { return 0; }; OrtCustomOp::GetVariadicOutputMinArity = [](const OrtCustomOp*) { return 0; }; OrtCustomOp::GetVariadicOutputHomogeneity = [](const OrtCustomOp*) { return 0; }; + + OrtCustomOp::CreateKernelV2 = {}; + OrtCustomOp::KernelComputeV2 = {}; + OrtCustomOp::KernelCompute = {}; } const std::string op_name_; @@ -619,12 +857,14 @@ struct OrtLiteCustomOp : public OrtCustomOp { template struct OrtLiteCustomFunc : public OrtLiteCustomOp { using ComputeFn = void (*)(Args...); + using ComputeFnReturnStatus = Status (*)(Args...); using MyType = OrtLiteCustomFunc; struct Kernel { size_t num_input_{}; size_t num_output_{}; ComputeFn compute_fn_{}; + ComputeFnReturnStatus compute_fn_return_status_{}; std::string ep_{}; }; @@ -636,8 +876,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { auto kernel = reinterpret_cast(op_kernel); - std::vector tensors; - auto t = CreateTuple<0, 0, Args...>(context, tensors, kernel->num_input_, kernel->num_output_, kernel->ep_); + std::vector args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); std::apply([kernel](Args const&... t_args) { kernel->compute_fn_(t_args...); }, t); }; @@ -656,7 +896,36 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { }; } - ComputeFn compute_fn_; + OrtLiteCustomFunc(const char* op_name, + const char* execution_provider, + ComputeFnReturnStatus compute_fn_return_status) : OrtLiteCustomOp(op_name, execution_provider), + compute_fn_return_status_(compute_fn_return_status) { + ParseArgs(input_types_, output_types_); + + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + auto kernel = reinterpret_cast(op_kernel); + std::vector args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + return std::apply([kernel](Args const&... t_args) { Status status = kernel->compute_fn_return_status_(t_args...); return status.release(); }, t); + }; + + OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { + auto kernel = std::make_unique(); + kernel->compute_fn_return_status_ = static_cast(this_)->compute_fn_return_status_; + Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); + Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); + auto self = static_cast(this_); + kernel->ep_ = self->execution_provider_; + return reinterpret_cast(kernel.release()); + }; + + OrtCustomOp::KernelDestroy = [](void* op_kernel) { + delete reinterpret_cast(op_kernel); + }; + } + + ComputeFn compute_fn_ = {}; + ComputeFnReturnStatus compute_fn_return_status_ = {}; }; // struct OrtLiteCustomFunc /////////////////////////// OrtLiteCustomStruct /////////////////////////// @@ -679,6 +948,10 @@ template struct OrtLiteCustomStruct : public OrtLiteCustomOp { template using CustomComputeFn = void (CustomOp::*)(Args...); + + template + using CustomComputeFnReturnStatus = Status (CustomOp::*)(Args...); + using MyType = OrtLiteCustomStruct; struct Kernel { @@ -692,18 +965,6 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { const char* execution_provider) : OrtLiteCustomOp(op_name, execution_provider) { init(&CustomOp::Compute); - } - - template - void init(CustomComputeFn) { - ParseArgs(input_types_, output_types_); - - OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { - auto kernel = reinterpret_cast(op_kernel); - std::vector tensors; - auto t = CreateTuple<0, 0, Args...>(context, tensors, kernel->num_input_, kernel->num_output_, kernel->ep_); - std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t); - }; OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); @@ -719,6 +980,28 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { delete reinterpret_cast(op_kernel); }; } + + template + void init(CustomComputeFn) { + ParseArgs(input_types_, output_types_); + OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { + auto kernel = reinterpret_cast(op_kernel); + ArgPtrs args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + std::apply([kernel](Args const&... t_args) { kernel->custom_op_->Compute(t_args...); }, t); + }; + } + + template + void init(CustomComputeFnReturnStatus) { + ParseArgs(input_types_, output_types_); + OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { + auto kernel = reinterpret_cast(op_kernel); + ArgPtrs args; + auto t = CreateTuple<0, 0, Args...>(context, args, kernel->num_input_, kernel->num_output_, kernel->ep_); + return std::apply([kernel](Args const&... t_args) { Status status = kernel->custom_op_->Compute(t_args...); return status.release(); }, t); + }; + } }; // struct OrtLiteCustomStruct /////////////////////////// CreateLiteCustomOp //////////////////////////// @@ -731,6 +1014,14 @@ OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, return std::make_unique(op_name, execution_provider, custom_compute_fn).release(); } +template +OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, + const char* execution_provider, + Status (*custom_compute_fn_v2)(Args...)) { + using LiteOp = OrtLiteCustomFunc; + return std::make_unique(op_name, execution_provider, custom_compute_fn_v2).release(); +} + template OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, const char* execution_provider) { @@ -739,4 +1030,4 @@ OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name, } } // namespace Custom -} // namespace Ort +} // namespace Ort \ No newline at end of file diff --git a/java/src/main/java/ai/onnxruntime/TensorInfo.java b/java/src/main/java/ai/onnxruntime/TensorInfo.java index 34d635b5c40f8..69ccb954e8afe 100644 --- a/java/src/main/java/ai/onnxruntime/TensorInfo.java +++ b/java/src/main/java/ai/onnxruntime/TensorInfo.java @@ -58,14 +58,37 @@ public enum OnnxTensorType { */ ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16( 16), // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN( - 17), // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ( - 18), // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2( - 19), // Non-IEEE floating-point format based on IEEE754 single-precision - ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ( - 20); // Non-IEEE floating-point format based on IEEE754 single-precision + /** + * A non-IEEE 8-bit floating point format with 4 exponent bits and 3 mantissa bits, with NaN and + * no infinite values (FN). + * + *

See the float 8 ONNX standard for + * details. + */ + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN(17), + /** + * A non-IEEE 8-bit floating point format with 4 exponent bits and 3 mantissa bits, with NaN, no + * infinite values (FN) and no negative zero (UZ). + * + *

See the float 8 ONNX standard for + * details. + */ + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ(18), + /** + * A non-IEEE 8-bit floating point format with 5 exponent bits and 2 mantissa bits. + * + *

See the float 8 ONNX standard for + * details. + */ + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2(19), + /** + * A non-IEEE 8-bit floating point format with 5 exponent bits and 2 mantissa bits, with NaN, no + * infinite values (FN) and no negative zero (UZ). + * + *

See the float 8 ONNX standard for + * details. + */ + ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ(20); /** The int id on the native side. */ public final int value; diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts index 75feba1d0ae08..e129c6971a85c 100644 --- a/js/common/lib/backend-impl.ts +++ b/js/common/lib/backend-impl.ts @@ -26,7 +26,7 @@ const backendsSortedByPriority: string[] = []; * @ignore */ export const registerBackend = (name: string, backend: Backend, priority: number): void => { - if (backend && typeof backend.init === 'function' && typeof backend.createSessionHandler === 'function') { + if (backend && typeof backend.init === 'function' && typeof backend.createInferenceSessionHandler === 'function') { const currentBackend = backends.get(name); if (currentBackend === undefined) { backends.set(name, {backend, priority}); diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index 804f33f00d103..dd04ef3f15997 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -3,6 +3,7 @@ import {InferenceSession} from './inference-session.js'; import {OnnxValue} from './onnx-value.js'; +import {TrainingSession} from './training-session.js'; /** * @ignore @@ -14,16 +15,23 @@ export declare namespace SessionHandler { } /** - * Represent a handler instance of an inference session. + * Represents shared SessionHandler functionality * * @ignore */ -export interface SessionHandler { +interface SessionHandler { dispose(): Promise; readonly inputNames: readonly string[]; readonly outputNames: readonly string[]; +} +/** + * Represent a handler instance of an inference session. + * + * @ignore + */ +export interface InferenceSessionHandler extends SessionHandler { startProfiling(): void; endProfiling(): void; @@ -31,6 +39,20 @@ export interface SessionHandler { options: InferenceSession.RunOptions): Promise; } +/** + * Represent a handler instance of a training inference session. + * + * @ignore + */ +export interface TrainingSessionHandler extends SessionHandler { + runTrainStep( + feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions): Promise; + + loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + getContiguousParameters(trainableOnly: boolean): Promise; +} + /** * Represent a backend that provides implementation of model inferencing. * @@ -42,8 +64,13 @@ export interface Backend { */ init(): Promise; - createSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise; + createInferenceSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise; + + createTrainingSessionHandler? + (checkpointStateUriOrBuffer: TrainingSession.URIorBuffer, trainModelUriOrBuffer: TrainingSession.URIorBuffer, + evalModelUriOrBuffer: TrainingSession.URIorBuffer, optimizerModelUriOrBuffer: TrainingSession.URIorBuffer, + options: InferenceSession.SessionOptions): Promise; } export {registerBackend} from './backend-impl.js'; diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 525272294c587..76575ef7b9368 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -9,6 +9,7 @@ export declare namespace Env { 'ort-wasm.wasm'?: string; 'ort-wasm-threaded.wasm'?: string; 'ort-wasm-simd.wasm'?: string; + 'ort-training-wasm-simd.wasm'?: string; 'ort-wasm-simd-threaded.wasm'?: string; /* eslint-enable @typescript-eslint/naming-convention */ }; @@ -105,6 +106,12 @@ export declare namespace Env { * see comments on {@link GpuBufferType} for more details about why not use types defined in "@webgpu/types". */ readonly device: unknown; + /** + * Set or get whether validate input content. + * + * @defaultValue `false` + */ + validateInputContent?: boolean; } } diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts index 85df1747f8576..9cbfcc4e8bcdc 100644 --- a/js/common/lib/index.ts +++ b/js/common/lib/index.ts @@ -22,3 +22,4 @@ export * from './env.js'; export * from './inference-session.js'; export * from './tensor.js'; export * from './onnx-value.js'; +export * from './training-session.js'; diff --git a/js/common/lib/inference-session-impl.ts b/js/common/lib/inference-session-impl.ts index 06949b4a26c0d..9bc2088f2088a 100644 --- a/js/common/lib/inference-session-impl.ts +++ b/js/common/lib/inference-session-impl.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {resolveBackend} from './backend-impl.js'; -import {SessionHandler} from './backend.js'; +import {InferenceSessionHandler} from './backend.js'; import {InferenceSession as InferenceSessionInterface} from './inference-session.js'; import {OnnxValue} from './onnx-value.js'; import {Tensor} from './tensor.js'; @@ -14,7 +14,7 @@ type FetchesType = InferenceSessionInterface.FetchesType; type ReturnType = InferenceSessionInterface.ReturnType; export class InferenceSession implements InferenceSessionInterface { - private constructor(handler: SessionHandler) { + private constructor(handler: InferenceSessionHandler) { this.handler = handler; } run(feeds: FeedsType, options?: RunOptions): Promise; @@ -195,7 +195,7 @@ export class InferenceSession implements InferenceSessionInterface { const eps = options.executionProviders || []; const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); const backend = await resolveBackend(backendHints); - const handler = await backend.createSessionHandler(filePathOrUint8Array, options); + const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, options); return new InferenceSession(handler); } @@ -213,5 +213,5 @@ export class InferenceSession implements InferenceSessionInterface { return this.handler.outputNames; } - private handler: SessionHandler; + private handler: InferenceSessionHandler; } diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index 71a5912df2464..8c1e69a68ca7e 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -192,6 +192,7 @@ export declare namespace InferenceSession { wasm: WebAssemblyExecutionProviderOption; webgl: WebGLExecutionProviderOption; xnnpack: XnnpackExecutionProviderOption; + webgpu: WebGpuExecutionProviderOption; webnn: WebNNExecutionProviderOption; nnapi: NnapiExecutionProviderOption; } @@ -233,6 +234,10 @@ export declare namespace InferenceSession { export interface XnnpackExecutionProviderOption extends ExecutionProviderOption { readonly name: 'xnnpack'; } + export interface WebGpuExecutionProviderOption extends ExecutionProviderOption { + readonly name: 'webgpu'; + preferredLayout?: 'NCHW'|'NHWC'; + } export interface WebNNExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webnn'; deviceType?: 'cpu'|'gpu'; diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts new file mode 100644 index 0000000000000..f06d06bda035f --- /dev/null +++ b/js/common/lib/training-session-impl.ts @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TrainingSessionHandler} from './backend.js'; +import {InferenceSession as InferenceSession} from './inference-session.js'; +import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; + +type SessionOptions = InferenceSession.SessionOptions; + +export class TrainingSession implements TrainingSessionInterface { + private constructor(handler: TrainingSessionHandler) { + this.handler = handler; + } + private handler: TrainingSessionHandler; + + get inputNames(): readonly string[] { + return this.handler.inputNames; + } + get outputNames(): readonly string[] { + return this.handler.outputNames; + } + + static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions): + Promise { + throw new Error('Method not implemented'); + } + + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + + async getContiguousParameters(_trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + + runTrainStep(feeds: InferenceSession.OnnxValueMapType, options?: InferenceSession.RunOptions|undefined): + Promise; + runTrainStep( + feeds: InferenceSession.OnnxValueMapType, fetches: InferenceSession.FetchesType, + options?: InferenceSession.RunOptions|undefined): Promise; + async runTrainStep(_feeds: unknown, _fetches?: unknown, _options?: unknown): + Promise { + throw new Error('Method not implemented.'); + } + + async release(): Promise { + return this.handler.dispose(); + } +} diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts new file mode 100644 index 0000000000000..0967d79b33434 --- /dev/null +++ b/js/common/lib/training-session.ts @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession} from './inference-session.js'; +import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js'; + +/* eslint-disable @typescript-eslint/no-redeclare */ + +export declare namespace TrainingSession { + /** + * Either URI file path (string) or Uint8Array containing model or checkpoint information. + */ + type URIorBuffer = string|Uint8Array; +} + +/** + * Represent a runtime instance of an ONNX training session, + * which contains a model that can be trained, and, optionally, + * an eval and optimizer model. + */ +export interface TrainingSession { + // #region run() + + /** + * Run TrainStep asynchronously with the given feeds and options. + * + * @param feeds - Representation of the model input. See type description of `InferenceSession.InputType` for + detail. + * @param options - Optional. A set of options that controls the behavior of model training. + * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values. + */ + runTrainStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions): + Promise; + + /** + * Run a single train step with the given inputs and options. + * + * @param feeds - Representation of the model input. + * @param fetches - Representation of the model output. + * detail. + * @param options - Optional. A set of options that controls the behavior of model inference. + * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding + values. + */ + runTrainStep( + feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType, + options?: InferenceSession.RunOptions): Promise; + + // #endregion + + // #region copy parameters + /** + * Copies from a buffer containing parameters to the TrainingSession parameters. + * + * @param buffer - buffer containing parameters + * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. + */ + loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + + /** + * Copies from the TrainingSession parameters to a buffer. + * + * @param trainableOnly - True if trainable parameters only to be copied, false othrwise. + * @returns A promise that resolves to a buffer of the requested parameters. + */ + getContiguousParameters(trainableOnly: boolean): Promise; + // #endregion + + // #region release() + + /** + * Release the inference session and the underlying resources. + */ + release(): Promise; + // #endregion + + // #region metadata + + /** + * Get input names of the loaded model. + */ + readonly inputNames: readonly string[]; + + /** + * Get output names of the loaded model. + */ + readonly outputNames: readonly string[]; + // #endregion +} + +/** + * Represents the optional parameters that can be passed into the TrainingSessionFactory. + */ +export interface TrainingSessionCreateOptions { + /** + * URI or buffer for a .ckpt file that contains the checkpoint for the training model. + */ + checkpointState: TrainingSession.URIorBuffer; + /** + * URI or buffer for the .onnx training file. + */ + trainModel: TrainingSession.URIorBuffer; + /** + * Optional. URI or buffer for the .onnx optimizer model file. + */ + optimizerModel?: TrainingSession.URIorBuffer; + /** + * Optional. URI or buffer for the .onnx eval model file. + */ + evalModel?: TrainingSession.URIorBuffer; +} + +/** + * Defines method overload possibilities for creating a TrainingSession. + */ +export interface TrainingSessionFactory { + // #region create() + + /** + * Creates a new TrainingSession and asynchronously loads any models passed in through trainingOptions + * + * @param trainingOptions specify models and checkpoints to load into the Training Session + * @param sessionOptions specify configuration for training session behavior + * + * @returns Promise that resolves to a TrainingSession object + */ + create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: InferenceSession.SessionOptions): + Promise; + + // #endregion +} + +// eslint-disable-next-line @typescript-eslint/naming-convention +export const TrainingSession: TrainingSessionFactory = TrainingSessionImpl; diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index d3680f9d44236..5f5ad49a2dea8 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Backend, InferenceSession, SessionHandler} from 'onnxruntime-common'; +import {Backend, InferenceSession, InferenceSessionHandler, SessionHandler} from 'onnxruntime-common'; import {Binding, binding} from './binding'; -class OnnxruntimeSessionHandler implements SessionHandler { +class OnnxruntimeSessionHandler implements InferenceSessionHandler { #inferenceSession: Binding.InferenceSession; constructor(pathOrBuffer: string|Uint8Array, options: InferenceSession.SessionOptions) { @@ -53,8 +53,8 @@ class OnnxruntimeBackend implements Backend { return Promise.resolve(); } - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { return new Promise((resolve, reject) => { process.nextTick(() => { try { diff --git a/js/react_native/e2e/package.json b/js/react_native/e2e/package.json index 969c70c110123..cd97ec1d099e4 100644 --- a/js/react_native/e2e/package.json +++ b/js/react_native/e2e/package.json @@ -10,7 +10,8 @@ }, "dependencies": { "react": "^18.1.0", - "react-native": "^0.69.1" + "react-native": "^0.69.1", + "react-native-fs": "^2.20.0" }, "devDependencies": { "@babel/core": "^7.17.0", diff --git a/js/react_native/e2e/src/App.tsx b/js/react_native/e2e/src/App.tsx index f3e415f0c5a55..8a76edabc613e 100644 --- a/js/react_native/e2e/src/App.tsx +++ b/js/react_native/e2e/src/App.tsx @@ -8,6 +8,7 @@ import { Image, Text, TextInput, View } from 'react-native'; import { InferenceSession, Tensor } from 'onnxruntime-react-native'; import MNIST, { MNISTInput, MNISTOutput, MNISTResult, } from './mnist-data-handler'; import { Buffer } from 'buffer'; +import { readFile } from 'react-native-fs'; interface State { session: @@ -39,10 +40,21 @@ export default class App extends React.PureComponent<{}, State> { this.setState({ imagePath }); const modelPath = await MNIST.getLocalModelPath(); - const session: InferenceSession = await InferenceSession.create(modelPath); + + // test creating session with path + console.log('Creating with path'); + const pathSession: InferenceSession = await InferenceSession.create(modelPath); + pathSession.release(); + + // and with bytes + console.log('Creating with bytes'); + const base64Str = await readFile(modelPath, 'base64'); + const bytes = Buffer.from(base64Str, 'base64'); + const session: InferenceSession = await InferenceSession.create(bytes); this.setState({ session }); - void this.infer(); + console.log('Test session created'); + void await this.infer(); } catch (err) { console.log(err.message); } diff --git a/js/react_native/lib/backend.ts b/js/react_native/lib/backend.ts index b3f0c466308a5..3b1852699ac48 100644 --- a/js/react_native/lib/backend.ts +++ b/js/react_native/lib/backend.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Backend, InferenceSession, SessionHandler, Tensor,} from 'onnxruntime-common'; +import {Backend, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {Platform} from 'react-native'; import {binding, Binding, JSIBlob, jsiHelper} from './binding'; @@ -43,7 +43,7 @@ const normalizePath = (path: string): string => { return path; }; -class OnnxruntimeSessionHandler implements SessionHandler { +class OnnxruntimeSessionHandler implements InferenceSessionHandler { #inferenceSession: Binding.InferenceSession; #key: string; @@ -66,12 +66,14 @@ class OnnxruntimeSessionHandler implements SessionHandler { let results: Binding.ModelLoadInfoType; // load a model if (typeof this.#pathOrBuffer === 'string') { + // load model from model path results = await this.#inferenceSession.loadModel(normalizePath(this.#pathOrBuffer), options); } else { + // load model from buffer if (!this.#inferenceSession.loadModelFromBlob) { throw new Error('Native module method "loadModelFromBlob" is not defined'); } - const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer); + const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer.buffer); results = await this.#inferenceSession.loadModelFromBlob(modelBlob, options); } // resolve promise if onnxruntime session is successfully created @@ -163,8 +165,8 @@ class OnnxruntimeBackend implements Backend { return Promise.resolve(); } - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { const handler = new OnnxruntimeSessionHandler(pathOrBuffer); await handler.loadModel(options || {}); return handler; diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index a87a894e3b3c5..44003021293b0 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -21,6 +21,8 @@ Do not modify directly.* | Atan | ai.onnx(7+) | | | Atanh | ai.onnx(9+) | | | AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(11+) | need perf optimization; need implementing activation | +| BiasAdd | com.microsoft(1+) | | +| BiasSplitGelu | com.microsoft(1+) | | | Cast | ai.onnx(6-8,9-12,13-18,19+) | | | Ceil | ai.onnx(6-12,13+) | | | Clip | ai.onnx(6-10,11,12,13+) | | @@ -62,6 +64,7 @@ Do not modify directly.* | Not | ai.onnx(1+) | | | Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | | Pow | ai.onnx(7-11,12,13-14,15+) | | +| Range | ai.onnx(11+) | | | Reciprocal | ai.onnx(6-12,13+) | | | ReduceL1 | ai.onnx(1-10,11-12,13-17,18+) | | | ReduceL2 | ai.onnx(1-10,11-12,13-17,18+) | | @@ -93,3 +96,4 @@ Do not modify directly.* | Tile | ai.onnx(6-12,13+) | | | Transpose | ai.onnx(1-12,13+) | need perf optimization | | Unsqueeze | ai.onnx(1-10,11-12,13+) | | +| Where | ai.onnx(9-15,16+) | | diff --git a/js/web/lib/backend-onnxjs.ts b/js/web/lib/backend-onnxjs.ts index 18a068e0ced8b..5ea7de809a495 100644 --- a/js/web/lib/backend-onnxjs.ts +++ b/js/web/lib/backend-onnxjs.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. /* eslint-disable import/no-internal-modules */ -import {Backend, InferenceSession, SessionHandler} from 'onnxruntime-common'; +import {Backend, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {Session} from './onnxjs/session'; import {OnnxjsSessionHandler} from './onnxjs/session-handler'; @@ -11,8 +11,8 @@ class OnnxjsBackend implements Backend { // eslint-disable-next-line @typescript-eslint/no-empty-function async init(): Promise {} - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { // NOTE: Session.Config(from onnx.js) is not compatible with InferenceSession.SessionOptions(from // onnxruntime-common). // In future we should remove Session.Config and use InferenceSession.SessionOptions. diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index ceb20044d97b6..04108c2ad0f66 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Backend, env, InferenceSession, SessionHandler} from 'onnxruntime-common'; +import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {cpus} from 'os'; import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper'; @@ -40,10 +40,12 @@ class OnnxruntimeWebAssemblyBackend implements Backend { // init wasm await initializeWebAssemblyInstance(); } - createSessionHandler(path: string, options?: InferenceSession.SessionOptions): Promise; - createSessionHandler(buffer: Uint8Array, options?: InferenceSession.SessionOptions): Promise; - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + createInferenceSessionHandler(path: string, options?: InferenceSession.SessionOptions): + Promise; + createInferenceSessionHandler(buffer: Uint8Array, options?: InferenceSession.SessionOptions): + Promise; + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { const handler = new OnnxruntimeWebAssemblySessionHandler(); await handler.loadModel(pathOrBuffer, options); return Promise.resolve(handler); diff --git a/js/web/lib/onnxjs/session-handler.ts b/js/web/lib/onnxjs/session-handler.ts index 0b06a7a747a44..47e50aeab673a 100644 --- a/js/web/lib/onnxjs/session-handler.ts +++ b/js/web/lib/onnxjs/session-handler.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession, SessionHandler, Tensor} from 'onnxruntime-common'; +import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {Session} from './session'; import {Tensor as OnnxjsTensor} from './tensor'; -export class OnnxjsSessionHandler implements SessionHandler { +export class OnnxjsSessionHandler implements InferenceSessionHandler { constructor(private session: Session) { this.inputNames = this.session.inputNames; this.outputNames = this.session.outputNames; diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 59da1369e152e..b7b2ff4537095 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import type {Tensor} from 'onnxruntime-common'; + export declare namespace JSEP { type BackendType = unknown; type AllocFunction = (size: number) => number; @@ -9,11 +11,8 @@ export declare namespace JSEP { type DownloadFunction = (gpuDataId: number, dataOffset: number, size: number) => Promise; type CreateKernelFunction = (name: string, kernel: number, attribute: unknown) => void; type ReleaseKernelFunction = (kernel: number) => void; - type RunFunction = (kernel: number, contextDataOffset: number, sessionState: SessionState) => number; - export interface SessionState { - sessionId: number; - errors: Array>; - } + type RunFunction = + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; } export interface OrtWasmModule extends EmscriptenModule { @@ -40,14 +39,23 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtFree(stringHandle: number): void; - _OrtCreateTensor(dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number): - number; + _OrtCreateTensor( + dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number, + dataLocation: number): number; _OrtGetTensorData(tensorHandle: number, dataType: number, dataOffset: number, dimsOffset: number, dimsLength: number): number; _OrtReleaseTensor(tensorHandle: number): void; + _OrtCreateBinding(sessionHandle: number): number; + _OrtBindInput(bindingHandle: number, nameOffset: number, tensorHandle: number): Promise; + _OrtBindOutput(bindingHandle: number, nameOffset: number, tensorHandle: number, location: number): number; + _OrtClearBoundOutputs(ioBindingHandle: number): void; + _OrtReleaseBinding(ioBindingHandle: number): void; + _OrtRunWithBinding( + sessionHandle: number, ioBindingHandle: number, outputCount: number, outputsOffset: number, + runOptionsHandle: number): Promise; _OrtRun( sessionHandle: number, inputNamesOffset: number, inputsOffset: number, inputCount: number, - outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): number; + outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): Promise; _OrtCreateSessionOptions( graphOptimizationLevel: number, enableCpuMemArena: boolean, enableMemPattern: boolean, executionMode: number, @@ -102,17 +110,67 @@ export interface OrtWasmModule extends EmscriptenModule { // #endregion // #region JSEP + /** + * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime. + * This function initializes WebGPU backend and registers a few callbacks that will be called in C++ code. + */ jsepInit? (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction, download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void; + /** + * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). + * + * @param context - specify the kernel context pointer. + * @param index - specify the index of the output. + * @param data - specify the pointer to encoded data of type and dims. + */ _JsepOutput(context: number, index: number, data: number): number; + /** + * [exported from wasm] Get name of an operator node. + * + * @param kernel - specify the kernel pointer. + * @returns the pointer to a C-style UTF8 encoded string representing the node name. + */ _JsepGetNodeName(kernel: number): number; - jsepOnRunStart?(sessionId: number): void; - jsepOnRunEnd?(sessionId: number): Promise; - jsepRunPromise?: Promise; + /** + * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output. + * + * @param sessionId - specify the session ID. + * @param index - specify an integer to represent which input/output it is registering for. For input, it is the + * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index + * corresponding to the session's ouputNames. + * @param buffer - specify the GPU buffer to register. + * @param size - specify the original data size in byte. + * @returns the GPU data ID for the registered GPU buffer. + */ + jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; + /** + * [exported from js_internal_api.js] Unregister all user GPU buffers for a session. + * + * @param sessionId - specify the session ID. + */ + jsepUnregisterBuffers?: (sessionId: number) => void; + /** + * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. + * + * @param dataId - specify the GPU data ID + * @returns the GPU buffer. + */ + jsepGetBuffer: (dataId: number) => GPUBuffer; + /** + * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor. + * + * @param gpuBuffer - specify the GPU buffer + * @param size - specify the original data size in byte. + * @param type - specify the tensor type. + * @returns the generated downloader function. + */ + jsepCreateDownloader: + (gpuBuffer: GPUBuffer, size: number, + type: Tensor.GpuBufferDataTypes) => () => Promise; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 5e77a0343b4ee..5bec562b157ac 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env} from 'onnxruntime-common'; +import {Env, Tensor} from 'onnxruntime-common'; import {configureLogger, LOG_DEBUG} from './log'; -import {TensorView} from './tensor-view'; -import {createGpuDataManager, GpuDataManager} from './webgpu/gpu-data-manager'; +import {createView, TensorView} from './tensor-view'; +import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; import {ComputeContext, GpuData, ProgramInfo, ProgramInfoLoader} from './webgpu/types'; @@ -98,6 +98,11 @@ export class WebGpuBackend { env: Env; + /** + * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. + */ + sessionExternalDataMapping: Map> = new Map(); + async initialize(env: Env): Promise { if (!navigator.gpu) { // WebGPU is not available. @@ -192,11 +197,13 @@ export class WebGpuBackend { } flush(): void { - this.endComputePass(); - this.device.queue.submit([this.getCommandEncoder().finish()]); - this.gpuDataManager.refreshPendingBuffers(); - this.commandEncoder = null; - this.pendingDispatchNumber = 0; + if (this.commandEncoder) { + this.endComputePass(); + this.device.queue.submit([this.getCommandEncoder().finish()]); + this.gpuDataManager.refreshPendingBuffers(); + this.commandEncoder = null; + this.pendingDispatchNumber = 0; + } } /** @@ -304,12 +311,9 @@ export class WebGpuBackend { } async download(gpuDataId: number, getTargetBuffer: () => Uint8Array): Promise { - const arrayBuffer = await this.gpuDataManager.download(gpuDataId); - // the underlying buffer may be changed after the async function is called. so we use a getter function to make sure // the buffer is up-to-date. - const data = getTargetBuffer(); - data.set(new Uint8Array(arrayBuffer, 0, data.byteLength)); + await this.gpuDataManager.download(gpuDataId, getTargetBuffer); } alloc(size: number): number { @@ -372,7 +376,7 @@ export class WebGpuBackend { kernelEntry(context, attributes[1]); return 0; // ORT_OK } catch (e) { - LOG_DEBUG('warning', `[WebGPU] Kernel "[${opType}] ${nodeName}" failed. Error: ${e}`); + errors.push(Promise.resolve(`[WebGPU] Kernel "[${opType}] ${nodeName}" failed. ${e}`)); return 1; // ORT_FAIL } finally { if (useErrorScope) { @@ -387,4 +391,40 @@ export class WebGpuBackend { this.currentKernelId = null; } } + + // #region external buffer + registerBuffer(sessionId: number, index: number, buffer: GPUBuffer, size: number): number { + let sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId); + if (!sessionInputOutputMapping) { + sessionInputOutputMapping = new Map(); + this.sessionExternalDataMapping.set(sessionId, sessionInputOutputMapping); + } + + const previousBuffer = sessionInputOutputMapping.get(index); + const id = this.gpuDataManager.registerExternalBuffer(buffer, size, previousBuffer?.[1]); + sessionInputOutputMapping.set(index, [id, buffer]); + return id; + } + unregisterBuffers(sessionId: number): void { + const sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId); + if (sessionInputOutputMapping) { + sessionInputOutputMapping.forEach(bufferInfo => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[1])); + this.sessionExternalDataMapping.delete(sessionId); + } + } + getBuffer(gpuDataId: number): GPUBuffer { + const gpuData = this.gpuDataManager.get(gpuDataId); + if (!gpuData) { + throw new Error(`no GPU data for buffer: ${gpuDataId}`); + } + return gpuData.buffer; + } + createDownloader(gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes): + () => Promise { + return async () => { + const data = await downloadGpuData(this, gpuBuffer, size); + return createView(data.buffer, type); + }; + } + // #endregion } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 78316cbe1c825..6ff3971d720fd 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -3,7 +3,7 @@ import {Env} from 'onnxruntime-common'; -import {JSEP, OrtWasmModule} from '../binding/ort-wasm'; +import {OrtWasmModule} from '../binding/ort-wasm'; import {DataType, getTensorElementSize} from '../wasm-common'; import {WebGpuBackend} from './backend-webgpu'; @@ -120,6 +120,11 @@ class ComputeContextImpl implements ComputeContext { this.module.HEAPU32[offset++] = dims[i]; } return this.module._JsepOutput(this.opKernelContext, index, data); + } catch (e) { + throw new Error( + `Failed to generate kernel's output[${index}] with dims [${dims}]. ` + + 'If you are running with pre-allocated output, please make sure the output type/dims are correct. ' + + `Error: ${e}`); } finally { this.module.stackRestore(stack); } @@ -138,7 +143,7 @@ export const init = async(module: OrtWasmModule, env: Env): Promise => { init( // backend - {backend}, + backend, // jsepAlloc() (size: number) => backend.alloc(size), @@ -178,13 +183,13 @@ export const init = async(module: OrtWasmModule, env: Env): Promise => { (kernel: number) => backend.releaseKernel(kernel), // jsepRun - (kernel: number, contextDataOffset: number, sessionState: JSEP.SessionState) => { + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { LOG_DEBUG( 'verbose', - () => `[WebGPU] jsepRun: sessionId=${sessionState.sessionId}, kernel=${kernel}, contextDataOffset=${ + () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${ contextDataOffset}`); const context = new ComputeContextImpl(module, backend, contextDataOffset); - return backend.computeKernel(kernel, context, sessionState.errors); + return backend.computeKernel(kernel, context, errors); }); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 92fdd5abc3892..131f7a9bfa29b 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -35,7 +35,7 @@ export interface GpuDataManager { /** * copy data from GPU to CPU. */ - download(id: GpuDataId): Promise; + download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise; /** * refresh the buffers that marked for release. @@ -46,6 +46,19 @@ export interface GpuDataManager { */ refreshPendingBuffers(): void; + /** + * register an external buffer for IO Binding. If the buffer is already registered, return the existing GPU data ID. + * + * GPU data manager only manages a mapping between the buffer and the GPU data ID. It will not manage the lifecycle of + * the external buffer. + */ + registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number; + + /** + * unregister an external buffer for IO Binding. + */ + unregisterExternalBuffer(buffer: GPUBuffer): void; + /** * destroy all gpu buffers. Call this when the session.release is called. */ @@ -62,12 +75,56 @@ interface StorageCacheValue { */ const calcNormalizedBufferSize = (size: number) => Math.ceil(size / 16) * 16; -let guid = 0; +let guid = 1; const createNewGpuDataId = () => guid++; +/** + * exported standard download function. This function is used by the session to download the data from GPU, and also by + * factory to create GPU tensors with the capacity of downloading data from GPU. + * + * @param backend - the WebGPU backend + * @param gpuBuffer - the GPU buffer to download + * @param originalSize - the original size of the data + * @param getTargetBuffer - optional. If provided, the data will be copied to the target buffer. Otherwise, a new buffer + * will be created and returned. + */ +export const downloadGpuData = + async(backend: WebGpuBackend, gpuBuffer: GPUBuffer, originalSize: number, getTargetBuffer?: () => Uint8Array): + Promise => { + const bufferSize = calcNormalizedBufferSize(originalSize); + const gpuReadBuffer = backend.device.createBuffer( + // eslint-disable-next-line no-bitwise + {size: bufferSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ}); + try { + const commandEncoder = backend.getCommandEncoder(); + backend.endComputePass(); + commandEncoder.copyBufferToBuffer( + gpuBuffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */, + 0 /* destination offset */, bufferSize /* size */ + ); + backend.flush(); + + await gpuReadBuffer.mapAsync(GPUMapMode.READ); + + const arrayBuffer = gpuReadBuffer.getMappedRange(); + if (getTargetBuffer) { + // if we already have a CPU buffer to accept the data, no need to clone the ArrayBuffer. + const targetBuffer = getTargetBuffer(); + targetBuffer.set(new Uint8Array(arrayBuffer, 0, originalSize)); + return targetBuffer; + } else { + // the mapped ArrayBuffer will be released when the GPU buffer is destroyed. Need to clone the + // ArrayBuffer. + return new Uint8Array(arrayBuffer.slice(0, originalSize)); + } + } finally { + gpuReadBuffer.destroy(); + } + }; + class GpuDataManagerImpl implements GpuDataManager { // GPU Data ID => GPU Data ( storage buffer ) - storageCache: Map; + private storageCache: Map; // pending buffers for uploading ( data is unmapped ) private buffersForUploadingPending: GPUBuffer[]; @@ -77,11 +134,15 @@ class GpuDataManagerImpl implements GpuDataManager { // The reusable storage buffers for computing. private freeBuffers: Map; + // The external buffers registered users for IO Binding. + private externalBuffers: Map; + constructor(private backend: WebGpuBackend) { this.storageCache = new Map(); this.freeBuffers = new Map(); this.buffersForUploadingPending = []; this.buffersPending = []; + this.externalBuffers = new Map(); } upload(id: GpuDataId, data: Uint8Array): void { @@ -143,6 +204,42 @@ class GpuDataManagerImpl implements GpuDataManager { sourceGpuDataCache.gpuData.buffer, 0, destinationGpuDataCache.gpuData.buffer, 0, size); } + registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number { + let id: number|undefined; + if (previousBuffer) { + id = this.externalBuffers.get(previousBuffer); + if (id === undefined) { + throw new Error('previous buffer is not registered'); + } + if (buffer === previousBuffer) { + LOG_DEBUG( + 'verbose', + () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ + id}, buffer is the same, skip.`); + return id; + } + this.externalBuffers.delete(previousBuffer); + } else { + id = createNewGpuDataId(); + } + + this.storageCache.set(id, {gpuData: {id, type: GpuDataType.default, buffer}, originalSize}); + this.externalBuffers.set(buffer, id); + LOG_DEBUG( + 'verbose', + () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, registered.`); + return id; + } + + unregisterExternalBuffer(buffer: GPUBuffer): void { + const id = this.externalBuffers.get(buffer); + if (id !== undefined) { + this.storageCache.delete(id); + this.externalBuffers.delete(buffer); + LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.unregisterExternalBuffer() => id=${id}`); + } + } + // eslint-disable-next-line no-bitwise create(size: number, usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST): GpuData { const bufferSize = calcNormalizedBufferSize(size); @@ -193,31 +290,13 @@ class GpuDataManagerImpl implements GpuDataManager { return cachedData.originalSize; } - async download(id: GpuDataId): Promise { + async download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise { const cachedData = this.storageCache.get(id); if (!cachedData) { throw new Error('data does not exist'); } - const commandEncoder = this.backend.getCommandEncoder(); - this.backend.endComputePass(); - const bufferSize = calcNormalizedBufferSize(cachedData.originalSize); - const gpuReadBuffer = this.backend.device.createBuffer( - // eslint-disable-next-line no-bitwise - {size: bufferSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ}); - commandEncoder.copyBufferToBuffer( - cachedData.gpuData.buffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */, - 0 /* destination offset */, bufferSize /* size */ - ); - this.backend.flush(); - - return new Promise((resolve) => { - gpuReadBuffer.mapAsync(GPUMapMode.READ).then(() => { - const data = gpuReadBuffer.getMappedRange().slice(0); - gpuReadBuffer.destroy(); - resolve(data); - }); - }); + await downloadGpuData(this.backend, cachedData.gpuData.buffer, cachedData.originalSize, getTargetBuffer); } refreshPendingBuffers(): void { diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index e92e6696d9a78..40309c1849bcc 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -2,6 +2,8 @@ // Licensed under the MIT License. import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax'; +import {biasAdd} from './ops/bias-add'; +import {biasSplitGelu} from './ops/bias-split-gelu'; import * as binaryOps from './ops/binary-op'; import {concat, parseConcatAttributes} from './ops/concat'; import {conv, parseConvAttributes} from './ops/conv'; @@ -16,6 +18,7 @@ import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; import {matMul} from './ops/matmul'; import {pad, parsePadAttributes} from './ops/pad'; import * as pool from './ops/pool'; +import {range} from './ops/range'; import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm'; @@ -25,6 +28,7 @@ import {parseSplitAttributes, split} from './ops/split'; import {tile} from './ops/tile'; import {parseTransposeAttributes, transpose} from './ops/transpose'; import * as unaryOps from './ops/unary-op'; +import {where} from './ops/where'; import {ComputeContext} from './types'; export type RunFunction = (context: ComputeContext, attribute?: unknown) => void; @@ -44,6 +48,8 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Atanh', [unaryOps.atanh]], // TODO: support new attributes for AveragePool-10 ['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]], + ['BiasAdd', [biasAdd]], + ['BiasSplitGelu', [biasSplitGelu]], ['Cast', [unaryOps.cast, unaryOps.parseCastAttributes]], ['Ceil', [unaryOps.ceil]], ['ClipV10', [unaryOps.clipV10]], @@ -83,6 +89,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Not', [unaryOps.not]], ['Pad', [pad, parsePadAttributes]], ['Pow', [binaryOps.pow]], + ['Range', [range]], ['Reciprocal', [unaryOps.reciprocal]], ['ReduceMin', [reduceMin, parseReduceAttributes]], ['ReduceMean', [reduceMean, parseReduceAttributes]], @@ -110,4 +117,5 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['ThresholdedRelu', [unaryOps.thresholdedRelu, unaryOps.parseAlphaAttributes]], ['Tile', [tile]], ['Transpose', [transpose, parseTransposeAttributes]], + ['Where', [where]], ]); diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts index dd4f13e76ee04..22b91d680a9b4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts @@ -21,16 +21,16 @@ export declare type Activation = 'linear' | 'relu' | 'prelu' | 'elu' | 'relu6' | 'leakyrelu' | 'sigmoid' | 'gelu'; -export const typeSnippet = (component: number) => { +export const typeSnippet = (component: number, dataType: string) => { switch (component) { case 1: - return 'f32'; + return dataType; case 2: - return 'vec2'; + return `vec2<${dataType}>`; case 3: - return 'vec3'; + return `vec3<${dataType}>`; case 4: - return 'vec4'; + return `vec4<${dataType}>`; default: throw new Error(`${component}-component is not supported.`); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 08b1d1f30b233..e6d4039d8131b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -23,6 +23,7 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; +import {tensorTypeToWsglStorageType} from '../common'; import {ConvAttributes} from '../conv'; import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; @@ -32,13 +33,13 @@ import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packe const conv2dCommonSnippet = (isChannelsLast: boolean, fitAOuter: boolean, fitBOuter: boolean, fitInner: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false, innerElementSizeX = 4, innerElementSizeW = 4, - innerElementSize = 4): string => { + innerElementSize = 4, dataType = 'f32'): string => { const getXSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: return 'resData = x[xIndex];'; case 3: - return 'resData = vec3(x[xIndex], x[xIndex + 1], x[xIndex + 2]);'; + return `resData = vec3<${dataType}>(x[xIndex], x[xIndex + 1], x[xIndex + 2]);`; case 4: return 'resData = x[xIndex / 4];'; default: @@ -92,7 +93,7 @@ const conv2dCommonSnippet = let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0]; let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1]; let xCh = ${col} % inChannels; - var resData = ${typeSnippet(innerElementSizeX)}(0.0); + var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0); // The bounds checking is always needed since we use it to pad zero for // the 'same' padding type. if (xRow >= 0 && xRow < ${xHeight} && xCol >= 0 && xCol < ${xWidth}) { @@ -110,7 +111,7 @@ const conv2dCommonSnippet = if (row < dimAOuter && col < dimInner) { ${readXSnippet} } - return ${typeSnippet(innerElementSizeX)}(0.0);`) : + return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : (fitInner && fitBOuter ? ` let col = colIn * ${innerElementSizeX}; ${readXSnippet}` : @@ -119,13 +120,15 @@ const conv2dCommonSnippet = if (row < dimInner && col < dimBOuter) { ${readXSnippet} } - return ${typeSnippet(innerElementSizeX)}(0.0);`); + return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); const sampleW = `${getWSnippet(innerElementSizeW)}`; - const resType = typeSnippet(innerElementSize); - const aType = isChannelsLast ? typeSnippet(innerElementSizeX) : typeSnippet(innerElementSizeW); - const bType = isChannelsLast ? typeSnippet(innerElementSizeW) : typeSnippet(innerElementSizeX); + const resType = typeSnippet(innerElementSize, dataType); + const aType = + isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); + const bType = + isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); const userCode = ` ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { @@ -190,23 +193,24 @@ export const createConv2DMatMulProgramInfo = const fitInner = dimInner % tileInner === 0; const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; + const t = tensorTypeToWsglStorageType(inputs[0].dataType); const declareInputs = [ - `@group(0) @binding(0) var x: array<${isVec4 && innerElementSize === 4 ? 'vec4' : 'f32'}>;`, - `@group(0) @binding(1) var w: array<${isVec4 ? 'vec4' : 'f32'}>;` + `@group(0) @binding(0) var x: array<${isVec4 && innerElementSize === 4 ? `vec4<${t}>` : t}>;`, + `@group(0) @binding(1) var w: array<${isVec4 ? `vec4<${t}>` : t}>;` ]; let declareFunctions = ` - fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? 'vec4' : 'f32'}) { - result[flatIndex] = ${isVec4 ? 'vec4' : 'f32'}(value); + fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { + result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); } - fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? 'vec4' : 'f32'}) { + fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3)); setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); }`; if (hasBias) { - declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); + declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? `vec4<${t}>` : t}>;`); declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -222,7 +226,7 @@ export const createConv2DMatMulProgramInfo = // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; ${declareInputs.join('')} @group(0) @binding(${declareInputs.length}) var result: array<${ - isVec4 ? 'vec4' : 'f32'}>; + isVec4 ? `vec4<${t}>` : t}>; //@group(0) @binding(${declareInputs.length + 1}) var uniforms: Uniforms; const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); @@ -240,12 +244,12 @@ export const createConv2DMatMulProgramInfo = ${ conv2dCommonSnippet( isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0], - elementsSize[1], elementsSize[2])} + elementsSize[1], elementsSize[2], t)} ${ isVec4 ? - makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) : + makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( - elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined, + elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined, sequentialAccessByThreads)}` }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts new file mode 100644 index 0000000000000..f41d0d058a624 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -0,0 +1,244 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts +// +// modified to fit the needs of the project + +import {LOG_DEBUG} from '../../../log'; +import {TensorView} from '../../../tensor-view'; +import {ShapeUtil} from '../../../util'; +import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; +import {ConvTransposeAttributes} from '../conv-transpose'; + +import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; +import {utilFunctions} from './conv_util'; +import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; + +const conv2dTransposeCommonSnippet = + (isChannelsLast: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false, + innerElementSize = 4): string => { + const type = typeSnippet(innerElementSize, 'f32'); + const getWSnippet = (innerElementSize: number) => { + switch (innerElementSize) { + case 1: + return 'return W[getIndexFromCoords4D(coord, wShape)];'; + case 4: + return ` + let coord1 = vec4(coordX, coordY, col + 1, rowInner); + let coord2 = vec4(coordX, coordY, col + 2, rowInner); + let coord3 = vec4(coordX, coordY, col + 3, rowInner); + let v0 = W[getIndexFromCoords4D(coord, wShape)]; + let v1 = W[getIndexFromCoords4D(coord1, wShape)]; + let v2 = W[getIndexFromCoords4D(coord2, wShape)]; + let v3 = W[getIndexFromCoords4D(coord3, wShape)]; + return vec4(v0, v1, v2, v3); + `; + default: + throw new Error(`innerElementSize ${innerElementSize} is not supported.`); + } + }; + const coordASnippet = isChannelsLast ? ` + let coord = vec4(batch, iXR, iXC, xCh); + ` : + ` + let coord = vec4(batch, xCh, iXR, iXC); + `; + + const coordResSnippet = isChannelsLast ? ` + let coords = vec4( + batch, + row / outWidth, + row % outWidth, + col); + ` : + ` + let coords = vec4( + batch, + row, + col / outWidth, + col % outWidth); + `; + + const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]'; + const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]'; + const row = isChannelsLast ? 'row' : 'col'; + const col = isChannelsLast ? 'col' : 'row'; + + const readASnippet = ` + let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outRow = ${row} / outWidth; + let outCol = ${row} % outWidth; + + let WRow = ${col} / (filterDims[1] * inChannels); + let WCol = ${col} / inChannels % filterDims[1]; + let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); + let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); + if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { + return ${type}(0.0); + } + if (xC < 0.0 || xC >= f32(${xWidth}) || fract(xC) > 0.0) { + return ${type}(0.0); + } + let iXR = i32(xR); + let iXC = i32(xC); + let xCh = ${col} % inChannels; + ${coordASnippet} + return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`; + + const sampleA = isChannelsLast ? ` + let col = colIn * ${innerElementSize}; + if (row < dimAOuter && col < dimInner) { + ${readASnippet} + } + return ${type}(0.0);` : + ` + let col = colIn * ${innerElementSize}; + if (row < dimInner && col < dimBOuter) { + ${readASnippet} + } + return ${type}(0.0);`; + + const sampleW = ` + let col = colIn * ${innerElementSize}; + let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); + let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; + if (${ + isChannelsLast ? 'row < dimInner && col < dimBOuter' : + 'row < dimInner && col < dimAOuter'} && coordX >= 0 && coordY >= 0) { + let rowInner = row % inChannels; + let coord = vec4(coordX, coordY, col, rowInner); + ${getWSnippet(innerElementSize)} + } + return ${type}(0.0); + `; + + + const userCode = ` + ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} + fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { + ${isChannelsLast ? sampleA : sampleW} + } + + fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${type} { + ${isChannelsLast ? sampleW : sampleA} + } + + fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { + let col = colIn * ${innerElementSize}; + if (row < dimAOuter && col < dimBOuter) { + var value = valueInput; + let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + ${coordResSnippet} + ${biasActivationSnippet(addBias, activation)} + result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value; + } + }`; + return userCode; + }; + +export const createConv2DTransposeMatMulProgramInfo = + (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvTransposeAttributes, + outputShape: readonly number[], dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean, + sequentialAccessByThreads: boolean): ProgramInfo => { + const isChannelsLast = attributes.format === 'NHWC'; + const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; + const batchSize = outputShape[0]; + const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; + const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; + const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; + const isVec4 = + isChannelsLast ? inChannels % 4 === 0 && outChannels % 4 === 0 : outWidth % 4 === 0 && outChannels % 4 === 0; + + // TODO: fine tune size + const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; + const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; + const workGroupSize: [number, number, number] = isVec4 ? + [8, 8, 1] : + [(dispatchX <= 4 || dispatchY <= 4) ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; + const elementsPerThread = + isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 4, dispatchX > 4 && dispatchY <= 4 ? 1 : 4, 1]; + const dispatch = [ + Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), + Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), + Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2]) + ]; + + LOG_DEBUG('verbose', () => `[conv_backprop_mm_webgpu] dispatch = ${dispatch}`); + + const innerElementSize = isVec4 ? 4 : 1; + const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); + + + const declareInputs = [ + `@group(0) @binding(0) var x: array<${isVec4 ? 'vec4' : 'f32'}>;`, + '@group(0) @binding(1) var W: array;' + ]; + let declareFunctions = ''; + if (hasBias) { + declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); + declareFunctions += ` + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; + }`; + } + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), + getShaderSource: () => ` + ${utilFunctions} + ${declareInputs.join('\n')} + @group(0) @binding(${declareInputs.length}) var result: array<${ + isVec4 ? 'vec4' : 'f32'}>; + const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); + const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); + const wShape : vec4 = vec4(${inputs[1].dims.join(',')}); + const outShape : vec4 = vec4(${outputShape.join(',')}); + const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); + const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ + attributes.kernelShape[isChannelsLast ? 2 : 3]}); + const effectiveFilterDims : vec2 = filterDims + vec2( + ${ + attributes.dilations[0] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, + ${ + attributes.dilations[1] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); + const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${ + attributes.pads[0] + attributes.pads[2]})/2, + i32(effectiveFilterDims[1]) - 1 - (${ + attributes.pads[1] + attributes.pads[3]})/2); + const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); + const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); + const dimAOuter : i32 = ${dimAOuter}; + const dimBOuter : i32 = ${dimBOuter}; + const dimInner : i32 = ${dimInner}; + ${declareFunctions} + ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)} + ${ + isVec4 ? makeMatMulPackedVec4Source( + elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : + makeMatMulPackedSource( + elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, + undefined, sequentialAccessByThreads)}` + }; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index ec6df438129fb..c60b94056a360 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -21,12 +21,13 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; -import {inputVariable, outputVariable, ShaderHelper} from '../common'; +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; const createConvTranspose2DOpProgramShaderSource = (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false): string => { + outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false, + dataType: string): string => { const isChannelsLast = attributes.format === 'NHWC'; const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; @@ -39,12 +40,12 @@ const createConvTranspose2DOpProgramShaderSource = const outputChannelsPerGroup = wShape[1]; let declareFunctions = ` - fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? 'vec4' : 'f32'}) { - result[flatIndex] = ${isVec4 ? 'vec4' : 'f32'}(value); + fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { + result[flatIndex] = ${isVec4 ? `vec4<${dataType}>` : dataType}(value); }`; if (hasBias) { declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${dataType}>` : dataType} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -66,33 +67,33 @@ const createConvTranspose2DOpProgramShaderSource = // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. - var dotProd: array, ${workPerThread}>; + var dotProd: array, ${workPerThread}>; for (var i = 0; i < ${workPerThread}; i++) { - dotProd[i] = vec4(0.0); + dotProd[i] = vec4<${dataType}>(0.0); } for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) { - var dyR = (f32(dyCorner.x) + f32(wR)) / f32(strides.x); + var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x); let wRPerm = filterDims[0] - 1 - wR; - if (dyR < 0.0 || dyR >= f32(outBackprop[1]) || + if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) { - let dyC = (f32(dyCorner.y) + f32(wC)) / f32(strides.y); - let dyC2 = (f32(dyCorner.y) + 1.0 + f32(wC)) / f32(strides.y); + let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y); + let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y); let wCPerm = filterDims[1] - 1 - wC; if (wCPerm < 0) { continue; } var bDyCVal = true; var bDyCVal2 = true; - if (dyC < 0.0 || dyC >= f32(outBackprop[2]) || + if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) || fract(dyC) > 0.0) { bDyCVal = false; } - if (dyC2 < 0.0 || dyC2 >= f32(outBackprop[2]) || + if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) || fract(dyC2) > 0.0) { bDyCVal2 = false; } @@ -108,7 +109,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -116,7 +117,7 @@ const createConvTranspose2DOpProgramShaderSource = xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - dotProd[1] = dotProd[1] + vec4(dot(xValue, wValue0), + dotProd[1] = dotProd[1] + vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -130,7 +131,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -145,7 +146,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -178,9 +179,9 @@ const createConvTranspose2DOpProgramShaderSource = if (wR % dilations.x != 0) { continue; } - let dyR = (f32(dyRCorner) + f32(wR)) / f32(strides[0]); + let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]); let wRPerm = filterDims.x - 1 - wR / dilations.x; - if (dyR < 0.0 || dyR >= f32(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || + if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } @@ -190,21 +191,21 @@ const createConvTranspose2DOpProgramShaderSource = if (wC % dilations.y != 0) { continue; } - let dyC = (f32(dyCCorner) + f32(wC)) / f32(strides.y); + let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y); let wCPerm = filterDims.y - 1 - wC / dilations.y; - if (dyC < 0.0 || dyC >= f32(outBackprop[${colDim}]) || + if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) || fract(dyC) > 0.0 || wCPerm < 0) { continue; } let idyC: u32 = u32(dyC); - + var inputChannel = groupId * ${inputChannelsPerGroup}; for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) { - let inputChannel = groupId * ${inputChannelsPerGroup} + d2; let xValue = ${ isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')}; dotProd = dotProd + xValue * wValue; + inputChannel = inputChannel + 1; } } } @@ -256,6 +257,7 @@ export const createConvTranspose2DProgramInfo = ]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); return { ...metadata, outputs: [{ @@ -265,6 +267,7 @@ export const createConvTranspose2DProgramInfo = }], dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( - shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1), + shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false, + dataType), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 8d43dbb378a69..82f8c82291f4b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -22,7 +22,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; -import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from '../common'; +import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -70,8 +70,8 @@ const calculateResultSnippet = (transposeA: boolean, innerElementSize: number) = }; export const makeMatMulPackedVec4Source = - (workPerThread: number[], workgroupSize: [number, number, number], batchDims?: IndicesHelper, transposeA = false, - tileInner = 32, splitK = false, splitedDimInner = 32): string => { + (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, + transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32): string => { const tileAOuter = workgroupSize[1] * workPerThread[1]; const tileBOuter = workgroupSize[0] * workPerThread[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; @@ -90,8 +90,8 @@ export const makeMatMulPackedVec4Source = workPerThread[0]} must be 4.`); } return ` -var mm_Asub : array, ${tileAWidth / innerElementSize}>, ${tileAHight}>; -var mm_Bsub : array, ${tileBOuter / workPerThread[0]}>, ${tileInner}>; +var mm_Asub : array, ${tileAWidth / innerElementSize}>, ${tileAHight}>; +var mm_Bsub : array, ${tileBOuter / workPerThread[0]}>, ${tileInner}>; const rowPerThread = ${workPerThread[1]}; const colPerThread = ${workPerThread[0]}; @@ -115,7 +115,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; - var acc: array, rowPerThread>; + var acc: array, rowPerThread>; // Loop over shared dimension. let tileRowB = localRow * ${rowPerThreadB}; @@ -179,8 +179,9 @@ const readDataFromSubASnippet = (transposeA: boolean) => // sequentialAccessByThreads means sequential data in memory is accessed by // threads, instead of a single thread (default behavior). export const makeMatMulPackedSource = - (workPerThread: number[], workgroupSize: [number, number, number], batchDims?: IndicesHelper, transposeA = false, - tileInner = 32, splitK = false, splitedDimInner = 32, sequentialAccessByThreads = false): string => { + (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, + transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32, + sequentialAccessByThreads = false): string => { const tileAOuter = workPerThread[1] * workgroupSize[1]; const tileBOuter = workPerThread[0] * workgroupSize[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; @@ -222,7 +223,7 @@ export const makeMatMulPackedSource = workgroupBarrier(); // Compute acc values for a single thread. - var BCached : array; + var BCached : array<${type}, colPerThread>; for (var k = 0; k < tileInner; k = k + 1) { for (var inner = 0; inner < colPerThread; inner = inner + 1) { BCached[inner] = mm_Bsub[k][localCol + inner * ${workgroupSize[0]}]; @@ -283,7 +284,7 @@ for (var t = 0; t < numTiles; t = t + 1) { workgroupBarrier(); // Compute acc values for a single thread. - var BCached : array; + var BCached : array<${type}, colPerThread>; for (var k = 0; k < tileInner; k = k + 1) { for (var inner = 0; inner < colPerThread; inner = inner + 1) { BCached[inner] = mm_Bsub[k][tileCol + inner]; @@ -309,8 +310,8 @@ for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { `; return ` - var mm_Asub : array, ${tileAHight}>; - var mm_Bsub : array, ${tileInner}>; + var mm_Asub : array, ${tileAHight}>; + var mm_Bsub : array, ${tileInner}>; const rowPerThread = ${workPerThread[1]}; const colPerThread = ${workPerThread[0]}; const tileInner = ${tileInner}; @@ -324,7 +325,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; - var acc : array, rowPerThread>; + var acc : array, rowPerThread>; // Without this initialization strange values show up in acc. for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { @@ -347,6 +348,7 @@ const matMulReadWriteFnSource = const outputVariable = variables[5]; const broadCastADims = getBroadcastDims(batchAVariable.shape, batchVariable.shape); const broadCastBDims = getBroadcastDims(batchBVariable.shape, batchVariable.shape); + const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor); const getAIndices = () => { const aRank = aVariable.shape.length; const batchRank = batchVariable.shape.length; @@ -377,8 +379,8 @@ const matMulReadWriteFnSource = }; const source = ` fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${ - typeSnippet(component)} { - var value = ${typeSnippet(component)}(0.0); + typeSnippet(component, dataType)} { + var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; if(row < dimAOuter && col < dimInner) { @@ -389,8 +391,8 @@ const matMulReadWriteFnSource = } fn mm_readB(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${ - typeSnippet(component)} { - var value = ${typeSnippet(component)}(0.0); + typeSnippet(component, dataType)} { + var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; if(row < dimInner && col < dimBOuter) { @@ -400,7 +402,7 @@ const matMulReadWriteFnSource = return value; } - fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component)}) { + fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) { let col = colIn * ${component}; if (row < dimAOuter && col < dimBOuter) { var value = valueIn; @@ -444,6 +446,7 @@ export const createMatmulProgramInfo = Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]) ]; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const components = isVec4 ? 4 : 1; const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components); const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components); @@ -466,8 +469,8 @@ export const createMatmulProgramInfo = ${declareFunctions} ${activationFunction} ${ - isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, batchDims) : - makeMatMulPackedSource(elementsPerThread, workgroupSize, batchDims)} + isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : + makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} ${batchDims.impl()}`; return { ...metadata, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts new file mode 100644 index 0000000000000..688bedc619ce6 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-add.ts @@ -0,0 +1,69 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper} from './common'; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (inputs[0].dims.length !== 3) { + throw new Error('input should have 3 dimensions'); + } + + if (![320, 640, 1280].includes(inputs[0].dims[2])) { + throw new Error('number of channels should be 320, 640 or 1280'); + } + + if (inputs[1].dims.length !== 1) { + throw new Error('bias is expected to have 1 dimensions'); + } + + if (inputs[0].dims[2] !== inputs[1].dims[0]) { + throw new Error('last dimension of input and bias are not the same'); + } +}; + +const createBiasAddProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => { + const outputShape = inputs[0].dims; + + const channels = inputs[0].dims[2]; + // since channel number can be only 320/640/1280, it's always divisable by 4 + const outputSize = ShapeUtil.size(outputShape) / 4; + + const dataType = inputs[0].dataType; + const input = inputVariable('input', dataType, outputShape, 4); + const bias = inputVariable('bias', dataType, [channels], 4); + const residual = inputVariable('residual', dataType, outputShape, 4); + const output = outputVariable('output', dataType, outputShape, 4); + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const channels = ${channels}u / 4; + ${shaderHelper.declareVariables(input, bias, residual, output)} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let value = ${input.getByOffset('global_idx')} + + ${bias.getByOffset('global_idx % channels')} + ${residual.getByOffset('global_idx')}; + ${output.setByOffset('global_idx', 'value')} + }`; + + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; +}; + +export const biasAdd = (context: ComputeContext): void => { + validateInputs(context.inputs); + const inputTypes = Array(context.inputs.length).fill(GpuDataType.default); + const metadata = { + name: 'BiasAdd', + inputTypes, + }; + + context.compute(createBiasAddProgramInfo(metadata, context.inputs)); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts new file mode 100644 index 0000000000000..8ec4ff5c870d3 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/bias-split-gelu.ts @@ -0,0 +1,76 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {ShapeUtil} from '../../util'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper} from './common'; +import {erfImpl} from './unary-op'; + +const validateInputs = (inputs: readonly TensorView[]): void => { + if (inputs[0].dims.length !== 3) { + throw new Error('input should have 3 dimensions'); + } + + if (![2560, 5120, 10240].includes(inputs[0].dims[2])) { + throw new Error('hidden state should be 2560, 5120 or 10240'); + } + + if (inputs[1].dims.length !== 1) { + throw new Error('bias is expected to have 1 dimensions'); + } + + if (inputs[0].dims[2] !== inputs[1].dims[0]) { + throw new Error('last dimension of input and bias are not the same'); + } +}; + +const createBiasSplitGeluProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => { + const outputShape = inputs[0].dims.slice(); + outputShape[2] = outputShape[2] / 2; + + const input = inputVariable('input', inputs[0].dataType, inputs[0].dims, 4); + const bias = inputVariable('bias', inputs[0].dataType, [inputs[0].dims[2]], 4); + const output = outputVariable('output', inputs[0].dataType, outputShape, 4); + + const outputSize = ShapeUtil.size(outputShape) / 4; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + const M_SQRT2 = sqrt(2.0); + const halfChannels = ${inputs[0].dims[2] / 4 / 2}u; + + ${shaderHelper.declareVariables(input, bias, output)} + + ${erfImpl('vec4f')} + + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + let biasIdx = global_idx % halfChannels; + let batchIndex = global_idx / halfChannels; + let inputOffset = biasIdx + batchIndex * halfChannels * 2; + let valueLeft = input[inputOffset] + bias[biasIdx]; + let valueRight = input[inputOffset + halfChannels] + bias[biasIdx + halfChannels]; + let geluRight = valueRight * 0.5 * (erf_vf32(valueRight / M_SQRT2) + 1); + + ${output.setByOffset('global_idx', 'valueLeft * geluRight')} + }`; + + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + getShaderSource, + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; +}; + +export const biasSplitGelu = (context: ComputeContext): void => { + validateInputs(context.inputs); + + const metadata = { + name: 'BiasSplitGelu', + inputTypes: [GpuDataType.default, GpuDataType.default], + }; + + context.compute(createBiasSplitGeluProgramInfo(metadata, context.inputs)); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 0ab777bfbdee9..fb800d66b59a2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -102,6 +102,16 @@ export interface IndicesHelper { */ readonly indicesToOffset: (varIndices: string) => string; + /** + * WGSL code of an `u32` expression for getting original offset from broadcasted indices. + * + * @param varIndices - a `type.indices` expression representing the output indices. + * @param output - output IndicesHelper. + * + * @returns an `u32` expression + */ + readonly broadcastedIndicesToOffset: (varIndices: string, output: IndicesHelper) => string; + /** * WGSL code of generating an indices literal * @@ -262,6 +272,7 @@ const createIndicesHelper = const implementationUsed = { offsetToIndices: false, indicesToOffset: false, + broadcastedIndicesToOffset: false, set: false, setByIndices: false, get: false, @@ -310,6 +321,26 @@ const createIndicesHelper = return rank < 2 ? varIndices : `i2o_${name}(${varIndices})`; }; + const broadcastedIndicesToOffsetImplementation: {[key: string]: string} = {}; + const broadcastedIndicesToOffset = (varIndices: string, output: IndicesHelper) => { + implementationUsed.broadcastedIndicesToOffset = true; + const implKey = `${output.name}broadcastedIndicesTo${name}Offset`; + if (implKey in broadcastedIndicesToOffsetImplementation) { + return `${implKey}(${varIndices})`; + } + const offsets = []; + for (let i = shape.length - 1; i >= 0; i--) { + const idx = output.indicesGet('outputIndices', i + output.shape.length - shape.length); + offsets.push(`${strides[i]}u * (${idx} % ${shape[i]}u)`); + } + broadcastedIndicesToOffsetImplementation[implKey] = + `fn ${implKey}(outputIndices: ${output.type.indices}) -> u32 { + return ${offsets.length > 0 ? offsets.join('+') : '0u'}; + }`; + + return `${implKey}(${varIndices})`; + }; + const indices = (...init: ReadonlyArray) => rank === 0 ? '0u' : `${type.indices}(${init.map(normalizeDim).join(',')})`; @@ -462,6 +493,9 @@ const createIndicesHelper = if (implementationUsed.indicesToOffset) { impls.push(indicesToOffsetImplementation); } + if (implementationUsed.broadcastedIndicesToOffset) { + Object.values(broadcastedIndicesToOffsetImplementation).forEach(impl => impls.push(impl)); + } if (implementationUsed.set) { impls.push(setImplementation); } @@ -482,6 +516,7 @@ const createIndicesHelper = type, offsetToIndices, indicesToOffset, + broadcastedIndicesToOffset, indices, indicesGet, indicesSet, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index e7d1ddf771650..59f11d6d9abba 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -1,14 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; import {createConvTranspose2DProgramInfo} from './3rd-party/conv_backprop_webgpu'; import {ConvAttributes} from './conv'; +import {createConv2DTransposeMatMulProgramInfoLoader} from './conv2dtranspose-mm'; import {parseInternalActivationAttributes} from './fuse-utils'; +import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose'; const computeTotalPad = (inDim: number, stride: number, adj: number, kernel: number, dilation: number, outSize: number) => @@ -63,7 +64,7 @@ const getAdjustedConvTransposeAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims - if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 0) === 0) { + if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 1) === 0) { kernelShape.length = 0; for (let i = 2; i < inputs[1].dims.length; ++i) { kernelShape.push(inputs[1].dims[i]); @@ -95,9 +96,11 @@ const getAdjustedConvTransposeAttributes = // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - Object.assign( - newAttributes, - {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey: attributes.cacheKey}); + const cacheKey = attributes.cacheKey + [ + kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','), + dilations.join(',') + ].join('_'); + Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey}); return newAttributes; }; @@ -197,15 +200,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvTranspose if (attributes.outputShape.length !== 0 && attributes.outputShape.length !== inputs[0].dims.length - 2) { throw new Error('invalid output shape'); } - - // TODO : Need to add support for float64 - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('ConvTranspose input(X,W) should be float tensor'); - } - - if (inputs.length === 3 && inputs[2].dataType !== DataType.float) { - throw new Error('ConvTranspose input(bias) should be float tensor'); - } }; const createConvTranspose2DProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ @@ -226,12 +220,64 @@ const createConvTranspose2DProgramInfoLoader = }; }; +// for transposing weight tensor from [C, M/group, KH, KW] to [KH, KW, M/group, C] +const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [2, 3, 1, 0]}); + const convTranspose2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); + const isChannelsLast = attributes.format === 'NHWC'; + const hasBias = inputs.length === 3; + if (adjustedAttributes.group !== 1) { + context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); + return; + } + const outputShape = adjustedAttributes.outputShape; + const outHeight = outputShape[isChannelsLast ? 1 : 2]; + const outWidth = outputShape[isChannelsLast ? 2 : 3]; + const outChannels = outputShape[isChannelsLast ? 3 : 1]; + const weightHeight = inputs[1].dims[2]; + const weightWidth = inputs[1].dims[3]; + const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; + + const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; + const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; + const dimInner = weightHeight * weightWidth * inputChannels; - context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); + const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; + + + // STEP.1: transpose weight + const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? + context.compute( + { + ...transposeProgramMetadata, + cacheHint: weightTransposeAttribute.cacheKey, + get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm) + }, + {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + if (attributes.wIsConst && !context.kernelCustomData.wT) { + context.kernelCustomData.wT = transposedWeight; + } + + // STEP.2: prepare reshaped inputs + const convTransposeInputs = [inputs[0], transposedWeight]; + if (hasBias) { + if (!isChannelsLast && inputs[2].dims.length === 1) { + convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); + } else { + convTransposeInputs.push(inputs[2]); + } + } + + // STEP.3: compute matmul + context.compute( + createConv2DTransposeMatMulProgramInfoLoader( + convTransposeInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, + sequentialAccessByThreads), + {inputs: convTransposeInputs}); }; + const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => { // extend the input to 2D by adding H dimension const isChannelLast = attributes.format === 'NHWC'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 95a64e5787841..7afc3ce1b9d77 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {PoolConvUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -93,15 +92,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvAttribute if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) { throw new Error('invalid kernel shape'); } - - // TODO : Need to add support for float64 - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('Conv input(X,W) should be float tensor'); - } - - if (inputs.length === 3 && inputs[2].dataType !== DataType.float) { - throw new Error('Conv input(bias) should be float tensor'); - } }; const getAdjustedConvAttributes = (attributes: T, inputs: readonly TensorView[]): T => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts new file mode 100644 index 0000000000000..da04b5063a9f0 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; +import {ConvTransposeAttributes} from './conv-transpose'; + + +const createConv2DTransposeMatMulProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ + name: 'Conv2DTransposeMatMul', + inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : + [GpuDataType.default, GpuDataType.default], + cacheHint +}); + +export const createConv2DTransposeMatMulProgramInfoLoader = + (inputs: readonly TensorView[], attributes: ConvTransposeAttributes, outputShape: readonly number[], + dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean, + sequentialAccessByThreads: boolean): ProgramInfoLoader => { + const metadata = createConv2DTransposeMatMulProgramMetadata(hasBias, attributes.cacheKey); + return { + ...metadata, + get: () => createConv2DTransposeMatMulProgramInfo( + inputs, metadata, attributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, + sequentialAccessByThreads) + }; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 837ac8410f291..7dadf9a6205ea 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {BroadcastUtil} from '../../util'; import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types'; @@ -9,7 +8,6 @@ import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; import {InternalActivationAttributes} from './fuse-utils'; - const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({ name: 'MatMul', inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : @@ -35,10 +33,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (inputs[0].dims[inputs[0].dims.length - 1] !== inputs[1].dims[inputs[1].dims.length - 2]) { throw new Error('shared dimension does not match.'); } - - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('inputs should be float type'); - } }; export const matMul = (context: ComputeContext): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts index c2f89fd2845df..ccaf9f16ea770 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pad.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pad.ts @@ -209,7 +209,7 @@ const createPadProgramInfo = const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes: PadAttributes): PadAttributes => { if (inputs.length > 1) { const bigInt64Pads = inputs[1].getBigInt64Array(); - const value = (inputs.length >= 3) ? inputs[2].getFloat32Array()[0] : 0.0; + const value = (inputs.length >= 3 && inputs[2].data) ? inputs[2].getFloat32Array()[0] : 0.0; const inputRank = inputs[0].dims.length; const updatePads = new Int32Array(2 * inputRank).fill(0); @@ -220,7 +220,7 @@ const createPadAttributesFromInputs = (inputs: readonly TensorView[], attributes updatePads[Number(axes[i]) + inputRank] = Number(bigInt64Pads[i + axes.length]); } } else { - bigInt64Pads.forEach((i, v) => updatePads[Number(i)] = (Number(v))); + bigInt64Pads.forEach((v, i) => updatePads[Number(i)] = (Number(v))); } const pads: number[] = []; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/range.ts b/js/web/lib/wasm/jsep/webgpu/ops/range.ts new file mode 100644 index 0000000000000..3ecb3308b1899 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/range.ts @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env} from 'onnxruntime-common'; + +import {DataType} from '../../../wasm-common'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; + +import {outputVariable, ShaderHelper} from './common'; + +const validateInputsContent = (start: number, limit: number, delta: number): void => { + const sameStartLimit = start === limit; + const increasingRangeNegativeStep = start < limit && delta < 0; + const decreasingRangePositiveStep = start > limit && delta > 0; + + if (sameStartLimit || increasingRangeNegativeStep || decreasingRangePositiveStep) { + throw new Error('Range these inputs\' contents are invalid.'); + } +}; + +const createRangeProgramInfo = + (metadata: ProgramMetadata, start: number, limit: number, delta: number, dataType: DataType): ProgramInfo => { + const numElements = Math.abs(Math.ceil((limit - start) / delta)); + const outputShape: number[] = [numElements]; + const outputSize = numElements; + + const output = outputVariable('output', dataType, outputShape); + const wgslType = output.type.storage; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.declareVariables(output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + output[global_idx] = ${wgslType}(${start}) + ${wgslType}(global_idx) * ${wgslType}(${delta}); + }`; + return { + ...metadata, + getShaderSource, + outputs: [{dims: outputShape, dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; + }; + +export const range = (context: ComputeContext): void => { + let start = 0; + let limit = 0; + let delta = 0; + if (context.inputs[0].dataType === DataType.int32) { + start = context.inputs[0].getInt32Array()[0]; + limit = context.inputs[1].getInt32Array()[0]; + delta = context.inputs[2].getInt32Array()[0]; + } else if (context.inputs[0].dataType === DataType.float) { + start = context.inputs[0].getFloat32Array()[0]; + limit = context.inputs[1].getFloat32Array()[0]; + delta = context.inputs[2].getFloat32Array()[0]; + } + if (env.webgpu.validateInputContent) { + validateInputsContent(start, limit, delta); + } + + const cacheHint = [start, limit, delta].map(x => x.toString()).join('_'); + const metadata: ProgramMetadata = {name: 'Range', inputTypes: [], cacheHint}; + context.compute( + {...metadata, get: () => createRangeProgramInfo(metadata, start, limit, delta, context.inputs[0].dataType)}, + {inputs: []}); +}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts new file mode 100644 index 0000000000000..4c595bb90b4bc --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts @@ -0,0 +1,110 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {DataType} from '../../../wasm-common'; +import {TensorView} from '../../tensor-view'; +import {BroadcastUtil, ShapeUtil} from '../../util'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {inputVariable, outputVariable, ShaderHelper} from './common'; + +const createWhereOpProgramShader = + (shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean, + typeOutput: number) => { + const outputSize = ShapeUtil.size(dimsOutput); + const vecSize = Math.ceil(outputSize / 4); + + const output = outputVariable('outputData', typeOutput, dimsOutput, 4); + const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4); + const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4); + const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4); + + let assignment: string; + const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`; + if (!isBroadcast) { + assignment = output.setByOffset( + 'global_idx', + expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx'))); + } else { + const singleAssignment = (resStr: string, x: number, typeCast = '') => { + const expressionA = `aData[indexA${x}][componentA${x}]`; + const expressionB = `bData[indexB${x}][componentB${x}]`; + // eslint-disable-next-line no-bitwise + const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`; + return ` + let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)}; + let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let offsetC${x} = ${c.broadcastedIndicesToOffset(`outputIndices${x}`, output)}; + let indexA${x} = offsetA${x} / 4u; + let indexB${x} = offsetB${x} / 4u; + let indexC${x} = offsetC${x} / 4u; + let componentA${x} = offsetA${x} % 4u; + let componentB${x} = offsetB${x} % 4u; + ${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)}); + `; + }; + if (typeOutput === DataType.bool) { + assignment = ` + var data = vec4(0); + ${singleAssignment('data', 0, 'u32')} + ${singleAssignment('data', 1, 'u32')} + ${singleAssignment('data', 2, 'u32')} + ${singleAssignment('data', 3, 'u32')} + outputData[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`; + } else { + assignment = ` + ${singleAssignment('outputData[global_idx]', 0)} + ${singleAssignment('outputData[global_idx]', 1)} + ${singleAssignment('outputData[global_idx]', 2)} + ${singleAssignment('outputData[global_idx]', 3)} + `; + } + } + + return ` + ${shaderHelper.declareVariables(c, a, b, output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)} + ${assignment} + }`; + }; + +const createWhereOpProgramInfo = (metadata: ProgramMetadata, inputs: readonly TensorView[]): ProgramInfo => { + const dimsA = inputs[1].dims; + const dimsB = inputs[2].dims; + const dimsC = inputs[0].dims; + const outputDataType = inputs[1].dataType; + + const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC)); + let outputShape = dimsA; + let outputSize = ShapeUtil.size(dimsA); + // TODO: deal with zero-sized tensors (eg. dims=[1,0]) + + if (isBroadcast) { + const calculatedShape = BroadcastUtil.calcShape(BroadcastUtil.calcShape(dimsA, dimsB, false)!, dimsC, false); + if (!calculatedShape) { + throw new Error('Can\'t perform where op on the given tensors'); + } + outputShape = calculatedShape; + outputSize = ShapeUtil.size(outputShape); + } + + return { + ...metadata, + getShaderSource: (shaderHelper) => + createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType), + outputs: [{dims: outputShape, dataType: outputDataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}) + }; +}; + +const createWhereOpProgramInfoLoader = (inputs: readonly TensorView[], name: string): ProgramInfoLoader => { + const inputTypes = [GpuDataType.default, GpuDataType.default, GpuDataType.default]; + const metadata: ProgramMetadata = {name, inputTypes}; + return {...metadata, get: () => createWhereOpProgramInfo(metadata, inputs)}; +}; + +export const where = (context: ComputeContext): void => { + context.compute(createWhereOpProgramInfoLoader(context.inputs, 'Where')); +}; diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index e5a2d8c2351b8..43f70c23f7193 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -3,20 +3,24 @@ import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; -/** - * tuple elements are: ORT element type; dims; tensor data - */ -export type SerializableTensor = [Tensor.Type, readonly number[], Tensor.DataType]; +export type SerializableTensorMetadata = + [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu']; -/** - * tuple elements are: InferenceSession handle; input names; output names - */ -export type SerializableSessionMetadata = [number, string[], string[]]; +export type GpuBufferMetadata = { + gpuBuffer: Tensor.GpuBufferType; + download?: () => Promise; + dispose?: () => void; +}; -/** - * tuple elements are: modeldata.offset, modeldata.length - */ -export type SerializableModeldata = [number, number]; +export type UnserializableTensorMetadata = + [dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']| + [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned']; + +export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata; + +export type SerializableSessionMetadata = [sessionHandle: number, inputNames: string[], outputNames: string[]]; + +export type SerializableModeldata = [modelDataOffset: number, modelDataLength: number]; interface MessageError { err?: string; @@ -58,10 +62,10 @@ interface MessageReleaseSession extends MessageError { interface MessageRun extends MessageError { type: 'run'; in ?: { - sessionId: number; inputIndices: number[]; inputs: SerializableTensor[]; outputIndices: number[]; + sessionId: number; inputIndices: number[]; inputs: SerializableTensorMetadata[]; outputIndices: number[]; options: InferenceSession.RunOptions; }; - out?: SerializableTensor[]; + out?: SerializableTensorMetadata[]; } interface MesssageEndProfiling extends MessageError { diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 815b223e40379..202209ed3bfed 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -3,7 +3,7 @@ import {Env, env, InferenceSession} from 'onnxruntime-common'; -import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensor} from './proxy-messages'; +import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; import * as core from './wasm-core-impl'; import {initializeWebAssembly} from './wasm-factory'; @@ -22,7 +22,7 @@ const createSessionAllocateCallbacks: Array> = []; const createSessionCallbacks: Array> = []; const releaseSessionCallbacks: Array> = []; -const runCallbacks: Array> = []; +const runCallbacks: Array> = []; const endProfilingCallbacks: Array> = []; const ensureWorker = (): void => { @@ -177,6 +177,10 @@ export const createSessionFinalize = async(modeldata: SerializableModeldata, opt export const createSession = async(model: Uint8Array, options?: InferenceSession.SessionOptions): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + // check unsupported options + if (options?.preferredOutputLocation) { + throw new Error('session option "preferredOutputLocation" is not supported for proxy.'); + } ensureWorker(); return new Promise((resolve, reject) => { createSessionCallbacks.push([resolve, reject]); @@ -202,17 +206,27 @@ export const releaseSession = async(sessionId: number): Promise => { }; export const run = async( - sessionId: number, inputIndices: number[], inputs: SerializableTensor[], outputIndices: number[], - options: InferenceSession.RunOptions): Promise => { + sessionId: number, inputIndices: number[], inputs: TensorMetadata[], outputIndices: number[], + outputs: Array, options: InferenceSession.RunOptions): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + // check inputs location + if (inputs.some(t => t[3] !== 'cpu')) { + throw new Error('input tensor on GPU is not supported for proxy.'); + } + // check outputs location + if (outputs.some(t => t)) { + throw new Error('pre-allocated output tensor is not supported for proxy.'); + } ensureWorker(); - return new Promise((resolve, reject) => { + return new Promise((resolve, reject) => { runCallbacks.push([resolve, reject]); - const message: OrtWasmMessage = {type: 'run', in : {sessionId, inputIndices, inputs, outputIndices, options}}; - proxyWorker!.postMessage(message, core.extractTransferableBuffers(inputs)); + const serializableInputs = inputs as SerializableTensorMetadata[]; // every input is on CPU. + const message: OrtWasmMessage = + {type: 'run', in : {sessionId, inputIndices, inputs: serializableInputs, outputIndices, options}}; + proxyWorker!.postMessage(message, core.extractTransferableBuffers(serializableInputs)); }); } else { - return core.run(sessionId, inputIndices, inputs, outputIndices, options); + return core.run(sessionId, inputIndices, inputs, outputIndices, outputs, options); } }; diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index d8c5ae7886fe4..7bc467449c33a 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -2,16 +2,45 @@ // Licensed under the MIT License. import {readFile} from 'fs'; -import {env, InferenceSession, SessionHandler, Tensor} from 'onnxruntime-common'; +import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {promisify} from 'util'; -import {SerializableModeldata} from './proxy-messages'; +import {SerializableModeldata, TensorMetadata} from './proxy-messages'; import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper'; +import {isGpuBufferSupportedType} from './wasm-common'; let runtimeInitialized: boolean; let runtimeInitializationPromise: Promise|undefined; -export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { +const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { + switch (tensor.location) { + case 'cpu': + return [tensor.type, tensor.dims, tensor.data, 'cpu']; + case 'gpu-buffer': + return [tensor.type, tensor.dims, {gpuBuffer: tensor.gpuBuffer}, 'gpu-buffer']; + default: + throw new Error(`invalid data location: ${tensor.location} for ${getName()}`); + } +}; + +const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { + switch (tensor[3]) { + case 'cpu': + return new Tensor(tensor[0], tensor[2], tensor[1]); + case 'gpu-buffer': { + const dataType = tensor[0]; + if (!isGpuBufferSupportedType(dataType)) { + throw new Error(`not supported data type: ${dataType} for deserializing GPU tensor`); + } + const {gpuBuffer, download, dispose} = tensor[2]; + return Tensor.fromGpuBuffer(gpuBuffer, {dataType, dims: tensor[1], download, dispose}); + } + default: + throw new Error(`invalid data location: ${tensor[3]}`); + } +}; + +export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHandler { private sessionId: number; inputNames: string[]; @@ -74,25 +103,31 @@ export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { inputIndices.push(index); }); + const outputArray: Array = []; const outputIndices: number[] = []; Object.entries(fetches).forEach(kvp => { const name = kvp[0]; - // TODO: support pre-allocated output + const tensor = kvp[1]; const index = this.outputNames.indexOf(name); if (index === -1) { throw new Error(`invalid output '${name}'`); } + outputArray.push(tensor); outputIndices.push(index); }); - const outputs = - await run(this.sessionId, inputIndices, inputArray.map(t => [t.type, t.dims, t.data]), outputIndices, options); + const inputs = + inputArray.map((t, i) => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`)); + const outputs = outputArray.map( + (t, i) => t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null); + + const results = await run(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); - const result: SessionHandler.ReturnType = {}; - for (let i = 0; i < outputs.length; i++) { - result[this.outputNames[outputIndices[i]]] = new Tensor(outputs[i][0], outputs[i][2], outputs[i][1]); + const resultMap: SessionHandler.ReturnType = {}; + for (let i = 0; i < results.length; i++) { + resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); } - return result; + return resultMap; } startProfiling(): void { diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 2659b471733f5..02ff229cc4954 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -88,6 +88,21 @@ const setExecutionProviders = break; case 'webgpu': epName = 'JS'; + if (typeof ep !== 'string') { + const webgpuOptions = ep as InferenceSession.WebGpuExecutionProviderOption; + if (webgpuOptions?.preferredLayout) { + if (webgpuOptions.preferredLayout !== 'NCHW' && webgpuOptions.preferredLayout !== 'NHWC') { + throw new Error(`preferredLayout must be either 'NCHW' or 'NHWC': ${webgpuOptions.preferredLayout}`); + } + const keyDataOffset = allocWasmString('preferredLayout', allocs); + const valueDataOffset = allocWasmString(webgpuOptions.preferredLayout, allocs); + if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !== + 0) { + checkLastError( + `Can't set a session config entry: 'preferredLayout' - ${webgpuOptions.preferredLayout}.`); + } + } + } break; case 'wasm': case 'cpu': diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index 389773f3e8884..b9eff45e890c4 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -164,3 +164,35 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro throw new Error(`unsupported logging level: ${logLevel}`); } }; + +/** + * Check whether the given tensor type is supported by GPU buffer + */ +export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' || + type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32'; + +/** + * Map string data location to integer value + */ +export const dataLocationStringToEnum = (location: Tensor.DataLocation): number => { + switch (location) { + case 'none': + return 0; + case 'cpu': + return 1; + case 'cpu-pinned': + return 2; + case 'texture': + return 3; + case 'gpu-buffer': + return 4; + default: + throw new Error(`unsupported data location: ${location}`); + } +}; + +/** + * Map integer data location to string value + */ +export const dataLocationEnumToString = (location: number): Tensor.DataLocation|undefined => + (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer'] as const)[location]; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index fcca82ab2aa54..5b49a1d4202e3 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -3,10 +3,10 @@ import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata, SerializableTensor} from './proxy-messages'; +import {SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; -import {logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; @@ -60,9 +60,36 @@ export const initRuntime = async(env: Env): Promise => { }; /** - * tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded + * valid data locations for input/output tensors. */ -type SessionMetadata = [number, number[], number[]]; +type SupportedTensorDataLocationForInputOutput = 'cpu'|'cpu-pinned'|'gpu-buffer'; + +type IOBindingState = { + /** + * the handle of IO binding. + */ + readonly handle: number; + + /** + * the preferred location for each output tensor. + * + * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer'. + */ + readonly outputPreferredLocations: readonly SupportedTensorDataLocationForInputOutput[]; + + /** + * enum value of the preferred location for each output tensor. + */ + readonly outputPreferredLocationsEncoded: readonly number[]; +}; + +/** + * tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded; bindingState + */ +type SessionMetadata = [ + inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], + bindingState: IOBindingState|null +]; const activeSessions = new Map(); @@ -92,6 +119,7 @@ export const createSessionFinalize = let sessionHandle = 0; let sessionOptionsHandle = 0; + let ioBindingHandle = 0; let allocs: number[] = []; const inputNamesUTF8Encoded = []; const outputNamesUTF8Encoded = []; @@ -108,6 +136,7 @@ export const createSessionFinalize = const inputNames = []; const outputNames = []; + const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = []; for (let i = 0; i < inputCount; i++) { const name = wasm._OrtGetInputName(sessionHandle, i); if (name === 0) { @@ -122,15 +151,45 @@ export const createSessionFinalize = checkLastError('Can\'t get an output name.'); } outputNamesUTF8Encoded.push(name); - outputNames.push(wasm.UTF8ToString(name)); + const nameString = wasm.UTF8ToString(name); + outputNames.push(nameString); + + if (!BUILD_DEFS.DISABLE_WEBGPU) { + const location = typeof options?.preferredOutputLocation === 'string' ? + options.preferredOutputLocation : + options?.preferredOutputLocation?.[nameString] ?? 'cpu'; + if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { + throw new Error(`Not supported preferred output location: ${location}.`); + } + outputPreferredLocations.push(location); + } + } + + // use IO binding only when at least one output is preffered to be on GPU. + let bindingState: IOBindingState|null = null; + if (!BUILD_DEFS.DISABLE_WEBGPU && outputPreferredLocations.some(l => l === 'gpu-buffer')) { + ioBindingHandle = wasm._OrtCreateBinding(sessionHandle); + if (ioBindingHandle === 0) { + checkLastError('Can\'t create IO binding.'); + } + + bindingState = { + handle: ioBindingHandle, + outputPreferredLocations, + outputPreferredLocationsEncoded: outputPreferredLocations.map(l => dataLocationStringToEnum(l)), + }; } - activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded]); + activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]); return [sessionHandle, inputNames, outputNames]; } catch (e) { inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + if (ioBindingHandle !== 0) { + wasm._OrtReleaseBinding(ioBindingHandle); + } + if (sessionHandle !== 0) { wasm._OrtReleaseSession(sessionHandle); } @@ -161,7 +220,13 @@ export const releaseSession = (sessionId: number): void => { if (!session) { throw new Error(`cannot release session. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + + if (ioBindingState) { + wasm._OrtReleaseBinding(ioBindingState.handle); + } + + wasm.jsepUnregisterBuffers?.(sessionId); inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -169,18 +234,84 @@ export const releaseSession = (sessionId: number): void => { activeSessions.delete(sessionId); }; +const prepareInputOutputTensor = + (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): + void => { + if (!tensor) { + tensorHandles.push(0); + return; + } + + const wasm = getInstance(); + + const dataType = tensor[0]; + const dims = tensor[1]; + const location = tensor[3]; + + let rawData: number; + let dataByteLength: number; + + if (dataType === 'string' && location === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); + } + + if (location === 'gpu-buffer') { + const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; + dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); + } else { + const data = tensor[2]; + + if (Array.isArray(data)) { + // string tensor + dataByteLength = 4 * data.length; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + let dataIndex = rawData / 4; + for (let i = 0; i < data.length; i++) { + if (typeof data[i] !== 'string') { + throw new TypeError(`tensor data at index ${i} is not a string`); + } + wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); + } + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } + + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + let dimIndex = dimsOffset / 4; + dims.forEach(d => wasm.HEAP32[dimIndex++] = d); + const tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(location)); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + } + tensorHandles.push(tensor); + } finally { + wasm.stackRestore(stack); + } + }; + /** * perform inference run */ export const run = async( - sessionId: number, inputIndices: number[], inputs: SerializableTensor[], outputIndices: number[], - options: InferenceSession.RunOptions): Promise => { + sessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], + outputTensors: Array, options: InferenceSession.RunOptions): Promise => { const wasm = getInstance(); const session = activeSessions.get(sessionId); if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; const inputCount = inputIndices.length; const outputCount = outputIndices.length; @@ -188,171 +319,200 @@ export const run = async( let runOptionsHandle = 0; let runOptionsAllocs: number[] = []; - const inputValues: number[] = []; - const inputAllocs: number[] = []; + const inputTensorHandles: number[] = []; + const outputTensorHandles: number[] = []; + const inputOutputAllocs: number[] = []; + + const beforeRunStack = wasm.stackSave(); + const inputValuesOffset = wasm.stackAlloc(inputCount * 4); + const inputNamesOffset = wasm.stackAlloc(inputCount * 4); + const outputValuesOffset = wasm.stackAlloc(outputCount * 4); + const outputNamesOffset = wasm.stackAlloc(outputCount * 4); try { [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); // create input tensors for (let i = 0; i < inputCount; i++) { - const dataType = inputs[i][0]; - const dims = inputs[i][1]; - const data = inputs[i][2]; - - let dataOffset: number; - let dataByteLength: number; - - if (Array.isArray(data)) { - // string tensor - dataByteLength = 4 * data.length; - dataOffset = wasm._malloc(dataByteLength); - inputAllocs.push(dataOffset); - let dataIndex = dataOffset / 4; - for (let i = 0; i < data.length; i++) { - if (typeof data[i] !== 'string') { - throw new TypeError(`tensor data at index ${i} is not a string`); - } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], inputAllocs); - } - } else { - dataByteLength = data.byteLength; - dataOffset = wasm._malloc(dataByteLength); - inputAllocs.push(dataOffset); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), dataOffset); - } + prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]); + } - const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); - try { - let dimIndex = dimsOffset / 4; - dims.forEach(d => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(dataType), dataOffset, dataByteLength, dimsOffset, dims.length); - if (tensor === 0) { - checkLastError(`Can't create tensor for input[${i}].`); - } - inputValues.push(tensor); - } finally { - wasm.stackRestore(stack); - } + // create output tensors + for (let i = 0; i < outputCount; i++) { + prepareInputOutputTensor( + outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]); + } + + let inputValuesIndex = inputValuesOffset / 4; + let inputNamesIndex = inputNamesOffset / 4; + let outputValuesIndex = outputValuesOffset / 4; + let outputNamesIndex = outputNamesOffset / 4; + for (let i = 0; i < inputCount; i++) { + wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i]; + wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]]; + } + for (let i = 0; i < outputCount; i++) { + wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i]; + wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; } - const beforeRunStack = wasm.stackSave(); - const inputValuesOffset = wasm.stackAlloc(inputCount * 4); - const inputNamesOffset = wasm.stackAlloc(inputCount * 4); - const outputValuesOffset = wasm.stackAlloc(outputCount * 4); - const outputNamesOffset = wasm.stackAlloc(outputCount * 4); - - try { - let inputValuesIndex = inputValuesOffset / 4; - let inputNamesIndex = inputNamesOffset / 4; - let outputValuesIndex = outputValuesOffset / 4; - let outputNamesIndex = outputNamesOffset / 4; + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; + + if (inputNamesUTF8Encoded.length !== inputCount) { + throw new Error(`input count from feeds (${ + inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`); + } + + // process inputs for (let i = 0; i < inputCount; i++) { - wasm.HEAPU32[inputValuesIndex++] = inputValues[i]; - wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]]; + const index = inputIndices[i]; + const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]); + if (errorCode !== 0) { + checkLastError(`Can't bind input[${i}] for session=${sessionId}.`); + } } + + // process pre-allocated outputs for (let i = 0; i < outputCount; i++) { - wasm.HEAPU32[outputValuesIndex++] = 0; - wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; + const index = outputIndices[i]; + const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated. + + if (location) { + // output is pre-allocated. bind the tensor. + const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0); + if (errorCode !== 0) { + checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`); + } + } else { + // output is not pre-allocated. reset preferred location. + const errorCode = + wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], 0, outputPreferredLocationsEncoded[index]); + if (errorCode !== 0) { + checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`); + } + } } + } - // jsepOnRunStart is only available when JSEP is enabled. - wasm.jsepOnRunStart?.(sessionId); + let errorCode: number; - // support RunOptions - let errorCode = wasm._OrtRun( + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + errorCode = await wasm._OrtRunWithBinding( + sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle); + } else { + errorCode = await wasm._OrtRun( sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount, outputValuesOffset, runOptionsHandle); + } - const runPromise = wasm.jsepRunPromise; - if (runPromise) { - // jsepRunPromise is a Promise object. It is only available when JSEP is enabled. - // - // OrtRun() is a synchrnous call, but it internally calls async functions. Emscripten's ASYNCIFY allows it to - // work in this way. However, OrtRun() does not return a promise, so when code reaches here, it is earlier than - // the async functions are finished. - // - // To make it work, we created a Promise and resolve the promise when the C++ code actually reaches the end of - // OrtRun(). If the promise exists, we need to await for the promise to be resolved. - errorCode = await runPromise; - } + if (errorCode !== 0) { + checkLastError('failed to call OrtRun().'); + } - const jsepOnRunEnd = wasm.jsepOnRunEnd; - if (jsepOnRunEnd) { - // jsepOnRunEnd is only available when JSEP is enabled. - // - // it returns a promise, which is resolved or rejected when the following async functions are finished: - // - collecting GPU validation errors. - await jsepOnRunEnd(sessionId); + const output: TensorMetadata[] = []; + + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + if (tensor === outputTensorHandles[i]) { + // output tensor is pre-allocated. no need to copy data. + output.push(outputTensors[i]!); + continue; } - const output: SerializableTensor[] = []; + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); - if (errorCode !== 0) { - checkLastError('failed to call OrtRun().'); - } + let keepOutputTensor = false; + let type: Tensor.Type|undefined, dataOffset = 0; + try { + const errorCode = wasm._OrtGetTensorData( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + if (errorCode !== 0) { + checkLastError(`Can't access output tensor data on index ${i}.`); + } + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + const size = dims.reduce((a, b) => a * b, 1); + type = tensorDataTypeEnumToString(dataType); - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); + const preferredLocation = ioBindingState?.outputPreferredLocations[outputIndices[i]]; - let type: Tensor.Type|undefined, dataOffset = 0; - try { - errorCode = wasm._OrtGetTensorData( - tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); - if (errorCode !== 0) { - checkLastError(`Can't access output tensor data on index ${i}.`); + if (type === 'string') { + if (preferredLocation === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); } - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); } - wasm._OrtFree(dimsOffset); - - const size = dims.length === 0 ? 1 : dims.reduce((a, b) => a * b); - type = tensorDataTypeEnumToString(dataType); - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); + output.push([type, dims, stringData, 'cpu']); + } else { + // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU + // tensor for it. There is no mapping GPU buffer for an empty tensor. + if (preferredLocation === 'gpu-buffer' && size > 0) { + const gpuBuffer = wasm.jsepGetBuffer(dataOffset); + const elementSize = getTensorElementSize(dataType); + if (elementSize === undefined || !isGpuBufferSupportedType(type)) { + throw new Error(`Unsupported data type: ${type}`); } - output.push([type, dims, stringData]); + + // do not release the tensor right now. it will be released when user calls tensor.dispose(). + keepOutputTensor = true; + + output.push([ + type, dims, { + gpuBuffer, + download: wasm.jsepCreateDownloader(gpuBuffer, size * elementSize, type), + dispose: () => { + wasm._OrtReleaseTensor(tensor); + } + }, + 'gpu-buffer' + ]); } else { const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); const data = new typedArrayConstructor(size); new Uint8Array(data.buffer, data.byteOffset, data.byteLength) .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); - output.push([type, dims, data]); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); + output.push([type, dims, data, 'cpu']); } + } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); + } + if (!keepOutputTensor) { wasm._OrtReleaseTensor(tensor); } } + } - return output; - } finally { - wasm.stackRestore(beforeRunStack); + if (ioBindingState) { + wasm._OrtClearBoundOutputs(ioBindingState.handle); } + + return output; } finally { - inputValues.forEach(v => wasm._OrtReleaseTensor(v)); - inputAllocs.forEach(p => wasm._free(p)); + wasm.stackRestore(beforeRunStack); + + inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach(p => wasm._free(p)); if (runOptionsHandle !== 0) { wasm._OrtReleaseRunOptions(runOptionsHandle); @@ -380,11 +540,11 @@ export const endProfiling = (sessionId: number): void => { wasm._OrtFree(profileFileName); }; -export const extractTransferableBuffers = (tensors: readonly SerializableTensor[]): ArrayBufferLike[] => { +export const extractTransferableBuffers = (tensors: readonly SerializableTensorMetadata[]): ArrayBufferLike[] => { const buffers: ArrayBufferLike[] = []; for (const tensor of tensors) { const data = tensor[2]; - if (!Array.isArray(data) && data.buffer) { + if (!Array.isArray(data) && 'buffer' in data) { buffers.push(data.buffer); } } diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index f90f568879146..31bca8b94306d 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -51,6 +51,10 @@ Options: -P[=<...>], --perf[=<...>] Generate performance number. Cannot be used with flag --debug. This flag can be used with a number as value, specifying the total count of test cases to run. The test cases may be used multiple times. Default value is 10. -c, --file-cache Enable file cache. + -i=<...>, --io-binding=<...> Specify the IO binding testing type. Should be one of the following: + none (default) + gpu-tensor use pre-allocated GPU tensors for inputs and outputs + gpu-location use pre-allocated GPU tensors for inputs and set preferredOutputLocation to 'gpu-buffer' *** Session Options *** -u=<...>, --optimized-model-file-path=<...> Specify whether to dump the optimized model. @@ -109,6 +113,7 @@ export declare namespace TestRunnerCliArgs { type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'xnnpack'|'webnn'; type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'; type BundleMode = 'prod'|'dev'|'perf'; + type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location'; } export interface TestRunnerCliArgs { @@ -140,6 +145,8 @@ export interface TestRunnerCliArgs { */ bundleMode: TestRunnerCliArgs.BundleMode; + ioBindingMode: TestRunnerCliArgs.IOBindingMode; + logConfig: Test.Config['log']; /** @@ -326,7 +333,11 @@ function parseWebgpuFlags(args: minimist.ParsedArgs): Partial { if (profilingMode !== undefined && profilingMode !== 'off' && profilingMode !== 'default') { throw new Error('Flag "webgpu-profiling-mode" is invalid'); } - return {profilingMode}; + const validateInputContent = args['webgpu-validate-input-content']; + if (validateInputContent !== undefined && typeof validateInputContent !== 'boolean') { + throw new Error('Flag "webgpu-validate-input-content" is invalid'); + } + return {profilingMode, validateInputContent}; } function parseGlobalEnvFlags(args: minimist.ParsedArgs): NonNullable { @@ -416,6 +427,13 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs logConfig.push({category: 'TestRunner.Perf', config: {minimalSeverity: 'verbose'}}); } + // Option: -i=<...>, --io-binding=<...> + const ioBindingArg = args['io-binding'] || args.i; + const ioBindingMode = (typeof ioBindingArg !== 'string') ? 'none' : ioBindingArg; + if (['none', 'gpu-tensor', 'gpu-location'].indexOf(ioBindingMode) === -1) { + throw new Error(`not supported io binding mode ${ioBindingMode}`); + } + // Option: -u, --optimized-model-file-path const optimizedModelFilePath = args['optimized-model-file-path'] || args.u || undefined; if (typeof optimizedModelFilePath !== 'undefined' && typeof optimizedModelFilePath !== 'string') { @@ -455,6 +473,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs npmlog.verbose('TestRunnerCli.Init', ` Env: ${env}`); npmlog.verbose('TestRunnerCli.Init', ` Debug: ${debug}`); npmlog.verbose('TestRunnerCli.Init', ` Backend: ${backend}`); + npmlog.verbose('TestRunnerCli.Init', ` IO Binding Mode: ${ioBindingMode}`); npmlog.verbose('TestRunnerCli.Init', 'Parsing commandline arguments... DONE'); return { @@ -467,6 +486,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs logConfig, profile, times: perf ? times : undefined, + ioBindingMode: ioBindingMode as TestRunnerCliArgs['ioBindingMode'], optimizedModelFilePath, graphOptimizationLevel: graphOptimizationLevel as TestRunnerCliArgs['graphOptimizationLevel'], fileCache, diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index f3764e63fcf45..d8fecec1b8084 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -257,7 +257,7 @@ async function main() { times?: number): Test.ModelTest { if (times === 0) { npmlog.verbose('TestRunnerCli.Init.Model', `Skip test data from folder: ${testDataRootFolder}`); - return {name: path.basename(testDataRootFolder), backend, modelUrl: '', cases: []}; + return {name: path.basename(testDataRootFolder), backend, modelUrl: '', cases: [], ioBinding: args.ioBindingMode}; } let modelUrl: string|null = null; @@ -323,6 +323,16 @@ async function main() { } } + let ioBinding: Test.IOBindingMode; + if (backend !== 'webgpu' && args.ioBindingMode !== 'none') { + npmlog.warn( + 'TestRunnerCli.Init.Model', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`); + ioBinding = 'none'; + } else { + ioBinding = args.ioBindingMode; + } + + npmlog.verbose('TestRunnerCli.Init.Model', 'Finished preparing test data.'); npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); npmlog.verbose('TestRunnerCli.Init.Model', ` Model file: ${modelUrl}`); @@ -330,7 +340,7 @@ async function main() { npmlog.verbose('TestRunnerCli.Init.Model', ` Test set(s): ${cases.length} (${caseCount})`); npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); - return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases}; + return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases, ioBinding}; } function tryLocateModelTestFolder(searchPattern: string): string { @@ -390,6 +400,13 @@ async function main() { for (const test of tests) { test.backend = backend; test.opset = test.opset || {domain: '', version: MAX_OPSET_VERSION}; + if (backend !== 'webgpu' && args.ioBindingMode !== 'none') { + npmlog.warn( + 'TestRunnerCli.Init.Op', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`); + test.ioBinding = 'none'; + } else { + test.ioBinding = args.ioBindingMode; + } } npmlog.verbose('TestRunnerCli.Init.Op', 'Finished preparing test data.'); npmlog.verbose('TestRunnerCli.Init.Op', '==============================================================='); diff --git a/js/web/test/data/ops/bias-add.jsonc b/js/web/test/data/ops/bias-add.jsonc new file mode 100644 index 0000000000000..e89c5dd81cc23 --- /dev/null +++ b/js/web/test/data/ops/bias-add.jsonc @@ -0,0 +1,874 @@ +[ + { + "name": "BiasAdd", + "operator": "BiasAdd", + "attributes": [], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "bias add [2,2,320]x[320]x[2,2,320]", + "inputs": [ + { + "data": [ + -0.43078827160569055, -1.3343044914862014, -1.3875186857333706, 0.7550491056943578, -0.015677493769832296, + 1.2761569600658005, -0.033221806310058, 1.3590872185896297, 1.706516873231478, 1.429932535224122, + -0.39598259309643424, -0.6735725816667406, 0.5551536414690581, 1.9642482005713608, 0.1481069760277398, + -0.6706041659325672, 1.7626582023316804, 1.8587314247747715, 0.6761988056588422, -1.2865903673587722, + 0.17451384248240753, -1.468579522325724, -1.4426371855482047, -1.101469412150644, 0.32065561169114787, + -1.7280217475756476, 0.8472414982170147, 1.089925472742105, -1.9188568763699578, -1.6897769809368715, + -0.22443847491362767, -1.592027916773227, -1.403545711986184, -0.992398507333391, 1.501364958055385, + 1.1969883087000177, -1.742362862568375, -1.3122082038099396, -1.602285316113413, 1.7122406231186638, + -1.7331729420495074, -1.1660018974894326, 1.8466385944978407, 0.8059189037532883, 0.42819875266849383, + 1.2622036788694384, 1.09253289978553, 1.7169775828443141, 0.8965924110901806, 1.731515754612638, + -0.710260858323644, 0.6949611010518701, 1.9392089898271783, -0.06855145845621369, 1.831602787077924, + 1.919258334566825, -0.1750228495679611, -1.385684174878624, 0.1864456393673155, -0.7627497248975219, + 0.5460459705476177, 1.0151263124631544, 1.0476468224745439, -1.4662684764200389, 0.8737230293064417, + -1.5018396956497346, 1.7171541562949724, 0.33089656971471637, -0.4497250989135697, 0.03155539207838842, + 0.051811401515028166, -0.6417182241860351, 0.2892848388089577, 1.7313934166665597, 0.07283972370973935, + -1.1648958342544793, 0.31235891981372976, 0.8667946196377363, -1.8955234844794324, -0.7935013829177411, + 1.7461436471789726, -0.7863083195781497, -0.24062975690975286, 0.5319240675275907, -1.093540712132092, + -0.08133035383567844, 0.24368896952864638, -0.7817379372747375, 0.09213174790435374, 0.7008130043325052, + 1.2549161788773802, -1.3617356636970648, 0.9203525826936687, -0.17473559280037954, 0.8615190853991157, + 1.7227407352985065, -1.6408918805405301, -1.2530895034442668, -1.561750342643168, 0.9122283638044753, + 1.52061857140668, -1.757751307093299, 1.9174275778315568, 0.8751522548565935, -0.40373986100628567, + -0.9081360776169474, -1.0200280593711373, 1.8684056043061243, -0.8656213251095854, 1.2016289009297738, + -0.6420074354987753, 0.8277984540094971, -0.7794691695697935, -0.3682767999575489, -0.8431108361591697, + -1.8337997630639924, -1.497840701527875, 0.3582228355620778, -1.7648364421707035, 1.2295209603744386, + -1.4630737828552771, -1.0901242069115051, 0.16944674169381635, -1.1367614122844483, 1.1108472254071025, + -1.823264607399035, -1.9393711989501217, 1.5655717488598917, -1.2791502509294475, 0.5264599738022175, + 1.2793776364048464, 0.6643881105536087, -1.8021867427977876, 1.2507197673418382, 1.5096918274472868, + -0.35787327298378635, -0.8142338946526149, 0.05554038558981045, -0.8292295885917778, -1.9802571383548457, + 0.15711486642700123, 1.0606485766817002, -0.75648306593627, 0.7821766109776371, 1.5576015081398582, + 1.274065256431233, -0.11137614589209388, 1.5523940512159609, -0.9225168834642474, -1.52372850136818, + -0.09285110141468689, 0.951570914266842, 1.4257064421442056, -1.1151817450836203, -1.0405979771695408, + -0.3512294660068509, 0.5544676624698717, -1.1641268738929513, -1.1024667332826468, -0.8470834386045532, + 1.0200780368213875, -0.4111438148101385, 1.5060564363688043, -0.9264136525377458, -1.1144625610145829, + -0.8955596783917263, 1.1298063812792858, 1.1669470152018224, -0.9624856700584665, -0.5105634488042536, + 0.3350150739180062, 1.8252865056498973, 1.8318837072890988, 1.247181799708187, 0.953253999223505, + 0.33583064317486144, 0.8060935063640473, 0.025351359439778953, -1.9020713785704064, -1.4395100399524523, + -1.4434709436802136, 1.1887525549807645, -0.7404500263025655, 0.43567229698745447, 0.697905733863994, + 1.4955234099092705, 1.8798334809531436, 0.28521714680739496, -1.8247616864327485, -0.8899636892784697, + -1.9902986052685518, -0.8612351129594646, 1.4473654540376284, -0.027164144754272534, -1.3409554794285556, + 1.8216393608213144, 1.6287547173387606, 0.34546830263617245, -1.8554217975954366, -0.1599936883574884, + -1.1517327823070804, -1.1043403244478194, 0.9959357690900248, 0.47246889924079394, 1.3040046448975406, + 1.3219564442035914, 1.5522794215260332, -1.3132538270303629, -0.6568357397541558, -0.3166961264888011, + 1.155856117058299, -0.2856093347506521, 0.37282953339474023, 1.3624325643069906, -0.5961360461760625, + 0.15430022597664106, 1.459554220213934, -0.4578169766295437, -1.5520231063597194, -0.9692019226485984, + -0.7868623417517249, -0.47648838559948903, 0.521518493577898, -0.43202750995667394, -0.2800695092149823, + -0.062073964072573595, -1.280928976710845, 1.653227202229826, -1.3266797422187144, -1.8026162116413413, + -1.6342686696555697, -1.8535208355884771, -0.6740602051435252, -1.1112082500330684, -0.2784013709506814, + -0.38477554625820254, -1.8030977138882127, 1.0792754568917307, 1.7171188079371378, 1.0025920437036717, + 0.4909327973461517, 1.7583392556519017, -0.9513053055982352, 0.9836640607205354, 0.676230592061307, + -1.532719596602865, 1.7094356339043584, 1.275239052465298, -0.7097409099122736, -0.07447816392320394, + 0.6054981041837371, -1.368606109962121, 0.5491138352978151, 0.9607716616110995, -0.8472145355857039, + -1.5962071764683028, -1.3945278279448647, 1.6804737734010837, -0.2799307641611124, -1.5199944467651827, + -0.6284155111805374, -1.387427601467464, -1.3689101475786654, 0.5479648937942034, -0.35058122336125397, + -0.01913895405776156, 1.9421209900010803, 1.521339125582216, 0.44317460300853284, 1.4673321860144632, + 1.0844686121268872, 1.5454854860881193, -0.10676223292802778, 1.304830685737219, 1.5144316149285784, + -0.9206202207515908, 1.800878209478248, -1.4022748170800003, 1.8475728691446207, -0.6049511260966813, + 1.3794138268485217, -0.5504308384418284, -1.308589620881369, -1.5124914949500843, -0.9843890898875198, + -0.47517812559034134, 0.019190610059508728, 1.3102818494016093, 0.02046481460734384, 0.8966955765481393, + -1.9544981815442322, 1.4818794231229235, 1.701164813487166, -0.1853726987447688, -0.9284129805631327, + -0.6665475855012435, 1.7174739766102922, -0.6866660477705144, 1.3741442823600236, 1.3144332615465366, + -1.6821345841035624, 1.4369774381581397, -0.40113703543299106, 1.2155358501458204, 0.3746014766056245, + -1.2797835093659478, -1.2733146708696594, 1.7592472733050606, 0.7411515442284751, -0.8417009900350942, + 0.43950532643238294, -0.4849697519720575, 0.4561330711561924, -0.9993716519294349, 1.2993450426953688, + 0.6077681589510648, -1.1884025219126162, 0.7423783440021925, 0.2164111466551839, -1.8404401979886122, + -1.6417250174486764, 0.33169566624656355, 1.4670345153781224, 0.05770292022609169, -0.7574653293187339, + -1.8332930679907182, 0.30545703812352265, 1.2645746651782588, 0.06851662509789058, -1.5007737660843672, + -0.7895770838504896, -0.16973917251686643, -1.1309776011957018, -0.5245832860094533, 1.2865060486729964, + -0.6877713804065921, 0.8603087489217494, -1.6760835939805823, -0.3421307621860663, -0.08163355960848584, + 0.8960283575910557, -0.07414831058651483, 0.06552900030788233, -0.28343554832982676, 0.8580233436387052, + 0.9756624200728901, -0.31454210863680565, -1.3961591961928228, 1.4357660237613095, 0.995011454280724, + 1.7963312095595976, 0.44298210926821735, -0.4481873251636479, 0.04661154364950004, -1.0205093044441567, + -1.416501713829339, -1.5609202237237856, -1.9678529588325686, 0.054492766820211536, 1.9234783545287293, + -1.730337423937696, 1.0002270533854967, 0.912626942921996, 1.6082280213507438, -0.0007848079092749316, + -0.16260961343095381, -0.4951675344800499, -1.358161813033476, 0.0952066467959174, 1.1451260564923738, + 0.5658192251004577, 1.638639222136547, 1.4561159943128157, 1.394720518360975, -1.6647338550020887, + -0.35227744791908844, -1.4997218325299349, 1.7269852905741923, -1.6265531305359326, 0.08177780426753678, + -0.2548418092940219, -0.5491200865930868, -1.3448390466679427, -1.4669459626844805, 1.6556888857778747, + 0.9797534441152536, 0.011435434178378223, -1.0444316805293496, -0.21951923449225497, 1.2775877064028922, + 1.3912532203430734, -0.41034439679528845, 1.6871573166908078, -0.2949753341249375, -0.19366027881779058, + -0.8063505536355873, -1.9683493413513462, 1.1343233715358245, 0.5863358659604163, -1.7161271586916902, + -1.1511237584065572, -1.026334237004506, 0.051935455424261256, -0.47999388673650234, 1.3653965992268642, + -1.6246601676055645, 1.0720439180368135, 1.7513891096415195, -1.9557074487732296, -0.6497504028268937, + 1.5657013366841293, 1.480374539039186, 0.6254160749725743, -0.1386074990036681, 1.4879074098885399, + 0.8608723264201785, 1.3707693807951804, -1.9599278528285744, -0.08178821222314703, -0.8801306486657889, + 0.2672127121047323, -1.0905213718834403, 1.2845923771997692, 0.7483674434994203, 1.692846569262815, + -1.4081461907911006, 0.4847990697153808, -1.384710972149036, -0.48477323237460457, 0.5075573472017414, + -0.469975647682892, -1.7330175330396296, 0.4900846176343876, -0.5504384887286493, -1.3105683633488505, + 1.7832520063588246, -0.008133111014310579, -0.4117632340667363, -1.364428903076842, 1.1502921666750634, + -1.297368143184646, 1.2657228492344146, -0.8339630526727912, 1.6036841813714835, 1.7388392748246462, + 1.0764044765977916, 0.8635247292187316, 0.8915437658945047, 0.048593859303998954, 0.6397243841879581, + -0.32613694518247716, -0.5352551370184315, 1.0102009591194312, 1.2995349433229801, 0.873915978492783, + 0.28688858188963984, 1.0299796543068158, -0.21558135465200134, -0.8175153872775942, -0.2228988257546698, + 1.3651108393401516, 1.7419897496979102, -0.4287431897946634, -1.7210342396761868, -1.1328295428945232, + 0.45940815613566865, 1.225728222018235, -1.8492110470394127, 0.11914217556997819, -0.5298005043189082, + 0.2548881953681379, 1.4159379393782983, 1.4882705620355736, -1.778746106600524, -0.1322461839301079, + 0.959787258509837, -1.2087764883807948, -0.656679223823172, -0.8052024037155077, 0.28857972466710535, + -1.9964605007110707, -1.618829946697912, 0.0003216720189040956, 1.3814498186817072, 1.267459171727996, + 1.7808899740756745, -0.026284995100580133, 1.8415787116931153, 1.588128771971907, -1.379239932131604, + 1.4954639033867787, -0.5365530743272542, -1.689396586297355, -1.5263365207258657, -0.15914036263121556, + -0.2236484993504826, -0.233339567785662, -0.6479528728470942, 1.7496861376407091, 1.9658328381866337, + 0.7465900143344948, 0.9538445157519684, -1.233501367636971, 1.7832767928842772, -0.05854857072903297, + -1.4741131369232807, -0.8878273465278168, 0.7857197910561684, -1.3138933804749309, -1.6218392819603826, + -1.5976871654455351, 0.7862199318904839, 0.7930452731942381, -0.3851655998300325, 1.454630867585033, + 0.07699835053088844, 1.9635265550996897, -0.9018052833921555, -1.3260562107679466, 1.591219679914813, + -1.9325066606486736, 1.6231422543543284, 0.06461353520395186, 1.938690006387282, 1.18237528878611, + 0.08913168691813134, -0.24910533784113476, -0.24270271174439095, -0.058892972285330636, + 0.3017252069250542, -1.6443343279597213, -0.35086444759952506, 1.2568748876859326, 1.2564694463685155, + 0.5001323821158774, 1.6753738435417835, 0.12536895878538168, 0.11445707906508318, -1.4674734226308148, + -1.664480357904826, 1.2455086575046002, 0.016219310045000768, 0.6915540154848818, 0.8863957336976913, + -1.902221644800517, 1.1943261373616156, 1.5343838485344943, 0.26951920872184143, 1.1650565743686334, + 0.8796643687219943, -0.9914368645961655, -0.4140197685853506, -0.31070218781242964, 0.18576121794650557, + -1.8127899542033798, 0.7364855692315055, 0.61436653795095, -1.340047295554938, -0.9787870418028497, + 1.467065674277963, -1.3989058830195544, 0.8760353392917564, 1.103715585163978, 1.5993888842634876, + -1.4487285544801765, 1.0735445052923094, -1.7221746626643117, 1.3406229033114787, -0.31808486639617684, + 0.18778714988957468, 0.3360116223452154, 0.9957693335536861, -0.7082173962327145, 0.13563127391102814, + -0.6103565638388746, -0.27366984218234336, 1.670182285266585, 0.7450288088790513, -0.7376055316855803, + -1.9155349843040552, -0.5647372929979877, 0.38985718472222164, -1.0859956865362559, -1.2774336101522819, + 1.6716179824020179, -1.8248740137755952, -0.2462417745015948, -1.6330831414998679, 1.2633365236447727, + 1.8405595764836615, -1.5041331151923663, 0.7795364178340263, 0.30228912276856246, -1.4199604845773672, + -1.9675070264789785, -1.5847567364669262, 0.8417404542568088, 1.411592224813667, -0.7798761695294161, + 0.3836369752883826, -1.8656778968593093, -0.5928302815127546, -1.9170044216421092, -0.05015509660959072, + -0.2949046704017606, 0.7706683291518104, -1.0318399043402815, 1.673370077717097, -0.014510791777835763, + -1.7993886007705928, -1.6310599365379153, 0.47517801288072725, 1.7067820519317198, 0.2398133204231856, + 0.8136128605181492, 1.7661699247824538, -1.7400434804598, 1.6440515677939507, 1.0630270197569356, + 1.604276559525923, -0.6932776253265054, -0.5527120052345769, -0.15050627035753017, 1.9862178971553677, + -0.015247839923115514, -1.973561045412188, -0.7527150187762626, -1.8069828707308844, -1.0877483004326844, + 1.5632037263398146, -0.7907173278699853, 0.020757273269920162, -0.9162761989634776, -0.871619440202827, + 0.9274209555443278, 0.2788106849606935, -1.139267389509799, 1.8554359980704245, -1.2900715232363025, + 1.5740910757851019, 1.2022687452616108, -1.3263358341556861, 1.705785771419051, -0.2966767936771886, + 1.0184489507711243, 0.9331775150786781, -0.09496306724418613, 1.6169644094277178, -0.43071270921050253, + 0.00617695499471882, 1.9046927055754255, -0.3800897035987143, -1.2677566100047573, -1.6177276055783159, + 0.41451247814580316, 1.9136674265154827, -1.4827469552090689, -1.7754603228705612, -0.6490018494468908, + 1.0832837177200183, 1.5236866225466352, -1.3663768537968632, 0.3345867923992154, -1.3264457751116838, + 1.1302163417527789, 0.7987367218594539, -0.16727587952795364, -0.7323977510251147, -1.3812007830618693, + 0.7210535269663234, 1.1937695850586403, 1.9659603230065574, -0.12903581925605057, 0.22599793243849042, + -1.055270780364089, 1.0726029698666473, -1.8921222720165147, 1.232307457573829, -1.6501289040499376, + -1.2203689042981933, -1.5782693671316723, 0.84090024006949, 1.1443588250664956, -0.8862909159261179, + -0.34912733758720993, 0.15172846490974035, 0.6801601864117464, -1.3948321257969702, 1.9080269171533608, + -1.851455291444804, -0.8525139187927957, -0.9629354426957466, -0.3970802170728076, 0.2714456125847784, + 1.9355349765888947, 1.643295056500194, 0.461049347109487, -0.819228054522112, -0.7773196244370615, + 0.27305821451894285, -0.5007383808686932, -0.5307070901422906, -1.6087255013287924, 1.708746273758031, + -0.18903643771546896, 0.4237658727537639, 0.7912222637914885, -0.06713232018529425, -0.9303076017910259, + -0.9715788400465986, 1.9239773722802864, -0.13605413264730704, 1.63369388301102, -1.4098364226383078, + -1.8177000155999794, -1.0786108587058756, 1.1467184875003493, 1.6942868827942474, 0.3743937965110735, + -0.5205992959118086, 1.9008470494268277, -0.33613881942363744, 1.408798355509151, 1.4374267676314432, + -1.6266431395044751, -1.2795792184706105, -1.6942133113310183, 0.44536668255477885, -0.5438457389607523, + -0.10839952389615792, -0.741753360770752, 1.9252204052044393, -1.298683868110981, -0.6674482773778925, + 1.7250903096066814, -0.31851968039415723, 0.09917784413898989, -1.2300134366181839, -1.6661221342983064, + -1.7392500850180967, 1.1202143474944348, -0.008191271742616912, -1.9527622539466778, -1.112511983718873, + 0.8023412387190376, -0.8667541685114779, 1.5324025634620666, -0.885176692529595, 0.28103075310068526, + 0.5555445946177473, 0.0746344154810279, 1.8279995093264354, 1.589298526844452, -1.1728977455798502, + -1.2266193029402643, -0.08324504041088154, -1.5997622864209005, -1.9929554997781622, 1.649795503298476, + -1.825621659313522, 1.2037940913995477, -0.1478578137836779, -1.791209833936434, -1.5875938503209248, + -1.54554126750838, 0.22767426425123372, -0.9730859655005233, 1.01498892879457, -0.14002274071109522, + 1.5196250773427282, -0.262915623574389, -0.8395076912698629, -1.6509617804598316, 0.4853743370864656, + -1.2426649457879062, -0.34746037671673236, -0.2668537063224701, -1.5727266735586252, -1.117657993368196, + 0.38564526206938243, 1.8759791120426446, -0.9537042932874478, 1.8291799127445367, -0.005069359588546263, + 0.9881621403906173, -1.0839433778149212, -0.4270854569893858, 1.6796287101836995, 1.4763471217127018, + -1.0142364314111942, -1.9083290608311998, 1.8813585092075078, -1.7461290703297596, -0.21674842672674544, + 0.5502685329995662, -0.4411322123830974, -1.8416200674737109, -1.0844672555196375, -1.4812521607836615, + 1.5083908944091222, -1.038918979055473, 0.7142655565242073, -1.1043893519874795, -0.9397073551320325, + 1.335714667963627, -0.39552262419583606, 0.941189318776745, -1.973265716530558, 0.29314288363790375, + 1.1149136203907473, -1.3280789878295156, 0.4344131789940944, -1.4976865930360805, -0.18683967016091696, + 1.6841429470215354, 0.9990587834251263, -0.9896683309850935, 1.7344823188063838, 1.2809894053746431, + 0.11253057753370044, 0.6202153889005659, -1.1393527545232924, -0.13267950871413792, 1.9929188473410813, + -0.5751171946115745, -1.5738159418363038, -0.722979035597497, -0.3845867260905491, 0.6726496775295967, + 1.5922485171567802, -1.2068498206133187, 0.412853384927236, -1.1136351543496117, 0.4046306514337976, + 1.2348205714395197, -0.6837477630512465, 0.4396481646202046, -1.689933367495156, 1.2059642129302333, + -0.9663178383992985, 1.1606541969900555, 1.6008140707565817, -0.49901900408756017, -0.3295636539441862, + -1.6961254597784983, -1.2594474125668986, -1.8290342261999246, -1.3195471501043112, -1.3360729012617636, + 0.4819319828437685, -0.6044650413524977, 1.5401325916637765, 1.3108212503564856, 0.6641610189431937, + -1.167107917049922, 1.3717452437078013, 1.8234968598766077, -0.059000791459610014, 0.5078759939367101, + 1.7012186950957178, -0.8153543329038886, -0.8116555600682265, 1.1042603281155614, -0.5230370384584662, + 1.5907666644047485, 1.3585126059484214, -1.1604013321546773, 0.2832653250904853, 0.6146831527317715, + 1.710136815942171, -0.4339250659200289, 0.5404568535663827, 0.9731252061576328, 1.667624932562064, + -0.8395294570873553, 0.9655900545408684, -0.8459497721445768, -0.3303605936459242, -1.0228351996527785, + 1.64618479653826, -0.17369144560780647, -0.762588585662475, -1.420812072676508, 0.04977508086731497, + 1.1411840603073538, 1.2658579855151144, -0.7695492083057207, -1.5802557368206935, 1.8063629418173504, + -1.3314629726856042, 0.926332179655863, -0.29403591774091886, -1.4368324532624133, 1.794172160269225, + -1.45190145484798, -0.6970484207532923, -0.7393596578156938, -1.3777134665955009, 1.698548607127142, + -1.5277458031231408, 0.8121596602029895, -0.4871450419173451, -0.8973481250486168, 0.999939728035903, + -1.622149524743996, -1.8097564354752569, -1.9184903777986992, 1.6699213455758075, -1.8966494026411178, + -1.5537605568683883, 1.0114197035541261, -1.7582102654477056, -1.881605728865896, 1.9424710063889181, + -1.387404030261334, 1.3858022761207254, 0.6698691050652563, -1.7882425787208467, 0.05463416162273482, + 1.151588572666963, 0.9448022669750591, 1.2079032520591921, 1.3271868932748552, 0.7459734492068639, + 1.4448931947680101, 1.741554075288172, -0.49778064496799157, -0.05357363118867564, 0.17060355537450267, + 1.8858952797885928, 1.1828850394664219, 1.4398264570692447, -1.0269803221490008, 0.12795611159499387, + 1.4338295119095292, 1.7680805378740985, 1.8889001943775678, -0.023646131688371597, -0.055364618226636075, + 1.3107732868213482, -1.3726761197824935, -0.48421640176631975, 1.8520978683112554, 0.14900451528494418, + 0.7553309487914097, 0.995210897988966, -0.2653148753757497, -1.0047335940870337, -0.15140716923573905, + 1.0342357378533045, 0.9590192011054128, 1.3276618340182669, 1.7076552004070518, 1.3639368693762144, + 1.0626034699464553, 1.0985888634186862, 1.0871213821052033, 0.3518298069849042, 0.6905794127769393, + -0.5700252629850588, -0.10814050178161683, 1.3965639143955952, 0.8292785089561896, 0.25327348015151596, + -1.334944218927732, 1.6209990328336517, -1.26244979705121, 1.771347153639546, -1.4436659851102362, + -0.5033590550326617, 1.484309499478445, -0.3804774758165417, -0.6854434358446646, 0.11814627802495625, + -1.8940220672649586, 1.9948327521339193, -0.31419774418955715, -1.708608376028823, -1.8717143806284637, + 1.4405268554284918, 1.8275766986420505, 1.613878636732296, -1.6842307903910925, 0.40437384436799473, + -1.7974817731693786, 1.1222020737272933, 0.8616912290576968, 0.7450260507858868, -1.262819663341098, + -1.3890825964448705, -0.24733521021608595, -1.9566763230316564, 1.240599031645397, 1.0005078570110628, + -1.2429059760645886, -0.8343788480444481, 1.8180187022655172, -1.2406155795725864, 1.206905860440954, + 1.0657535671563094, 1.0688089382594308, -0.1623146723818225, -0.9394369384107097, 0.1644095126054408, + -1.5836754669396766, -0.45886757503464093, 0.02309988717518241, -0.11609000844996142, -1.8627934732063016, + 1.9415986762986073, -1.1741977923279308, 1.2850766678166368, -1.6650362245910895, 0.8711235235579853, + 1.703790813181027, -0.20031635110543, 0.7709122587840396, 0.4676763407673441, 0.7333591027965438, + 0.010661459873729129, -0.6248209657856156, 1.3499622620431584, -1.010555890938674, -1.9094767924575482, + -0.4118954261411796, -1.9764407569645153, -0.11597198863706204, -1.2058671493925148, -0.9480128119239577, + -0.5293044156931037, -1.0802020289569132, -0.17428061346628443, 1.79479586490669, 0.4507608914027035, + -1.5890677193516893, -0.4180241158854212, 1.1247910122551152, -0.10769135533882057, 0.2413054436244062, + 1.3070197809453399, -0.7234463442247714, -0.9044481440724681, 0.808060474881219, -1.9681916392611978, + 0.6794353030118225, -1.8140413117066592, -0.21172484209703857, -0.3970612901969721, -0.22610168646442563, + 1.8446444889972504, -0.14161684848047962, -1.0612317380319158, 1.6805704263182024, -1.5680342684533937, + 1.6367583739045974, 1.9603810572547848, -1.9461850695662726, -1.669279137293902, -0.4582040515383534, + -0.4903387593204007, 1.700009791169296, -0.1006260106974528, 0.46356012096704813, -1.148937310426322, + 1.1291686959534264, -0.43326682883595513, -0.7294554760312604, -1.3404464141642505, 1.4428709283346048, + -0.31527585219642784, 0.5232849548965959, 0.840594052127317, 0.8605144687711537, -1.6471991161039679, + -0.32119284017874694, -1.199582906673192, 0.3080262169024959, 0.14151294810583348, -1.7354321287231, + 0.4873457081904098, 0.30837874931057474, 1.3003825539901728, 1.9636934942685267, -1.017158000827136, + -1.7596484196087827, -0.6817110071876389, 1.8933995404689652, -0.27989446483984093, -1.067564992804905, + 0.48291319348514605, -1.6740080386493696, 0.2807422201361458, -1.9110968101706796, 1.6831495761856532, + 0.11512823651793447, -0.01736208420134222, 0.4405414565435697, -1.8085718479527353, -0.9564467319845411, + -1.58340676850203, -0.647955866711154, 1.6543925512752926, -0.7724666857844449, 0.5113949921548908, + -0.39267895944443065, 0.40761631065773063, 1.7690968773142668, -0.21764226027462374, -0.5169606846714005, + 1.4412302687210286, -1.8896763215831802, 0.07979887381656514, -0.11436797170178448, 1.0634323241712869, + 0.26858767414261475, 1.582753510128553, -0.010528543666477042, -0.3613643892495162, 1.391514209117238, + -1.4700872595733046, -1.5362821122874086, -1.818586245442571, 1.9678697208900475, -0.9362374796105595, + -1.4709960962767532, -1.182374325374778, 1.6385607669867177, -0.46775448579892487, 1.9576437315276696, + 0.915531228777362, -1.0860235734385926, 1.0655104012081509, -1.7770877181643954, 0.1443659128293806, + -0.23955298993629803, 0.41725367891443366, -0.8558589757408512, 1.024674449305122, -0.8538581096220099, + -0.17121172366938264, -1.0495343198650096, -0.8461809157835463, 1.956660524400533, -0.10451516941234473, + -0.9119888509755709, 0.9341633453090434, 0.5765821236303488, -1.8017153374435075, -1.6959921212218267, + 0.3565838506194048, 1.986423720658717, -1.1810787364750697, -1.3554314442606277, 1.2292087344595828, + 1.9389467760629646, -0.06060251846881748, 0.6471281822482204, 1.028562584237319, -0.9889764039700069, + -0.30300382154607064, 1.8809742113734886, -1.7911374091327446, 0.4093234223382991, -0.692170253260544, + -0.5766217362114325, -1.234711065294488, 1.4845455677723791, 1.1142993730640143, 1.0351547495051978, + -0.2304804542756207, 0.7896680860540854, 0.7368394967498872, 0.6117647784314304, 0.9649509001647774, + -1.4794529756304886, -1.5330264541276408, -1.1347331780500776, -1.957296773370273, 1.0497217009949296, + -1.876577007676679, -1.8707142834400772, 1.0355676671507679, -1.3024864669068572, -0.8110172097035955, + -0.7956308468122133, 1.5651626086889294, -0.5950947090287055, -0.15363512205018193, -1.7408026469236138, + -1.8514840915078024, -0.40821529034130855, 1.2085979600022174, -0.953165571253348, -0.03303673441403454, + -1.6012777563202603, -1.1080907034689993, 0.4974859939502432, -1.2774517368694216, -0.20163785863448247, + 0.8261780324851822, -1.5127015843190126, 0.14100033563999403, -1.052646319316592, 1.6782929279009817, + -1.5464154280508744, 1.7486715792103427, 1.7780537663957405, 1.7209562957702067, -1.5054888719925499, + 1.5292916602749358, 1.179965119787405, 1.5170349093126205, 0.7433092687415837, 0.9522185073327041, + -1.2380413345480523, 1.9354870169011447, 0.625636243747028, -0.09000816987169546, -0.5335972012344188, + 0.9674628266238745, -0.03967494279717343, -1.8591428816092792, 0.2309446236016406, -0.9041639531030761, + -1.9026103934874206, 1.5541589920102101, -0.45696090245009824, 1.8466298423672463, -0.36055327204706167, + -1.2458073226198056, -1.4410345639464017, 0.4731557626110279, 1.2307498218360111, -0.30913913858563724, + 1.8557259220973865, 1.1724797822135766, -1.516241961681641, 1.147047572924638, 0.009295148811337306, + 0.8291735590935811, 0.9825251314963639, -1.6702374836146134, -1.3070146895265724, -0.35977729833032246, + 0.10882986028094521, -0.1545812635060546, 0.5966946312401102, -0.12463585998219351, 0.764026848253426, + -0.0653501987613172, 0.5337159207310522, -1.3783865008607394, 0.3440914524635428, -1.6128660868012537, + 1.0505520366072156, 1.4508195056160966, -0.3811605116562866, -1.0184337989154448, 0.3472432185953034, + 0.7934690008453043, -1.1871814411996455, 1.7160465328415073, -0.8932034391682153, 0.22684342695521842, + 0.006601468173127678, -0.8703158970162406, 0.7854001417512366, -1.6096149032006064, 0.5734105371918883, + 1.8323642183966413, 0.8494195484926621, 1.52159530384528, -0.43666400911265324, 0.6949749758230679, + 0.09014001060463439, 1.1181673671725294, 0.23797216532766896, 0.9091467606318648, -0.051242214293259813, + -0.3492957666583303 + ], + "dims": [2, 2, 320], + "type": "float32" + }, + { + "data": [ + 0.0018182522274203805, 0.06756509596322857, 1.889723294866065, -0.10289095754140298, 1.5711519216894745, + 0.027529292075774592, 0.9603256438495507, -1.497309631471758, 1.9251601219617065, 0.8851491878732389, + 0.05078780805071137, 0.40903741455911735, -1.6644840015459215, -1.348225759557871, 1.615832737926227, + 1.042719864089511, -1.9289326046242312, 1.3417535199012995, -1.710655801290117, -1.130165128147044, + -0.3755000776719024, 0.6155781582426902, -0.5883485771887473, -1.7159986811406176, -0.16333854572017525, + -0.06079239446971929, 1.6926064002585495, -1.8776332892248098, 1.4601970742576578, -1.3202352800423185, + -0.12899708506012164, -0.6003093613879029, -0.1726349092091164, 1.2394146364350664, -1.769629141089184, + 1.4197330981295524, -0.9504267735259635, -1.240675610662361, 1.4018548317486923, 0.5332018345413356, + -0.16073415033536875, 0.15303724703170385, -1.8037963193841238, -1.311714810716846, -0.5740602095553404, + -1.0372165240096223, -1.370949121899355, 0.29661966702940035, 0.07816374250571023, -0.41396787300651905, + -0.3694698645575212, 0.6759765867037197, 0.2952400780995772, -0.06275069272676337, -0.9130274419561628, + -1.8944701092982958, 0.33465806810173593, -0.404939193749847, 1.4043718178232805, -0.5590711165263631, + 1.2184926422968934, -0.7087036307842709, -1.6055109382118182, 1.968257767003597, 0.529695028652811, + -1.9967381817454308, 1.595078125176956, 0.9871155490120058, 1.6566751957870993, -1.609626148231829, + -1.1397801527001823, 0.02238544560446254, 1.4873497063814245, -0.4755743745599572, -1.6926423664304844, + -1.4161828320433028, 0.346372427157398, -0.18203459832580027, 1.4635583159542183, 1.5944148599650028, + 1.3186726267824955, 1.5675687012032968, -1.9754809408365706, 1.44557963549327, 0.9397875688795354, + 1.4424046221061442, -1.4458135310352649, 0.9975520389856136, -1.9027511578082317, 0.9144382540308484, + 1.052124261689804, -1.4732678674195494, 0.29024955712503164, -1.2231252144665383, 0.34787712508784985, + 0.3556934319800238, -1.7419738471239645, 0.8630538908485903, 0.5386452782458866, 1.7600516786463105, + 1.8905437777505014, -1.5744952794523028, -0.7530004157782235, 1.5678919268380707, 0.034533101389558674, + 0.7325333516090975, 0.9775064333478163, -1.6408433791748216, -0.7414398323785214, 1.6725433719876586, + 1.0072882099919305, 1.4341931058179327, 0.7139948421146176, 0.40545031341822124, -0.11478362063979386, + -0.9345270890825441, 1.4281286745225614, -1.39970554180245, -1.1485396410325386, 1.1495990036520798, + -1.2916127094423402, 1.4211660589871826, -1.0749317173140405, 1.4370776307284663, 0.7880288709576773, + 0.46732661965227873, 1.56798877542517, -0.9531716195760707, -1.3739051298849194, 0.9766290318098436, + -1.307661662111708, 1.8574559170417002, 1.35797073743995, -1.6940130226054606, 0.28491131826133387, + -0.36419491260352554, -1.3047662545854015, -0.9266815176320033, -1.2358711507932467, 0.9127887752631247, + -1.6466848327495578, 1.1607458121027339, 0.46297657760513733, -1.5495718508374514, 0.3292413137438217, + 0.7675934897387728, 1.4008121445440214, 1.4898570624591958, 1.6030744917802648, 1.2925872420362232, + -0.8421561750911684, -0.3407292616133608, 0.38924919209979336, 1.6793513775487527, -1.0373013949726966, + -1.5337353736283532, 1.5143316995909872, -1.6320472160478126, -1.3996482770156646, 0.6337864872715988, + 0.5406528347636357, 1.2967809902878562, -1.5182702863754916, -0.7399098341126589, 0.31978027899894723, + -0.4320909370805026, -0.057815767103424065, -0.42656779912656795, -0.7191461156604344, 1.732444695508872, + -0.16793165663622744, -1.029044319841585, -0.7379183254565955, 0.6335667491493258, -0.7407757474651113, + 0.737814588729532, -0.41713542698826167, -0.862992043249343, -1.8537968903371889, -0.480058608858549, + 0.04028745468513595, -0.11696118988455151, 1.2159286584219329, -1.4551651039165874, 1.8518920420484895, + 0.8324620148383071, -1.9503205997190287, 1.3118092522348013, -0.41781057862944326, 0.47025354333711356, + -0.08599400306878557, -1.398138636933056, 1.799030968066016, -0.9016154689967486, -0.44642885397970034, + -1.6161407274817075, 0.6108393015698415, -0.9652371448534662, 1.472448459030451, -0.12097411552763226, + -1.7427779621544364, 0.6772588555443013, 1.239525535102806, 1.4978793781566582, 0.9794171716198061, + 0.37480400234555056, 1.7099069435864092, 0.5339030487857208, -1.7368267186422761, 0.3401246801395308, + -1.495349576003802, 1.1154539341471592, -0.5739747352480027, 1.7719108709631328, 0.7087464471791378, + -1.407094251765498, -0.5711994993106657, 1.6197007171162792, -1.665245693725593, -1.4093290138388097, + 0.8150971020478908, 1.1565262598728276, 0.007036682898540647, -1.067724969488646, 1.1760370444772006, + -0.4660822995530971, 0.18663889657333232, 0.8600384570962394, 0.07639203983671461, 1.7055162765205303, + -1.7134292208088802, -1.3413558800873675, -1.338677372528159, 1.4246968540400653, -1.1823984287999973, + 1.4751654585472211, 0.5262834049380078, 1.5117343050060867, -1.409416488118043, -0.39544742603356386, + -0.6577586706586658, -0.5919201797053688, 0.6013534842506445, -1.1862135968111707, -1.229417973714626, + -1.803412419156234, -0.7655790098575235, 0.9128632433156794, 0.9036623476529559, -1.9831271121679324, + 0.8324308647368319, -1.759507307385337, -1.7725931616787687, -1.7039303423725647, -1.4439967872268928, + -1.7432455401143834, -0.02216033991501387, 1.2819676717165, 0.16659457648361364, -1.167642388668959, + -1.7143084152722228, -0.7289345444538382, -0.02925241516287791, 1.9566358667801342, -1.2857581699546135, + 1.0915031830445114, 0.05084200795390714, 1.083568818422366, -1.1315486700234478, -0.8881346175534794, + -0.63619987674788, 0.3799832019858531, 0.2477670922101094, -0.6132210208290614, -1.8451948781462812, + -0.22847217268867048, 0.0025467735349682386, 0.1315834394384794, 0.1776926575489597, -0.8691295174311664, + 1.6637912565242994, 0.448901769947029, 1.233816013145204, 1.7971799993597228, 1.8719614934816882, + 1.655937636621596, -0.27359273976124054, 0.08461142131684696, -0.2757947346097396, -0.9521228519499276, + 1.766034536643284, 0.8831916052200137, 0.9813027219865562, -0.322591101625501, -0.20675723380495992, + 1.0866641329284041, -0.6397672290782843, -1.9715973970816654, -0.36395252045986304, -1.4160336028155198, + -0.7487477697571272, 1.4091533113140509, 1.2152244001439598, 1.0139512253023701, -0.5841820989850488, + 0.36171343432432046, 1.1810326691265303, -0.044977125366693294, -1.5719763377131732, 1.636814383280785, + -0.8254090686593019, -0.2739258751225844, -1.5838736296117837, 0.057544692367468286, -1.6536791042504957, + 0.8676152862870037, -0.6012988236535559, 1.0789190140651197, -0.21655562768188386, 1.5865699400089461 + ], + "dims": [320], + "type": "float32" + }, + { + "data": [ + 1.535672932186043, -0.3469466691127403, 0.7594896463626952, -0.05450122463129414, 1.4639377922956394, + -0.6333990278356465, -0.8242789470237648, 0.5117653543833605, 1.6078505759993273, -1.410275750604895, + -1.6792951377646883, -1.783057576321041, 1.1956662347204423, -1.3979831191193002, 1.7644067312268517, + 0.4240762243207543, 1.986096182518743, 0.36545941859180964, -0.8774236745093011, -0.8647372274160503, + 0.8720148666725347, -1.022106286236455, 0.5503111675120635, 1.0204841436521281, -0.9254965061314904, + -1.3449432022823808, 0.006824458535456657, 0.07690008991648423, -0.8426817383905396, 0.9996621329373534, + -0.23056243949407484, -1.0440039859718286, 1.9168909615768683, 1.5600000104620682, -1.9890822883775865, + -0.3604004168107382, -0.028511959235538065, 0.8476098198214288, 0.22053970034789216, -0.42929632097288817, + 0.6599479925924321, 0.6647860485919495, 0.10175396167639938, 0.22650892002231515, 0.4701540897019987, + 0.624214514356682, -0.6652257805050041, 0.8518349008799753, -0.9562813618340789, -1.657496508881473, + -0.3312572279583206, -1.5494034812904562, 0.18877981986543801, -1.2351800795813066, -0.07918559380797063, + -0.09391536586009241, -0.2856357420391582, 1.9393958604954182, -0.7529216437305211, 1.525084648903749, + -0.07883509109638975, 1.376637107607113, 0.5783362536287875, 0.961847664027677, -1.6855455725917468, + -0.5830772019897683, 0.4271291901307981, -1.8373857521152086, -1.7965394924729141, 1.7115697467771378, + 0.2565457539488545, -1.3360260284983019, 0.4353676471582455, 1.7248708601658969, -0.9750598890096729, + 0.05312641822767361, 1.5787554531472985, 0.9667162219022503, -1.364971428290251, -0.2814850946411962, + 0.9013643208289279, -1.4725055888862961, 0.6001425944665559, 1.2723681158746203, 1.7714493392964075, + 1.4044899825398272, -1.5787548082153382, -1.1589036159974757, -0.4012478414167475, -0.06868641055197777, + -0.7534521481998526, -0.8700101449208493, -1.1662115104665567, 1.7611310737805477, -0.2501517942331226, + 0.12866215308587936, -1.4699875001512854, 1.0395395370450604, -0.5782390952646876, 0.63115653417037, + 0.10138581116634082, -0.07007439881121424, -0.4276277546360472, 0.418589841403306, 0.9267207479900215, + -1.0293664343515356, -0.1495871781602336, 1.4452889339030666, -1.7189823464809564, 1.8323799237149645, + 1.8914008693919682, 0.3829486757403364, -0.8735369861149813, 1.602486711188683, -0.39959917784662924, + -0.8673792916868024, -1.2627215362178648, 1.8597348040684398, -0.8688156300975107, 0.15713415561611388, + -0.13148226217512082, 1.883732805180382, 0.11420203807616502, -1.6552288945493094, 1.0335466032430753, + 1.9806710089769703, 1.988269693866676, 0.29427412741632075, 1.4966799360753749, 0.6937827119996989, + -1.298620046493725, -1.752952308784005, 0.46645438478103873, -0.898908219432915, 0.7139098459371658, + -0.16215199540773462, 0.07954281050960432, 0.795652990025399, 1.5967568490712063, -1.2445652996859247, + 1.9127555713254205, -0.4996844935898572, -1.1156627480592256, -1.2948343944985163, 0.18276720875230268, + -0.748683470251498, -1.5079466014120557, -0.2494558107532141, -0.9231537960141623, -1.4121241243829443, + 1.2059834829573104, 1.905725511300579, -0.39337905860681044, 1.8425190053973166, -1.6566221588247219, + -0.242919176072947, 1.2425502129492436, -1.4417507121400348, 0.015600407032383856, -0.2242098694907284, + 0.18796276556529357, 0.08107732342066765, 0.7149451467441841, -0.20769007081368773, 0.4421202004832834, + 1.5233025839787455, -0.6642431462292846, 1.5564028464468986, -0.1586815058735116, 0.6099306071219655, + 0.8180887224937807, 1.9911546339103818, -0.005984685083011421, 0.6777759409892354, 1.7289968623869099, + 0.5264262640237458, -0.511038272902959, -1.7235775305068346, -1.138944875679032, -0.9623892814614488, + -0.6380738572168294, 1.8832106250881075, 0.028541651530706424, 0.7394956760616829, -1.5455450050824036, + 0.598697699776686, 1.44227094769795, -1.842926293477114, -1.9786511228960357, -1.774125089606943, + 0.4273755931309067, -1.1833540770674968, -0.29742688579612864, -1.7932368057978882, -1.3999703979662605, + -0.5494229951060436, -0.7692231154827809, -1.0112160791506497, 1.633993910846237, 1.3945699010195831, + -0.8649776103569309, 1.921348771224042, 0.832322610715301, 1.3754060709990767, 0.3497018723561567, + -0.7191957838389857, -1.9794221990125722, -1.84384806203993, 1.2324522851211803, 1.7698494016317143, + 0.006624102198243165, 1.5911519918365267, -0.762455861009844, 1.4479210196035108, 0.7818151145500849, + -1.9876926272814606, 1.8202062970885162, -1.6446357331454369, 0.8692666690506741, -0.7358532212979823, + -0.8444806659707744, -1.6015224446062994, 0.7479281419258141, -1.6523782603794155, 0.6710725185977369, + -1.1710932073100304, 1.4784513737588512, -1.212966513263102, -1.3741809040280142, 0.25437428444308896, + 0.8440351752665407, -1.1722116121570672, 1.104161389783421, -1.645735790976162, -0.4286533806712738, + -0.37044520217626875, -1.574330285391767, 1.4899314272896893, -0.8495642882336822, -1.714377156019676, + -0.4893435327563349, -0.7616337581393848, -0.5339391487933929, -0.3003289730553087, 1.3489307896261735, + -0.14680109166432054, 0.1026969670558735, 0.32430953678969043, 0.1795871726769951, 0.9696062238740311, + -1.5296687271207166, -0.2770372362376037, -1.0409472934130868, -0.17306368093190905, -1.1960408781183967, + 0.984219061951209, -0.4077661181651919, 1.7423047847942446, 0.5608878908901787, 1.4329493489434109, + 1.8986413512869937, 0.19154199669760352, -1.3315756935180012, 1.8870822754754517, 0.5674631415439482, + 1.1017148980678542, 0.7256621357674105, -1.8682573426264009, -1.2687446906641284, 0.5430939279068951, + 1.8279931413962558, 0.15890686973919443, 1.394841983124743, -0.8330211159668224, -1.2412716683059033, + -1.1755274256803165, 0.3146422214936937, 0.5127310756940888, -0.6223681329826247, -1.3728009148038876, + -0.5073994704733549, -1.1727465329222264, -0.07518002339175833, -1.6218851358655701, -1.3314808424730247, + 0.2696107099425271, -1.353815758928219, 1.6070801592460056, -0.7018653814032136, -1.594649470877921, + 1.8662880614030657, 0.009632792539534307, 0.885433106263176, 0.7081454198732997, 0.12480572241808119, + -1.9002637028711113, -0.8823815470565757, -0.12794198437065507, 0.3682196882451354, -1.1962414622570767, + -1.101920787984521, 0.1703217046774217, 1.2755057257388405, 1.2757461273763866, -1.7253598839195732, + -0.3935586680170111, -1.6790297555951925, 1.1726640337873802, -0.7187759606615787, -1.5997974808572053, + 1.9512036824878933, 0.8991982283799391, -1.516998597379371, -1.0918962406357053, 0.12845929863120276, + 0.387447437135819, -1.6766371647631972, 0.4172435231617522, -0.8587881195399367, 0.8973805509978297, + 0.5384910477202398, 0.22290981983700497, -1.9824980848037859, -0.19789410371539873, 1.0396641208977249, + 1.9498654847750698, 1.752979186273122, -0.10251547854421528, -1.7031576116596918, -0.6422947693243835, + -0.5947775282776488, 0.25094777162345583, 0.5519773563378578, 1.1845669608153342, 0.07011886849473115, + -1.5689347607142432, -0.9068208446502926, 0.4518736648271817, -1.1908598340431444, 0.9123278060019366, + 0.3808045721687314, -0.5161116183400685, 1.4633312728276353, -0.24955275031843804, -1.9270793627181808, + 0.5510310380033525, 0.002103402836195478, -1.9722027133603266, 0.8207770388309132, -1.2709862666051333, + 0.03660015849392373, -0.08721025552259398, 0.1480447971653538, 1.3975878551198289, -1.8688681862560603, + -0.2735983144132472, -0.29150197793885635, -1.349553505848272, -0.14289894302424067, -1.0632608448362548, + 0.9197316019995538, 1.6766092374653363, 1.4333994578157911, 1.8497508886723608, -1.8365902161760914, + 0.3329875047259945, 0.28711035354851955, 0.018743287980965917, -0.47550704561352664, 0.026002587809994537, + -0.9815518239812109, -0.30422655490353545, 1.1236748290508274, -1.996970334350796, -1.663190926732148, + -1.4930228184840004, 1.2293686779591093, -0.11228295031816327, -1.399262159949875, -1.2745774075202778, + 1.0404471355251506, -0.9042932188930193, -0.483855773240883, -0.051899262666108115, -1.1517591694487734, + 1.631117268451015, -0.4341760983538707, -1.5093199848977354, -1.524695207871412, -1.179033179719653, + 1.203939869858461, 1.2278371883112191, -0.7764972190751465, 0.12469436067847539, 1.1254668267275294, + 0.1253270059252225, 1.02529025377972, 0.37477534712132243, -0.816607896481754, 0.7652933238577306, + 0.0816252203587613, -1.6877073529228523, 0.282188424454314, -0.48899417877023144, -1.10579595806544, + 0.4180711569457314, 1.239608967084651, -0.8553327976952234, 0.7601553028351749, 1.017191993054694, + -1.561711107008871, 0.18166558203866234, 1.4575039351725838, -0.7919992885427041, -0.05528739747934974, + 0.5393789182198327, 0.9208003955213648, 0.8037584630910892, 0.2508199691349171, 0.5025718274381168, + 0.40223725437742086, 0.43401128486340124, -1.918673978558985, 0.38895512761013773, 0.9647875436316022, + 0.356426573504554, 0.7676218046110401, 0.15946706730485438, 1.2727737024033576, -1.1428215846133938, + 0.36778995418490545, -0.41392909578544224, -1.550642999283478, -1.7016418383565188, -0.3516276355010701, + -1.6424434547903983, 1.2296355686757101, -1.3262048004001983, 1.9866748350391505, 1.9039145370701833, + -0.4605978047947623, 0.37289955561548194, 1.2909351136100344, -1.4775326769813537, 1.8608708474080071, + 0.6440656172393684, 1.4358923542702868, 1.6635530454398575, -1.7844300247360296, -1.470415868795488, + 0.21864396672047892, 1.5488664195436606, -1.0864322992770177, -1.550780881959068, -0.8331945313037004, + -0.9367699280324953, -1.9013228249406309, 0.7807264098375688, 0.06677827961955263, -0.4865949947067687, + -1.9079733463147095, -0.8445233464370387, 0.4065074139836655, -0.3310839064029283, -0.2904445573034993, + -1.1753367420636245, 1.3721435340052208, 0.36660883645931097, 1.27723053302687, -0.9359216637576937, + -1.732231846976478, 1.0644600709999477, -1.6378422934868384, -0.8826400850725795, 1.950622879844948, + -1.9911319096225792, -0.6598073662934398, 1.8955996856482145, -1.4071132961709223, -0.8795225115767629, + -0.5029228970810946, -0.7734268477225967, -1.716542237524993, 0.04010671043366898, -0.7937284158037281, + 1.030026939297609, 0.9801808342123648, 1.2953427689382302, 0.20610803631475605, -1.672761300291775, + -0.690673451769495, 1.6609033000524338, -1.897131087105456, 1.2029533984904228, -0.5681454803874688, + -1.3646956682920965, 0.22071074912276334, -1.4735886916157908, -0.9695144027680014, 1.626222864433485, + -1.8694559899308487, -1.8879003983306655, -1.0176033048635613, 1.7915586444709328, 1.8810192124623084, + 0.5319984718680058, 0.0113238596202212, -0.09805090157632446, -0.5444299501215024, 1.1135935258682927, + 1.17684427133796, -1.85426568001437, -1.7530946944132086, 0.3038089938756876, -1.5870230070820002, + 1.9106333020042747, 1.3937407603560725, 0.8591788216145968, -1.612956779000272, -1.7151209190289016, + -0.13707423626294535, -1.4389728179178984, -0.05236093609874359, 0.9751452825232896, 1.744306648935904, + 0.7254929535860901, -0.09824503868926815, -1.4925208247531838, -1.985227418605298, -1.405540454178178, + 1.692915817031472, 1.2230668144021735, 0.04262811065188643, -1.1756894733009666, 1.0222275091190456, + 0.4934708666464802, 0.08979456736565261, 0.10059671914562518, 0.7155249975927536, 0.04082949674837977, + -1.715826873724553, -0.979189481763262, -1.1065843508804214, 1.481429410565739, 0.5278383608268999, + 1.4941027771635946, 1.4151786058577498, 0.07974076288029774, 0.3167509060420519, -1.269619345887964, + 1.8667680276727765, 0.527367815431, -1.874110045435497, -1.9373013120064702, 1.5330729150450173, + -1.7833509822122444, -1.7182607692067773, 0.7561591559894678, -0.1056962696530368, 0.13014948563496898, + 1.804101947913626, 0.7276195691909635, 0.021465712639121115, -0.5163553036182069, 1.1855106734783103, + -0.532372100695512, -0.5871635445412506, 0.161643292508721, 0.61018489160484, -0.9869821416193743, + -0.5318766940780302, 0.9532631042147388, 0.4597709134353236, 1.2142228742259942, -0.8224515402258339, + -0.879922983657166, -1.4710925151016916, -0.29851124917883975, -1.6631372706933156, -1.417993373545026, + 0.6364481896978704, -0.6013255938328603, -0.046835161333119935, -1.7247175181005758, 0.9825982199711403, + -0.7264776248635592, -1.463988875635824, -0.5013956201257255, 1.1933878395314643, 1.3056455851087287, + -1.4398688148432273, 1.7038722585040453, 0.46568790654958114, 0.481485333420693, -1.2873930064513237, + -1.1475617778051763, -1.9673375617555031, 0.39874490557435127, 1.0960170584095357, -1.8987243885981488, + 0.36983554057526735, 1.9718844590478293, 0.7894176749822801, 0.06983603687412288, -0.9000466156841869, + 0.6428129286371904, 1.704798225993037, 0.5950045030048496, -1.1678955586471442, 1.237662010594594, + -1.6482921001228146, 0.7270614937877813, -1.0006186813130835, -0.5305400798817805, -1.5252716548293819, + 0.18855276048488978, 1.5437352372976703, -0.9397215004565727, -0.4258934153954881, -1.0950445559616853, + 1.1844079915434298, 0.024990774215178035, -0.16149461270780652, -1.4078837300903269, 0.09499589792836627, + 0.516842370641422, 0.4800833347119191, -0.539291197739594, -1.5117979701954605, 1.354396143788092, + 1.28278689745333, 1.8488206619245648, 1.3022599053953776, 1.6609548614809775, 0.3269713781203789, + -0.38485903666664, -1.464958277436181, -0.3992461504929734, 1.0699961189397085, -1.0135843210651023, + -1.458604697589653, 1.490121083969428, -0.2595359932483561, 0.20854389182544342, -1.7482190390121701, + 1.3007127316326326, -1.9884730509986825, -1.4952841032131454, -1.3179011133758536, -0.388478076479009, + -1.1589100485370416, 0.8387145536985532, -1.3384696651494759, 1.4683529176022008, 1.303145953986827, + -1.3041819891109316, -0.03449749547681513, -0.6608734038387656, 1.1683787754166381, 0.5655509145236746, + 0.37738607602963814, 1.2152762148898635, -0.29353655718583926, -1.0509092280694636, 0.7139081884804019, + 1.5196106141395527, -0.530586207320952, -1.558831387258346, 0.01131046295330318, 1.9344117181061735, + -0.6850503993030497, 0.9331665418290909, 1.1688357095654807, -0.42466684124295995, -0.49121961262440816, + -0.11897540791552874, 0.5942162255141863, 1.7548838522451646, 0.4438013028171106, -0.28183936745813476, + -0.07495854303862437, 0.9303587326961971, -1.3198776631733748, -0.591718773961956, 1.1127108159764676, + 1.3939520197540327, 0.6360105102962654, 1.3722503898910716, 0.1757960098808633, -1.8297470389548955, + -1.9205472381959057, 1.4666198830651629, -0.7830326162911714, -1.4564248566278515, -1.0967977812614587, + 1.164770039819981, -0.5760771475874042, -0.7667709006388028, 1.371522788043694, 1.7326600398634024, + -1.8193902025531763, -1.6090197630011929, 0.09836987546776577, 1.0677415637460363, -1.5307232030478781, + -1.5599516580470505, -0.2609675276531007, 1.0598276568453162, 1.8794380814113483, 0.5316994667209949, + -0.7552146503779023, -0.4617287817040179, -0.3819745586646004, -0.42575961119349426, -1.9237942552613312, + -0.8825058198571423, -1.8728798417790404, 1.7802885739921077, 1.8333435291969842, -0.3098252256281784, + -0.029956863413143964, -1.0772837825116914, -0.08180463340649524, -0.6945910113459792, 1.1668128443816146, + -0.02738480437430635, 0.39293059281590104, -1.6704359314356383, 1.6869956995774205, -0.2294375604199823, + -0.32757809443951746, 1.9764189201357567, 1.5201484938081151, 1.087504317388186, -1.6272710803209698, + 1.0397868469069298, -0.22176854941092294, -1.8820468396186323, -0.5303897107068192, 0.594170569473933, + 1.6960001373937432, 1.9644545152850057, -1.7960342834770175, 1.7873883299813267, 1.3489957623885935, + -1.6820391707003042, -1.5713129762775537, -1.3637851932919034, 1.5936068708950781, -1.2089638711610604, + 1.322028643928432, 0.8929678781012855, 1.8401053408016272, 1.5452683829326075, -0.9171427145163484, + -0.06745535875434427, 1.9379586035273615, 0.5855503730756357, 0.03549855059545948, 0.6527698319031092, + 1.6754602207349976, 0.7728323704391817, -0.9665588877441182, 0.6173545510334506, 1.3120695172827377, + 1.181821226786317, 0.1841309435168954, 0.32318631702986167, -0.790159398034489, -1.385019609574396, + -0.7118643666238835, -1.0439971536099275, 0.017584768122861583, 1.7536303032255187, -1.1965922071808155, + 1.4548082915973595, 1.562828560283652, 0.0920828524560271, 1.892000960124009, 0.3648061373597171, + -1.9613669287159263, 0.841563763070833, -0.9118328355251464, 1.9226565574363006, 0.3988462224271192, + -1.970188432590363, -0.8264337415665439, -1.9090851263430704, 1.8428915650288547, 0.28596991752644385, + -1.6708643684685667, -1.0762549708362332, 1.9492472760488564, -0.17802109704659852, 1.9236671687550047, + 1.9548632849049623, 1.9566450030001414, 1.3303550873049677, -0.5915124672929295, -0.0037832544010933944, + -1.6026781861800838, 1.7578833516354813, -0.2956678774270767, 1.4060455643195402, -0.7370157759032727, + 1.8789198126787916, 1.1123902493105078, -0.8769185681462304, -1.2618214191177506, 1.8674610245111278, + 0.5103075356648485, -0.020685118611023512, -1.407221324818173, -0.7491588381751608, -1.3743460812306214, + -1.7710712130536228, -0.19455369352318552, 1.3434990212660862, -0.5544338320325721, -1.324247058015053, + 0.8874849369101687, 0.6838871095643375, -1.5617313105262172, 1.204432716341258, -1.4981479923604955, + -0.06499977622096687, -0.8060264147106553, -0.36092597365795775, -1.307326777195418, 1.6399424900785, + 0.429157912433868, -0.9915570262096942, 1.5128426032058089, 1.6375586318255548, -0.1737010535017669, + -1.21285453753765, -1.8844155037723906, -0.2590630754224348, 1.7328565249414716, -1.260633142919116, + 0.3637043664955444, -0.48087365705468965, -1.7001295586898113, -1.0775016378447075, 0.2620695698901221, + 0.5015363913767086, -0.42080100290276246, 0.5338170065286052, -0.43568151602634764, -0.744286733837793, + 1.57647267103789, -1.2491109283310529, 0.10032144655805375, -0.46919353377702855, 1.9415827315644636, + 1.2111393515469855, 1.744725164783687, -0.6871612254817352, 1.406736078990102, -1.063724178982385, + -0.904699966390976, 1.5681407930006221, 0.79849818604837, -1.7759907970834616, -1.6325947440964974, + -1.0309733602826086, 0.29563414198237936, -1.7157737037653629, 0.2876568188935451, 0.21411659926835913, + -0.1601632043965786, -0.02605079418095091, 1.2041639219664182, -0.6351323647136597, 1.1149646585336592, + 0.6657515663650084, -0.4672646227384094, -0.5117766415018226, -1.4643244157794149, 0.39081520672097003, + -1.502649477455031, -1.637368884151761, 0.34542161036123176, 0.060151105688381, -0.5045040651104555, + -1.761988723037204, -1.9197872473179176, -1.3665270161331975, 1.3928026939637972, 0.39218445250695577, + -0.8470063024385848, 0.009038121027233892, 0.46871439485211575, 1.459827780771806, 1.4853766551455694, + -1.2321752545416356, 0.3748806345040103, 0.20582729258619814, -1.4266279966077402, 1.2950255786963805, + 1.7125611822808544, -1.407545517068188, 0.5169179018491512, 1.8595592751857541, 0.9487671455033482, + 1.9467423989905699, -1.5919149626150748, -0.4630901723451881, -0.698284068975914, 0.6197574561950008, + -0.8199869405915381, -1.7196626702920055, 0.6036024034626806, -0.8348164600145518, -1.7650166093756132, + 1.5829990521620996, -1.588645487863901, 0.7633248861408699, 0.5800948434762754, -1.7159447523887836, + -0.3836837699904496, -0.9746560067630572, 1.4480442893861705, -0.24403527878135645, -0.6397662241706819, + 0.956203271386264, -1.4601856308265049, -1.5649468816584298, 1.731664582215319, 0.9679933953420496, + -1.9722379093946607, 0.24076423675934056, -0.19244242272389211, -1.3854799949067935, 1.3744990882455346, + -1.121046645776083, -0.4342567706309435, -1.7159646482293107, -0.9317859666979054, -1.698219647396134, + -0.8288368620433939, -1.3875410583085985, -0.16399331338641066, 1.4667160798353667, -0.020345764369364083, + -1.9585520591695529, 0.9886716666217517, -1.0701744437434098, -1.3248249591382057, -1.3272246201915312, + 0.906046259715148, -0.9554587301398687, 0.16744253332193715, 0.40874734944503466, -0.7237514235199383, + -0.8028952942996463, 0.9478548199038599, 1.6268191787625108, -0.7376232063503751, -0.6874490141085632, + 0.20469380737641973, -1.940886393624119, -0.9715176541080677, -0.11409081023343237, 1.8208884259904847, + -0.05753377002269655, -1.2533113228725696, 0.18235199190840046, 1.3670427559403047, 0.7183594427747524, + -1.9834311091439476, 0.09488256814231644, 0.07406140049599319, 1.0427950622016802, -0.7928805141629418, + 1.7221214208634228, -0.06548459693275177, 1.6984102601031559, 1.8777510809050026, -1.735259524674964, + -1.1416240368033357, -1.7612022583614682, 1.721880360655705, 0.14372177475853665, 1.9269311955654835, + 0.19978809107127216, 1.0299566806165856, -1.7617419918814026, -1.2737765895488096, 0.7789099525859564, + 1.9816257384474012, 0.4482897887627919, -0.9051913536142644, -1.152506387584042, -1.8817136441487783, + 1.1054935295772461, 1.577999662025542, 1.600449927735128, -1.4919075331081064, -1.9550057574515671, + 0.1306184124670624, 1.4754764229533928, -0.808023880270273, -0.21695285993080393, -0.539628797891055, + -0.7836468765498132, -0.8815668388678288, 1.8917264703112755, 0.028934119940069003, 0.06879472114883711, + 0.4407647131615784, -1.702696302284755, 1.7815067716148931, 1.7950168026349171, 1.1405438335719111, + -1.1434283018085534, -1.720238715793207, -0.7729623733152229, 0.17672006075090962, -0.14942755614865622, + -1.6229777838891115, 0.3793725781830055, -1.0113407389657345, -1.9280392460441265, 0.7422498462017861, + 0.8331559663193939, 1.3063596659922263, 1.8113167679814106, -0.1401093760534291, -0.24674083884906395, + 0.15509679692376732, 1.8667087827355466, 1.1906398286118094, -1.673307806924564, 0.41063702592861695, + 0.2436862014560477, 0.24067383021132027, 0.22686603382511628, -1.7295093225806442, 0.6565075922933001, + -0.5514412373381097, -0.5236684516031653, 1.8733248509057603, 1.082970345098504, 0.3340204503283841, + 0.5450315229688343, 1.0954041212853163, 1.565649272477267, -0.5469992182522905, 0.7193108029242588, + -0.9050070254533322, -0.5121370718971949, 0.962566205706815, -0.24631520617082092, 1.3340325816997325, + -0.8820024080231894, 0.22736077826137535, 0.2252389330707789, -1.947448723525529, -0.9518843625899782, + 1.7502182429516546, 1.558646352665332, -0.838440251378624, 1.541757246903681, 0.44677553405529213, + 0.9918545507928869, 1.060951650228274, -1.3653319701374311, 0.2635328688559797, -1.6894618652561055, + -1.9316398959917604, -1.6545844047461316, -0.8374565390669062, 0.5467667551875302, -1.1703334497283162, + 0.79898936445238, -0.48742537394255603, 0.05126348262407365, -1.0630031367885744, -1.6755563384807575, + -1.7470496911251123, -1.4045037572548411, 1.697678496203098, 0.541058257415223, 1.9355948975325852, + -0.8470115353500489, -1.2030885197848056, 0.8919323754916997, 0.0702516207867685, 0.5155253592422371, + -0.3579514965668338, 1.7112737380442171, 1.9947965056065957, 1.2741397687110538, 0.09885151767767297, + -0.9770807797341039, -0.11682236263324342, 0.7272198637411007, -1.987824039940028, -1.1358258258310752, + -0.11090836034305251, -1.9915135816887366, 0.39056056844969866, 1.2932859858303178, 1.7109978939050503, + 1.1846928384025448, -1.7330449982026206, -1.1525984164585106, 1.104166927781134, -0.28565377527521196, + -0.9685059019914002, 1.7093828969134002, 1.9709107005494806, 0.049031526597832276, 0.4472417265612556, + 1.0921859039999235, 0.8763632205063123, 0.8161138478535914, -1.0720275802414108, 0.7266994153226873, + 1.233185460886041, -1.127435043988318, 1.0918127239321773, -1.8540096367958645, 1.9681192361925266, + -1.176325917090126, -0.30265014266672097, -0.44524467812690727, -0.9978154618024702, -0.667478816738317, + 0.15065079333379305, -1.0302715841959227, 0.829863553229278, 0.8134089689909292, -0.6474889076993566, + -1.079618527738825, -1.783292588379826, 0.748112963221554, -0.6286053844150628, 0.48331041409284303, + 1.663305348437456, 0.18479680885937455, 0.6293791717008288, -0.6005275360880811, 1.5747695362530774, + 1.5708476785905807, 0.5755861487097542, 1.2041008720516082, 1.6685888824542738, -0.10278064261508835, + 0.9057150675313927, 0.6510335974298398, -0.10744692672758216, -1.7129461062136837, -1.4064873182457918, + 0.4781316094642234, -0.37189635012197275, 0.16614793992804522, 0.03645962620683285, -1.6224894209420242, + -1.8138940006820983, 0.5069783696842336, 0.6849365239989318, 0.8037589654894051, 1.979213666270276, + 1.6861127134381242, 1.2233661916798626, -0.3986509966310168, 1.274497735591801, 1.605857523477285, + -1.2118797206236485, -1.0066619307873124, -1.358968189183389, -1.9510798049888383, -0.9808314235618916, + 1.742926920936518, -1.1022645984613817, 1.0929594394621382, 0.48488158650621127, -0.32877770055973077, + -0.47650834081572935, -0.5160849885006016, 1.3738126318494883, 0.8827072110361662, 0.48644110690758247, + 1.0382179714335322, 1.6721919595070132, 1.341715329629717, 1.7295025892939409, -1.522344995293861, + -1.5965490751871654, -1.7983927509857223, -1.0759710407128011, -1.3793282201703079, -1.443902375079908, + -0.9426382639949571, 0.5602210832754357, -1.0965977429851064, -0.19857124750589872, -0.7431182359930233, + -1.2699260459939286, 1.4876549876992726, 1.6274319173214051, -0.3309529465344534, -1.9454352826883534, + 0.12935585140981676, -1.0093456723551508, 1.7243377444859647, 0.10127369924344443, 0.697537788222963, + -1.521134755613331, -0.442714777461525, -1.6896188579102178, -1.9330985764980921, 1.9140786772267155, + -1.2925077312416482, -0.9509978830442094, -1.1889787216631778, 0.795835379830006, 1.4837581515063887, + -0.8597344665233413, -0.7448025823499504, 1.7455825639820093, 0.33723505300912304, -1.8208678041990423, + -0.12753920031860666, -1.757360720716986, 0.8256076807737855, -1.9972549760931448, 0.844750409785961, + -0.9594803067513551, 0.7862268813645183, 1.7393046013815212, -1.161126895447727, -1.6347790700422697, + -1.3348870119154936, -1.1621632421015011, 1.2696646718252413, 0.4845759791644788, 1.0668299384975475, + -0.6942334010657198, 1.4734240949292259, 0.4282074397978146, 1.6699946816827183, -0.6802086123370898, + 1.9208442056043609, -1.8532082289660545, -0.1592261674427098, -1.2431462763761214, -0.7286614982674164, + -1.522868872353036, -0.3825873577199159, 1.431979005569458, 0.43719966684470446, 1.6478260330278633, + -0.06620691473965401, -0.36945868484144917, -0.3809516652838498, 1.6855172591399752, 0.31969027259376137, + 0.09157179754578149, -1.3138870107882425, 1.4208147318276607, -0.03157398665509881, 0.03702456744844529, + 1.4819698957492982, -1.6015809663105944, 1.8331399913105164, -0.6094891041007129, 0.9393020005799118, + 0.6313754553821562, -0.3128111370670492, -1.324295564232262, 1.7609361120800635, 1.5935407847968914, + -1.280640014867119, 1.4668684416985176, 1.460389700948717, 1.0299991397017587, 1.2266139129378075 + ], + "dims": [2, 2, 320], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1.1067028045654297, -1.6136860847473145, 1.261694312095642, 0.5976569056510925, 3.0194122791290283, + 0.6702871322631836, 0.10282492637634277, 0.3735429048538208, 5.239527702331543, 0.9048061370849609, + -2.0244898796081543, -2.0475926399230957, 0.08633577823638916, -0.7819606065750122, 3.528346300125122, + 0.796191930770874, 1.819821834564209, 3.565944194793701, -1.9118807315826416, -3.2814927101135254, + 0.6710286140441895, -1.8751076459884644, -1.4806747436523438, -1.796984076499939, -0.7681794166564941, + -3.1337573528289795, 2.5466723442077637, -0.710807740688324, -1.3013415336608887, -2.010349988937378, + -0.5839980244636536, -3.2363412380218506, 0.34071028232574463, 1.8070162534713745, -2.2573466300964355, + 2.2563209533691406, -2.721301555633545, -1.7052738666534424, 0.020109310746192932, 1.8161461353302002, + -1.2339590787887573, -0.3481786847114563, 0.14459623396396637, -0.2792869210243225, 0.3242926299571991, + 0.8492016792297363, -0.9436420798301697, 2.8654322624206543, 0.018474817276000977, -0.33994853496551514, + -1.4109879732131958, -0.17846572399139404, 2.4232289791107178, -1.366482138633728, 0.8393897414207458, + -0.06912710517644882, -0.1260005384683609, 0.14877259731292725, 0.8378958106040955, 0.2032637596130371, + 1.6857033967971802, 1.683059811592102, 0.020472168922424316, 1.4638370275497437, -0.2821274995803833, + -4.081655025482178, 3.739361524581909, -0.5193736553192139, -0.5895893573760986, 0.13349902629852295, + -0.8314229249954224, -1.955358862876892, 2.2120022773742676, 2.9806900024414062, -2.594862461090088, + -2.5279524326324463, 2.2374868392944336, 1.651476263999939, -1.7969365119934082, 0.5194283723831177, + 3.9661808013916016, -0.6912452578544617, -1.615968108177185, 3.2498717308044434, 1.6176962852478027, + 2.765564441680908, -2.780879497528076, -0.943089485168457, -2.211867332458496, 1.5465649366378784, + 1.5535883903503418, -3.7050137519836426, 0.04439067840576172, 0.36327028274536133, 0.9592444896697998, + 2.2070963382720947, -4.852853298187256, 0.6495039463043213, -1.601344108581543, 3.303436756134033, + 3.5125482082366943, -3.4023211002349854, 0.7367992997169495, 2.8616340160369873, 0.557513952255249, + -1.2049691677093506, -0.19210883975028992, 1.6728510856628418, -3.3260436058044434, 4.706552028656006, + 2.2566816806793213, 2.644940137863159, -0.9390113353729248, 1.6396602392196655, -1.3574936389923096, + -3.6357059478759766, -1.3324334621429443, 0.8182520866394043, -3.782191753387451, 2.5362539291381836, + -2.8861687183380127, 2.2147746086120605, -0.7912830114364624, -1.3549126386642456, 2.932422637939453, + 0.6247330904006958, 1.6168872117996216, 0.9066742658615112, -1.156375527381897, 2.196871757507324, + -1.3269041776657104, 0.7688918113708496, 0.02223837375640869, -1.3422014713287354, 2.5085129737854004, + -0.8842201828956604, -2.039457321166992, -0.0754881501197815, -0.4683438539505005, -2.3120336532592773, + 0.4231855869293213, 1.7217100858688354, -1.4091691970825195, -2.062229633331299, 2.0696098804473877, + 1.2929754257202148, -0.21851062774658203, 2.792795181274414, -0.24259614944458008, -1.6432653665542603, + 0.2709762454032898, 2.5165672302246094, 1.4215764999389648, 2.406688690185547, -3.7345216274261475, + -2.1278839111328125, 3.311349868774414, -4.237924575805664, -2.4865145683288574, -0.4375068247318268, + 1.7486937046051025, 0.9667145013809204, 0.7027313113212585, -1.8740135431289673, -0.3525621294975281, + 0.19565200805664062, 0.40774744749069214, 2.2967820167541504, -1.8403133153915405, 1.831811785697937, + 0.9851721525192261, 2.7873969078063965, 1.0879806280136108, 2.5585243701934814, 1.9414751529693604, + 1.6000714302062988, -0.12208014726638794, -2.56121826171875, -4.894813060760498, -2.881957769393921, + -2.041257381439209, 2.9550018310546875, 0.5040202736854553, -0.27999716997146606, 1.0042527914047241, + 2.926683187484741, 1.3717838525772095, -0.24589979648590088, -4.2212233543396, -2.1938352584838867, + -1.6489169597625732, -3.442727565765381, 2.948969602584839, -2.7220163345336914, -3.187354803085327, + -0.34392428398132324, 1.470370888710022, -1.630984902381897, 1.2510205507278442, 1.1136020421981812, + -3.759488344192505, 1.4942673444747925, 3.067783832550049, 3.345754384994507, 2.6331236362457275, + 0.9775646328926086, 1.2827643156051636, -2.623198986053467, -1.1612101793289185, 1.7932779788970947, + -0.332869291305542, 2.42099666595459, -0.9636011123657227, 4.5822649002075195, 0.8944255113601685, + -3.2404866218566895, 2.7085609436035156, -0.4827519655227661, -2.3480019569396973, -3.114384174346924, + -0.8162459135055542, -0.9214845895767212, 1.2764832973480225, -3.152130603790283, 1.567040205001831, + -1.699249505996704, 0.3841613531112671, 1.300299048423767, -2.6244685649871826, 0.1572742760181427, + -2.503662586212158, -4.367088317871094, -0.9085763692855835, -1.3322471380233765, -1.8894531726837158, + 0.7199447751045227, -2.851144790649414, 4.080941200256348, -0.541861891746521, -1.1072325706481934, + -0.6561694145202637, 0.40478527545928955, -0.8838909864425659, -0.5028785467147827, 0.7957435250282288, + -3.4829330444335938, 1.046553611755371, 2.5124118328094482, 0.3735085725784302, -1.0879991054534912, + -0.09173977375030518, -3.4051504135131836, -2.2644267082214355, -0.9162223935127258, -3.4872522354125977, + -2.355233669281006, -1.8244541883468628, 4.704746246337891, 0.4475516676902771, -1.2546875476837158, + -0.44408249855041504, -1.924820065498352, -2.729738235473633, 4.391683101654053, -1.0688762664794922, + 2.174078941345215, 2.718625068664551, 0.7366507053375244, -1.9571187496185303, 1.1222915649414062, + 2.276261806488037, 2.0843756198883057, 1.5358469486236572, -0.14141148328781128, -1.5720349550247192, + -2.324619770050049, 2.1180672645568848, -0.757960319519043, 1.402897596359253, -2.846881628036499, + 2.5358057022094727, -1.274275541305542, -0.14995357394218445, -1.3371965885162354, -0.4439084529876709, + 1.4503703117370605, -1.6082179546356201, 3.0019733905792236, -0.9571952819824219, -1.6500767469406128, + 1.6778243780136108, 2.374703884124756, 3.5679006576538086, 0.20018166303634644, -1.0103645324707031, + -1.480147123336792, 0.19532525539398193, -2.786205530166626, 1.3784115314483643, -1.2978419065475464, + -3.5328032970428467, 3.0164525508880615, 2.0895931720733643, 3.5052330493927, -1.9349405765533447, + -1.311628818511963, -1.7713117599487305, 2.886934280395508, -1.5496007204055786, -0.8046841025352478, + 1.5652999877929688, 0.1403025984764099, -2.6447391510009766, -2.0337233543395996, -0.22587481141090393, + 1.8628309965133667, -3.466338634490967, 2.2385408878326416, -0.858932614326477, 0.6435102820396423, + -1.1014156341552734, 0.6221705675125122, 1.3742595911026, -0.24308213591575623, 1.8533508777618408, + 0.14410161972045898, 3.0187618732452393, -0.33525052666664124, 0.290519118309021, -1.2579193115234375, + -1.3335667848587036, 0.4902459979057312, -2.2434842586517334, -0.6882419586181641, 2.9724576473236084, + -1.2139863967895508, -1.9754445552825928, 0.11754357814788818, -3.2436463832855225, -0.29947084188461304, + 0.9013328552246094, 0.025318264961242676, 0.9405116438865662, -2.2489869594573975, -1.2323944568634033, + 1.4659011363983154, 1.380167841911316, -5.245995044708252, 3.716740131378174, -1.5962101221084595, + 1.7039341926574707, -0.24453751742839813, -0.47277745604515076, 2.6836142539978027, -4.659006595611572, + -0.2703670263290405, -2.802849054336548, -4.558082103729248, 1.3134486675262451, 1.3934195041656494, + -0.9713399410247803, 2.829873561859131, 0.5422300696372986, 2.14626407623291, -2.411435127258301, + -0.8668385744094849, -1.579006314277649, -1.0427988767623901, -0.3021366596221924, 0.7571608424186707, + -0.7852025032043457, 2.0103890895843506, 2.875030994415283, -0.6650004386901855, -4.240952491760254, + -3.7397704124450684, 0.06430482864379883, 1.2097631692886353, -1.621443510055542, -1.7518706321716309, + 2.0040979385375977, -2.1621170043945312, -3.4342057704925537, 0.4494125247001648, 1.0336246490478516, + 0.6141325235366821, 1.1723374128341675, -1.566636085510254, -0.0875391960144043, -1.5110716819763184, + 1.4554129838943481, 0.839878261089325, 2.398009777069092, -0.6458553671836853, -0.7608357667922974, + -2.0972063541412354, -0.596686601638794, 1.327064037322998, 1.2332861423492432, 0.643580973148346, + 0.2491741180419922, -1.1464729309082031, -1.6413570642471313, 0.4765915870666504, 1.1993881464004517, + 0.2358156442642212, 0.8658393621444702, 1.8936083316802979, -3.0983033180236816, 1.2818799018859863, + 1.0561144351959229, 0.18877224624156952, 2.373169422149658, -2.1537320613861084, 1.7804971933364868, + 1.7559447288513184, 0.5495958924293518, -0.29311543703079224, 0.7076770067214966, 1.3824928998947144, + 2.5599937438964844, -2.2310054302215576, -1.3870820999145508, 2.705214500427246, 2.692167282104492, + -0.3191862404346466, 2.2299273014068604, -2.8660874366760254, 0.04656076431274414, 1.0372791290283203, + 0.9051024913787842, -0.7127535343170166, -0.346563458442688, -1.8466299772262573, -1.776979684829712, + -0.7937185168266296, 2.6496312618255615, -3.1376733779907227, -0.5262937545776367, 4.203805446624756, + -3.0495786666870117, 3.059788465499878, -0.6179596185684204, 1.5632293224334717, 4.387739181518555, + 2.1877965927124023, 3.867405891418457, 1.6019251346588135, -3.1097412109375, 0.14593756198883057, + -1.4151546955108643, 2.8710670471191406, 1.281739354133606, -1.9452589750289917, 0.3256327509880066, + -1.0140762329101562, -2.1761093139648438, -0.36153650283813477, -1.9866083860397339, 0.20329490303993225, + -2.189547300338745, 2.0582122802734375, 0.44074079394340515, -3.6016898155212402, -1.0940327644348145, + 0.05166494846343994, 3.9986839294433594, 0.007254809141159058, 2.9994473457336426, -0.17313486337661743, + -2.319499969482422, 2.1396687030792236, 0.23967742919921875, -0.9820348620414734, 0.7810753583908081, + -2.565080165863037, -0.3542521595954895, -0.39312660694122314, -3.611963987350464, 0.042843759059906006, + -1.9587305784225464, -1.0954759120941162, -3.2344908714294434, 0.6816467046737671, 0.7935110926628113, + 2.3788259029388428, 0.8960800766944885, 2.7103538513183594, 1.0750906467437744, -1.3195565938949585, + 0.6368587017059326, 0.09530603885650635, -4.324446201324463, 0.31018364429473877, -1.4680615663528442, + -0.8505295515060425, -0.4297642111778259, -2.9845335483551025, -1.073625087738037, 3.111997127532959, + -1.082578420639038, -1.0510170459747314, -1.0351759195327759, 2.1196703910827637, 3.6743626594543457, + -0.10965263843536377, -2.8268239498138428, 1.9994782209396362, -2.2761340141296387, -0.037992119789123535, + -0.5068368911743164, -2.466184377670288, 0.8389816284179688, -0.9829720854759216, -0.578821063041687, + 0.3714909553527832, 3.968106746673584, -1.0078635215759277, -1.4665645360946655, -0.24487531185150146, + -3.812358856201172, 0.8614283800125122, 1.251778244972229, 4.411714553833008, 3.906099319458008, + 1.1894285678863525, 1.3625565767288208, -1.2013204097747803, -3.780947208404541, -0.7636905908584595, + -1.4467679262161255, 1.9876563549041748, 0.7255282998085022, 1.8526909351348877, 2.2311062812805176, + 0.7617504596710205, -0.3560359477996826, 1.834754467010498, -2.417194128036499, -3.032979965209961, + 0.3447788953781128, 0.193556010723114, -0.4079936146736145, 1.300100326538086, -0.19834625720977783, + 2.222346782684326, 3.1362013816833496, 1.2092983722686768, 1.5581995248794556, 1.3155611753463745, + -0.8380979299545288, -1.2280077934265137, -3.5234897136688232, -0.32684326171875, -1.4621152877807617, + 0.428300142288208, -0.5776108503341675, 0.9278461337089539, -2.4938998222351074, 1.2017678022384644, + -0.2525625228881836, 1.0117347240447998, 1.7265347242355347, -0.10318005084991455, -1.492635726928711, + -1.2622400522232056, -3.0749173164367676, 2.4151294231414795, 1.1957623958587646, -2.7823221683502197, + 0.6365658044815063, 0.18952512741088867, -2.0210397243499756, -0.3540761470794678, -2.876804828643799, + -2.8968381881713867, 0.17692947387695312, 1.728485107421875, -2.2341482639312744, -4.501170635223389, + -1.6425974369049072, -0.9404029250144958, -1.1620832681655884, -1.0455152988433838, 1.3684580326080322, + -1.4598485231399536, -1.6593886613845825, -1.0509099960327148, 1.3251757621765137, 2.258070468902588, + -3.5802016258239746, 2.863391876220703, 1.0157440900802612, -1.5516963005065918, -5.100094795227051, + -2.9607906341552734, -1.1230504512786865, 1.9419206380844116, 0.4938334822654724, -2.3842170238494873, + 0.1679488718509674, 1.827955961227417, 0.10622924566268921, 1.8168610334396362, 0.677010178565979, + 3.0694189071655273, 0.3993656635284424, 2.3529860973358154, -1.4582010507583618, -1.5138496160507202, + -1.5133174657821655, 2.0854310989379883, 1.6874661445617676, -0.6133178472518921, -0.9184160232543945, + 3.041386842727661, -0.8360755443572998, -1.2672674655914307, 0.27318108081817627, -0.906801700592041, + -0.2576174736022949, 0.8814321160316467, 0.9032235145568848, 1.5922852754592896, -0.5044339895248413, + -1.0950052738189697, 0.9084010124206543, -2.3912510871887207, -4.171522617340088, 4.554413795471191, + -0.3333394527435303, 1.5956521034240723, -1.197889804840088, 0.8468800783157349, -0.39928677678108215, + 0.7615669369697571, -3.205524444580078, 2.535108804702759, -0.4366309642791748, 2.1470766067504883, + -0.25451767444610596, 0.23135042190551758, 3.335973024368286, -0.19102385640144348, 0.8413820266723633, + 2.2614195346832275, -1.1231105327606201, -1.3756293058395386, 0.17654633522033691, 0.5028480291366577, + 0.7965704202651978, 0.867662250995636, -4.270709991455078, -1.4976004362106323, 3.333491325378418, + 1.6522053480148315, -3.4461770057678223, -1.0945802927017212, -1.1912789344787598, 0.5186694860458374, + 1.525572657585144, 0.4644775390625, -0.5472983121871948, -4.093353748321533, 1.6807860136032104, + 2.2575550079345703, 0.9947443604469299, -4.168862342834473, 0.09030676633119583, 1.3352301120758057, + 0.37972205877304077, 2.2988173961639404, 0.8671650290489197, 1.040745735168457, -3.316119432449341, + 2.3733606338500977, -2.248332977294922, 1.7465157508850098, 0.19552722573280334, -0.9690064191818237, + -1.8139621019363403, 1.9242961406707764, -1.9793150424957275, -2.789724349975586, 0.18952327966690063, + 0.5084639191627502, -0.054778456687927246, 0.2740379571914673, 2.1619865894317627, -4.095170497894287, + -3.142530918121338, 1.1796610355377197, -0.8848727345466614, -1.2477298974990845, -0.07429039478302002, + 0.9135949611663818, 0.21963024139404297, -1.9909381866455078, 1.99857497215271, 1.4466471672058105, + -1.1016359329223633, -2.8484303951263428, -3.1158666610717773, 4.74474573135376, -1.1900646686553955, + -3.1329240798950195, 2.125332832336426, 1.9798109531402588, 2.6058056354522705, -2.0495054721832275, + 0.028982579708099365, -0.5753974914550781, 2.7390692234039307, -2.3111703395843506, -5.434136390686035, + -3.3772997856140137, -0.37978899478912354, 3.2925407886505127, 3.671295642852783, 0.7639904022216797, + 3.1895627975463867, 0.15414607524871826, -0.6484872102737427, 2.18841552734375, 0.4799572825431824, + 0.1354406625032425, -2.747096300125122, -0.22751712799072266, -0.7596011161804199, 0.5766011476516724, + -0.017207175493240356, 2.4283714294433594, 0.5117142200469971, -0.8030692338943481, 0.44569623470306396, + 1.076960563659668, -1.8645645380020142, -2.2490062713623047, -1.6578664779663086, 0.6149722337722778, + 4.706758499145508, 0.38176798820495605, -4.501796722412109, 2.2427682876586914, 2.1858701705932617, + -1.8162599802017212, 0.9385958909988403, -3.889805316925049, 1.1331977844238281, 1.0191240310668945, + 2.4039511680603027, 4.155160427093506, 4.143398284912109, 0.7778210639953613, -2.2585456371307373, + -1.085227608680725, 1.7663249969482422, -2.8071107864379883, 0.5367544293403625, -0.02325284481048584, + 1.5876415967941284, 2.046140670776367, -3.832700490951538, 0.46683841943740845, 0.5545571446418762, + 1.8768221139907837, 0.7790337800979614, 0.38500359654426575, -2.3040874004364014, 1.1112343072891235, + -2.2824416160583496, -0.026048898696899414, -0.27540627121925354, 0.5449916124343872, -2.154345989227295, + 0.7431529760360718, -0.008791446685791016, -2.407325506210327, -0.46152830123901367, 1.6632401943206787, + -1.7320727109909058, 1.0486053228378296, 1.3803236484527588, 0.3680152893066406, 1.716249704360962, + -2.2865381240844727, 0.14729297161102295, 1.260400652885437, 4.922313213348389, 0.5643207430839539, + -4.42134952545166, 0.464374303817749, 0.59236741065979, 1.2845817804336548, 1.4366343021392822, + -0.0200042724609375, 1.6293566226959229, -1.3861595392227173, -3.4724128246307373, 2.1383941173553467, + -2.1009442806243896, 3.7689297199249268, -2.918327569961548, -0.27357161045074463, 0.9184791445732117, + 1.0513062477111816, 1.9957637786865234, -3.276752233505249, -1.6878246068954468, 4.714818954467773, + -0.9857031106948853, -0.6153162121772766, -3.6428263187408447, -0.30243179202079773, -0.4309789538383484, + -0.03419780731201172, -1.6013574600219727, 2.214989185333252, -1.1272412538528442, -1.6917750835418701, + 1.547987699508667, -0.5724269151687622, -0.47848212718963623, 1.742186427116394, -0.2213730812072754, + -0.8063536882400513, -3.479326009750366, 0.5662966370582581, -1.0524877309799194, 3.702444553375244, + -0.8636859059333801, -1.9768422842025757, 2.1982383728027344, 1.1405737400054932, 0.6146906614303589, + -3.5127429962158203, -0.8339279890060425, -2.914233446121216, 4.411269187927246, -2.3479251861572266, + -0.2184194028377533, 1.7971992492675781, -0.9596229791641235, 0.09081411361694336, -0.4546387791633606, + -0.38310706615448, -0.5399283170700073, -0.2518271207809448, -3.5085813999176025, 0.077769935131073, + -0.5233420133590698, 1.4064757823944092, 0.8371680378913879, 1.9668782949447632, 1.483221173286438, + 1.1757903099060059, 2.9970226287841797, 0.8735387325286865, 0.24936652183532715, -0.7718344926834106, + -0.9049572348594666, 1.9130113124847412, 1.9097952842712402, -3.3667526245117188, 1.1342090368270874, + -0.1385430097579956, -0.6781283020973206, -0.57246994972229, 0.9787319898605347, 3.6297695636749268, + -2.3075175285339355, -0.8269498944282532, 0.8386117219924927, 2.4571895599365234, -0.9069632291793823, + 3.1065073013305664, -0.786931037902832, 0.6695969700813293, -3.896576166152954, 1.6415526866912842, + -2.334099531173706, -2.991877555847168, -0.4740370512008667, -1.0762873888015747, -0.5927379131317139, + -2.2995433807373047, -3.4549155235290527, -2.033919334411621, 4.102828025817871, -2.922405481338501, + 0.9117567539215088, -2.0445048809051514, -2.740710973739624, 1.5500695705413818, -1.4105217456817627, + -3.672469139099121, -0.38663938641548157, 0.1100814938545227, 0.43851518630981445, -1.4003627300262451, + 0.8104124069213867, -2.6236252784729004, -0.40968263149261475, 4.816134929656982, -1.9591403007507324, + 1.2284891605377197, -3.4595632553100586, 2.2904000282287598, -3.7264821529388428, -1.8221375942230225, + -0.44476717710494995, -3.0978899002075195, -1.0302362442016602, 0.49443352222442627, -4.997615814208984, + 2.7403292655944824, -0.9162295460700989, -0.8933342695236206, 0.8124216198921204, -1.433485746383667, + 2.224909782409668, 0.6821490526199341, 4.009047508239746, 2.2991182804107666, 2.677088499069214, + 4.353694915771484, -2.2315590381622314, -1.5339090824127197, 1.626473307609558, 1.9017658233642578, + 0.9766815900802612, 2.563782215118408, -0.2381199598312378, -1.5801150798797607, 2.601571559906006, + 1.7336980104446411, 0.8148760795593262, -3.7112083435058594, -1.3511030673980713, -1.8034800291061401, + -2.950260877609253, -0.4626041054725647, 2.9033288955688477, 2.629671812057495, 0.1508030742406845, + -0.6016277074813843, 1.9043893814086914, -2.119884967803955, -3.048208236694336, 1.3438254594802856, + 1.039656400680542, 0.0982772707939148, 0.29122409224510193, 1.8302289247512817, -1.3148270845413208, + 1.16330885887146, 1.4336774349212646, 3.057568073272705, -0.2635994255542755, 0.3290955424308777, + 0.09837156534194946, -0.4767574071884155, 1.7474840879440308, 0.036291711032390594, 2.057096004486084, + 1.5909944772720337, -1.5554354190826416, 0.45638948678970337, 1.8485369682312012, 1.1001496315002441, + -0.448333740234375, 0.12344249337911606, -2.2758660316467285, -0.18728435039520264, -1.0710699558258057, + 4.759674072265625, -2.308614730834961, 1.3315553665161133, -1.7046191692352295, -1.4248977899551392, + 0.31045258045196533, 0.4682546854019165, -0.5506991147994995, -1.167902946472168, -0.033889174461364746, + 1.2611976861953735, 3.584254264831543, -2.8943490982055664, -1.0763990879058838, -1.9304077625274658, + 1.6052935123443604, -2.1086959838867188, 0.16277271509170532, 1.087416172027588, -4.894248962402344, + 1.6908477544784546, 2.445591688156128, -0.8808413743972778, 1.1168533563613892, -0.35605037212371826, + 1.0386931896209717, 1.4661989212036133, -3.5512571334838867, -1.364258050918579, -2.697364568710327, + -2.279731035232544, -2.2294161319732666, 2.072256088256836, -1.7556955814361572, 1.5964255332946777, + -1.102902889251709, 0.25835680961608887, 0.41171061992645264, 2.6033897399902344, 1.931307077407837, + -3.2382147312164307, -0.6146683692932129, -0.7102252244949341, 2.314451217651367, -0.697837233543396, + -1.0293060541152954, 1.020631194114685, -3.6274075508117676, -1.869258165359497, 0.8600494861602783, + -3.1400227546691895, 2.785465717315674, 1.5925650596618652, 0.5685530304908752, -2.385671377182007, + -2.064885377883911, 1.7148135900497437, 4.472785472869873, -1.6981213092803955, -2.8710732460021973, + -1.5905208587646484, 1.7118372917175293, -0.0628599226474762, -0.024645566940307617, 3.5579423904418945, + 0.043785035610198975, 0.1394520401954651, 0.705904483795166, 0.5603584051132202, 1.9532432556152344, + 0.17339491844177246, -0.5621342658996582, 2.166140079498291, -2.675852060317993, 3.4783935546875, + 0.005500435829162598, -3.0466365814208984, 2.9333863258361816, -3.0374748706817627, 3.3186678886413574, + 1.4340720176696777, -3.4607982635498047, -0.5809862613677979, -1.8670074939727783, 0.31782853603363037, + 5.3407721519470215, -0.11647498607635498, -1.127880573272705, 1.9607118368148804, 1.6104774475097656, + 1.291121006011963, 1.3090026378631592, -4.346621990203857, 0.9649640321731567, -0.3321942090988159, + -0.40106678009033203, 0.6202027797698975, 0.737052857875824, -0.6949820518493652, -1.6063098907470703, + -1.335120677947998, 1.2487294673919678, -1.206929087638855, -3.946974754333496, -0.038611650466918945, + -2.730283737182617, 1.3170448541641235, 2.586440086364746, 0.9609778523445129, 0.9639753699302673, + -1.0613958835601807, 2.2582998275756836, -0.341133713722229, -2.37121844291687, 1.9750676155090332, + -3.339621067047119, 3.8494720458984375, 1.4416704177856445, 1.2632763385772705, 0.49889105558395386, + -1.358637809753418, -0.9810472130775452, -2.008033514022827, -4.180141925811768, -1.8064439296722412, + -2.4055490493774414, 0.8236247301101685, 0.08107048273086548, 0.2551090717315674, 1.6475603580474854, + 2.3599026203155518, 1.4368640184402466, 0.11961638927459717, 2.1902809143066406, 2.4481637477874756, + -3.700338363647461, 1.4484524726867676, 2.2457919120788574, 2.7918150424957275, -0.3214719891548157, + 1.1412039995193481, 2.3801662921905518, -1.1772977113723755, -1.080161690711975, -0.10960137844085693, + 0.23755615949630737, -1.6492403745651245, 1.5414122343063354, -3.5301568508148193, 0.8169034719467163, + -2.7907910346984863, 1.27809476852417, -1.339566946029663, 0.24068212509155273, 1.980497121810913, + -1.103304386138916, -0.9938054084777832, -1.6851425170898438, 1.5913416147232056, -1.6278176307678223, + 0.07544970512390137, -3.0562868118286133, 0.909795343875885, -3.3362603187561035, -0.16795992851257324, + 0.654058575630188, -0.7783452868461609, 3.801968574523926, -2.160207748413086, 2.5146727561950684, + 3.337472915649414, -0.7981523871421814, 0.7141947746276855, -0.44521379470825195, 0.7240567803382874, + 2.8061447143554688, -1.9281837940216064, 0.33615267276763916, -1.3853528499603271, 0.08603048324584961, + -1.1986116170883179, 0.8860710859298706, 0.22947335243225098, 0.5199316740036011, -2.0464673042297363, + -1.6756978034973145, -0.6069002151489258, 2.3337855339050293, 1.6094681024551392, 2.3820090293884277, + 0.8262057304382324, 4.417818546295166, 1.2495514154434204, 0.5728256702423096, 1.7155016660690308, + -1.9175611734390259, 0.8456315994262695, -1.3211780786514282, 0.7857818603515625, -1.7515380382537842, + -1.1971937417984009, -2.808197259902954, 0.7553633451461792, -0.1306423544883728, -3.6146838665008545, + -1.5321255922317505, 1.676008939743042, 0.07836294174194336, -0.9960349798202515, 0.8668472766876221, + 2.137298345565796, 1.2637362480163574, 2.0481185913085938, 0.06509196758270264, -1.6683127880096436, + -3.718193531036377, -1.9311506748199463, -3.367814064025879, 0.012331843376159668, -4.227578639984131, + -0.5755635499954224, 1.583990454673767, -0.86231529712677, -1.809625506401062, -0.3123876452445984, + -3.4403862953186035, 1.0367351770401, 1.1761391162872314, 0.16112631559371948, -4.721268653869629, + 0.07461494207382202, 0.003129124641418457, 4.358157157897949, -0.5005528330802917, 0.24370229244232178, + 0.4912611246109009, -0.6851872205734253, -2.718902587890625, -2.6848104000091553, -0.7679593563079834, + -1.1002662181854248, -0.3475220203399658, 0.002980828285217285, 0.4288327097892761, 1.283578634262085, + -2.613717794418335, -3.3328800201416016, 1.9472748041152954, 1.3897069692611694, -5.0092363357543945, + 2.518123149871826, -2.1634795665740967, 3.7558064460754395, -3.4893569946289062, -1.289191484451294, + -3.036714792251587, 1.6393659114837646, 3.2178215980529785, -2.083487033843994, -1.6242481470108032, + -0.39087945222854614, -2.675858497619629, 2.548295497894287, 0.6715638041496277, 1.0268739461898804, + 1.9520831108093262, 0.2520883083343506, 0.3550087511539459, 3.1073975563049316, 1.3005826473236084, + 3.4222006797790527, -1.5301063060760498, -0.19925060868263245, -0.7549142241477966, -1.7461345195770264, + 0.7768814563751221, -0.8777822256088257, 2.757373332977295, -1.4982573986053467, 2.4916207790374756, + 2.4712767601013184, -1.3903863430023193, -3.3709828853607178, 1.6688079833984375, -0.3028743267059326, + -1.8443574905395508, 1.8113127946853638, 1.7428357601165771, 1.2092206478118896, -0.5405560731887817, + 0.97336745262146, 0.36485183238983154, 0.17854809761047363, -1.6080548763275146, 4.408480644226074, + 0.6553859114646912, 0.9348583221435547, -3.3448331356048584, 2.513455867767334, 0.03000164031982422, + 0.705142617225647, 1.1035417318344116, 3.448455572128296, 0.7622013092041016, 2.463888168334961 + ], + "dims": [2, 2, 320], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/bias-split-gelu.jsonc b/js/web/test/data/ops/bias-split-gelu.jsonc new file mode 100644 index 0000000000000..23fcb488ca4d9 --- /dev/null +++ b/js/web/test/data/ops/bias-split-gelu.jsonc @@ -0,0 +1,1332 @@ +[ + { + "name": "BiasSplitGelu", + "operator": "BiasSplitGelu", + "attributes": [], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "bias split gelu [1,1,2560]x[2560]", + "inputs": [ + { + "data": [ + -0.2565546032426438, -0.4308542731494329, 0.7196725122004919, 0.049034255098233004, -0.6348569496555116, + 1.5952359184580631, 0.8451805251092992, -0.31838310590966934, 0.8187686985927041, 0.5208222841347814, + -1.36437690702164, 1.9883897131851596, 1.5564695381513953, 1.6179857166855847, -1.6925162826813818, + 1.7367654350206285, 0.7210294356326798, -0.16665830399749915, 0.7502629374213177, 0.49174373254887005, + 0.05228599187298144, -1.8492674426703974, -0.13963175206123601, 0.3689713052652124, -0.539726857153676, + 0.4047910328979869, -1.5364223249042084, 1.5401613243237753, -1.5533776810895263, -1.503725108181306, + 1.4980778036916167, 0.3954579111685037, 0.6720461504101429, 1.2082071287930374, -1.9053414457057452, + 0.5161807133471896, 1.876123128617479, 1.4026319996913186, -1.0578832074721918, -0.5667177472093234, + 0.7427152553125103, 1.3309336325346122, -1.250604470605933, -1.5581827636425984, 1.6094888834151977, + 0.8346945311402969, 1.964734881798945, -1.1480686695741893, -1.4692802823913134, 1.443324467502948, + -1.5059857923356113, -0.4324256636772654, 0.47787327572253613, 0.010830153648134555, -0.552568788591528, + 0.976425831078207, -0.3964598095089391, -0.8550178676766054, 1.7829908691025604, -1.4386851327204164, + 1.2586393384114123, 0.6282746108828974, -1.9228284428211788, -1.723501329370512, 0.07875604943128245, + -0.6805248059926141, -0.09444157676724618, -1.2578296605984862, 1.4748594927305314, -1.1326833840447499, + 1.2899356994911972, 0.36076376234182295, 1.0687687434117796, 0.016260539139403285, 0.8790330834375988, + 1.958147414219165, -0.6744379666162423, -0.11384066716133123, -0.8573301336636696, -1.6026243400276634, + 0.947103941481469, -1.9714924616237095, 1.4248854638955484, 0.20743958585530908, 1.1632144973105554, + -1.755170686870751, -0.7194149118428639, 0.14148466161285445, -0.2721811711092048, 1.5603318294054445, + 0.4281402047349676, 0.9777768965070042, -0.4340528216185948, -1.4027853075880978, -1.786451668170261, + -1.0399446061083921, 0.5016682459975055, 0.47912646038882833, -1.4994325065634833, -0.5813174761669684, + -1.7675659877499355, -0.5295448113440351, -0.185677936488422, 0.6133721122115032, 0.8342835422965917, + -1.8851704197237282, 1.8145028466620214, -0.6693817855471478, -1.2825600860901352, -1.9614837268208936, + 1.7498654134131364, -0.9137222020716864, -0.3419963114086366, 0.4718774290086678, -1.087807747894451, + 1.7131835407851748, -1.3262901364654693, -1.3116975227056313, -1.3834647528821646, -1.6786541344436587, + 1.7754973215348864, -0.2608287138885421, 0.7649932766417926, -1.620043859042676, 1.6639963441569705, + 0.8487054193287991, -0.0440434794759188, -1.7095242256175753, 1.285898353046436, -0.8541456852270608, + -1.0975953883519463, 1.6456192019739255, -0.9829899938198103, 1.5115291047537394, -0.1356101800325593, + 0.3163471293050524, -0.7350393764769096, -0.2635909692970628, -0.8961556079942969, 1.35594272476665, + 1.5346085770720466, -1.561332626472022, -0.7482744786530882, 0.7415275042470979, -1.3455806550694929, + -0.17403414417677876, 0.13174239791099396, 1.617942860808114, -0.21372789574955942, 0.1252618392038798, + 0.8356601282506393, -1.4645614042274238, -1.2750059490075998, 0.30114108212769697, 0.718759572394414, + -1.7476419557101002, 1.3774845116886816, -1.5548787928223922, 1.2843809545104072, 0.40166155990403407, + 0.7182073117046643, -0.9178049843649907, 0.23022861413351325, -0.16782915827518163, 0.809538498880066, + 1.5368543357363587, -1.9144180623212463, 0.22370787233170475, 1.754935494591769, -0.300812686337701, + 1.6324016835327004, 0.42942625235999277, -1.821577324695582, 0.7065235323323327, 1.6359230558031488, + -1.6698325562461855, -1.0322767487026727, 0.230813812551105, 0.10476742241238135, 1.8969745358419585, + 1.6692111883958614, -0.13495950982519744, -1.3187891429974368, 0.33479728114140084, -0.3567915710179923, + -0.757581852807693, 1.0471229062934597, -0.15293599163291116, 0.8603181880710826, 1.5915256233625152, + -1.146248153475943, 1.4333844348196614, -0.4549776123350959, -0.24793256123671004, -1.1103266883866416, + -1.8671773666157563, -1.2229905095092333, 0.06266819913029753, -1.0062234212958696, 0.8035035300131232, + -0.7305732410705756, -0.18203039831679124, -0.31014428798975313, 0.3622048186467879, 0.45985097885235504, + 1.4809283400531248, -0.43833106689297807, 0.04401934126909346, 1.5587830597086887, -1.6615450592134344, + -1.0549727490794032, 0.10619243811841983, 1.6457168234317425, 0.8542761080095067, 1.0005810856142707, + -1.5553877748221447, -0.9212057116120924, -0.8894180684757362, 1.2348207479698639, -1.867513894231462, + -0.3703762605562, 0.6409515313866843, 0.07519639585361748, -0.6450469579811191, -1.1954961398506443, + -1.1204107450288525, -1.7282576070578317, 1.3286514210595293, 0.9221546638897049, 0.3164348496087497, + 1.1124119320099153, -0.1657137161424238, 1.8878952460812632, 0.1293077579568429, 0.23358535468938246, + 1.6758942267934263, -1.4201635109571313, -0.15778983434253657, 0.7416701308494114, 0.5883762821820557, + 1.4927448431941999, 0.6935037789859804, -0.7735384747712253, -0.003653232296928266, 1.016177625769398, + 1.663989224847544, -1.860463076082154, 0.586684836283415, 0.35116588158246387, 1.1544623264859073, + 1.900422405885342, -1.4569177282686354, 0.04845063476308198, 1.3176042838894286, 0.7208418723320795, + -1.2204473994940885, -1.4000622968397813, 0.21650376984443565, 1.060631946423273, 1.5077365306547108, + -1.0212630681047514, -1.9198532775330452, 0.7154901442715236, -0.2676631096599271, 1.3808441670529703, + -0.09885367904367648, 1.4777610290353689, 0.056817122949397, -0.21052160911698614, 1.163420787671865, + -0.9278068609037406, 1.6649534188139832, 0.651092543699634, -1.9558728896680586, -0.17393499677517266, + 1.4142540331221616, -0.6414265768865439, -1.796476122987272, 1.3019592923170675, 1.8063975233357983, + -0.09846336861938809, -0.44669537007711746, -1.8509683249130777, -1.2535113000292872, + -0.26272617023889033, 1.618948771545468, -0.24333144562387243, 1.684567206060013, -0.0684671993602386, + -1.2513800864625626, -0.42209759253339296, -0.5204624726492675, -1.2354597847242017, -1.8741410257015954, + -0.4694216976292269, -0.6222814585320586, 0.21189389441688178, -0.07825818775306104, 1.7163253327182595, + -0.5828872666159741, 1.2859942412787504, -1.6213500187165488, 1.3292413342594758, -1.8919626255232016, + -0.5376250474966104, 1.8689278008038874, -1.7144714737277686, -1.0258522713372873, 1.7337707335280257, + 0.7736026666583085, 0.5294771168151202, 0.018158706442022776, 0.06181604275240726, 0.39921503387415935, + 1.1876014889699817, -0.3054069392890115, 0.16369159901033914, -1.7827759163094807, -1.972081887839714, + 0.6481860379313247, -0.35352594703309315, -1.3849570316603321, 0.623606163848514, -0.05008698162477465, + -0.7604090786231117, 0.018133973172791862, 1.4085369489144304, 0.0006579664080073044, -0.945901823508092, + 1.9768973551218298, 0.8031916889887203, 1.787407754969756, 1.4724476919448328, 0.4025600955058959, + -0.8188218839399735, 1.9450326825091242, -0.7081203957970317, 0.2756639708632225, 0.5758621242189532, + -1.3400477780804723, 0.9643250467598792, 1.6902983107628202, -0.9456991797498446, -1.214924995943468, + -1.7096508074160557, 0.6280377071410248, 0.043386564374169545, 1.4018487692877626, 1.9387558240379494, + 0.6404760132931342, 1.9907639837009397, -0.9629506658962566, 0.44160372290058625, 0.5311677453788013, + -1.98443782839251, -0.8098531120098231, 0.6492275695305256, -0.8908778962502977, 1.3356991895363226, + 0.22938352294803988, -0.7322473322214602, 1.5046263929054788, 1.3645550724874669, 0.9961702988823644, + 0.5139435687003742, 1.4501909090527283, -0.7800082224423388, 0.35291863289434566, -1.146090063932382, + 0.30350456267852977, 0.6439700368534709, -1.9972080095325557, -1.6099893116712707, 0.8052080484401971, + -1.0954234253610906, -0.7337556567039494, -0.31964730584596346, -0.0775965423638807, 1.9583401341749038, + -1.2551023178822023, 0.8562333066218226, 0.7147503883958679, -0.7448967487912448, 1.935551654646078, + -0.8599046782377142, 0.2970146461500134, 0.8930246286483605, 1.1182036059156406, -1.197201583834416, + 0.07161352111517694, 1.9408265571293892, -1.0932315538260982, -1.599500716435764, -1.3924326798020594, + 0.11983524432330928, -0.7679553338056699, 1.8370796383852186, -1.2704307623609719, -1.1997771708610943, + -1.888947914144012, 0.55632142669847, -0.34454450486308286, 1.4998341008691707, 1.5090724650893934, + -0.39294559939271156, -0.5834694562549148, -0.5369747247041268, 0.09250855982787431, 1.2568497810992714, + 0.3053435200201413, 1.9380606411981267, -1.7300384067590375, 1.8430090411386875, -0.09705704918122926, + -0.9748664840480341, -0.13310751475233928, 1.8695019507218804, 0.6192212840529407, -1.8594506513740239, + -0.5527467955831558, 0.42368164563035027, -0.3623000694050589, 1.4671722318441134, -1.9926015337743834, + 0.6390178007629022, 0.3930378886525858, -1.999173071586763, -0.10690790413254625, -1.8526590162958945, + -1.7626087675511721, -0.37087181562117433, 0.8387521856704474, -0.19069686876676517, -0.8188898303583647, + 0.12674602912986455, 1.1094188774808789, -1.3557441716004606, -0.6605764711080546, 1.5447276204253857, + -0.8769307424824397, -0.2545000118362921, -1.0291738091907403, 1.7938016119427465, -1.9984894749397677, + -1.2827802315328531, 0.9376743174843778, 1.6630184567373671, 1.3959796712419958, 0.9312602680557571, + 1.1022358675981447, -1.8339474355859497, -0.4903489710565081, -0.8895348660922506, 0.6526243952924284, + -1.3544145682501627, 1.7693021030400526, 0.16080394788547636, -0.012568942703699904, -1.15491020041351, + 1.2537510862653143, -1.0969062449531704, -1.2759047604132565, 1.5734543561168284, -0.10691180117407129, + -0.26934435340525376, -0.9933427134070101, 1.7043012450253494, 0.02549752010787465, -0.08701034102069993, + -0.2674176298500974, -0.08499649536913978, -1.9065384537204872, -1.8548124018210599, 1.4114174077197168, + -0.1003529232852527, 1.4168803569119728, 1.2546078925285231, 0.2836297019652072, -1.741854513232557, + 0.4335127344631058, -1.7722508801895138, 0.8906617987453824, -1.1907389082962334, -0.4570346317067928, + -1.65487466866219, -1.6799510105998428, -0.7345507991250928, 1.0450995087483994, 1.6718914380761127, + 1.9492487109001955, -1.603110762963735, -1.1386896023097748, -1.2314080609240623, 0.9155801473532383, + -1.7678615945378713, -1.2302067186131085, -0.008026479019882515, 1.7076934044219962, 1.2976204179783615, + 0.23979601472188072, -0.31452550079775676, 1.237656674724093, -0.5073967065142844, -1.4684831151454167, + -0.31751104037333455, -0.18632018989018295, 1.9021272175218185, -1.605133256102862, 0.7057429809617588, + 1.131130449558933, -0.37696933268230914, -1.9999062126002407, 1.2947486718643644, -1.5494680428486927, + -0.619492250822522, -0.7119383385868714, -1.2267826883342403, -0.05700963867832698, -1.6002949090802074, + -1.235668500595783, 1.0038351588940309, 0.9591804998296345, 0.3500270456646062, 0.9985378178389439, + 0.37800440518144107, -1.2954921521047265, -0.06652125976806822, -0.5529207993209635, 0.799277104478481, + -0.3195922296721463, 1.278513926080179, 1.7667762251394539, 1.6741322500537104, -0.7688735332563814, + 0.3718086223078023, -1.8559862073263451, 0.6460500799921309, 0.9538174223011966, -0.9599765113376018, + 0.2363354873707415, 1.1507344321172663, -1.5062754878549311, -0.6521456232305631, 0.26796953563726245, + -1.6029382384422037, -0.13330793525728346, -0.4700350464240364, 1.2426614036874923, 1.8410794059261688, + -0.29943068600730616, 0.23080019611542912, 1.1289988352417586, -0.9307260459011948, 1.5075366130642047, + -0.75430635423512, 0.7786360347486365, 0.43281144997343457, -1.9253410424489141, -1.7129308932126097, + -0.7115090993554389, 0.09728563961266712, 0.9612524095722907, 1.2677697737234288, 0.17580981533895557, + -0.07099813019347945, 0.6420914403628677, 1.929710466274961, 0.8224262711223203, 1.197715527419664, + -1.5306699759366271, -0.4379817520480751, -0.6193479789854166, -1.4632019512465932, -1.259260972161262, + 0.8218577982330961, -0.5872036188804506, 0.9619144554878618, 1.9136356497081524, -0.0054634952730427955, + 0.918959735070624, -0.5711442865104566, -1.7438335683612998, -0.09911336983085484, 0.9296115985185578, + 1.4616523373203654, 0.46274661162826547, 1.0360815568791644, 1.2212946749865532, 0.15313173395745583, + 0.822676634908933, 0.4284785502131063, -1.003861397431101, -0.1736541682765118, 0.6790160543077883, + -0.14098755750254632, 0.3123446972161563, -1.1989246192108354, -1.8129870085835993, 0.32605203312965525, + -1.1125265808505294, 1.7561842733362214, 1.7739812007749807, -1.4043301791912688, -0.8943531190371674, + -0.6297367970706924, -0.7981479808539875, 0.8794315514521189, -0.3994743473829594, -1.401573672109106, + -0.8912997118364743, 0.6495541504851907, 1.9230113410471903, -0.6785063630402766, 0.5089535993984642, + -1.8974271070828808, -0.2620702810213116, 0.1991815065696141, -1.207848636028186, -0.8986511371196979, + -0.7061642892592088, 1.9332451504308255, -0.9830761395756227, -1.244182804715014, -0.017993897992697683, + 0.8786296671983029, 0.540637680299664, 0.8969343374913903, 0.7955401708949532, -1.529657800511421, + 1.8639189521351138, 1.067177379712641, 0.7401683858489143, 0.7294620382453241, 0.5904153067983531, + -0.49843243603051146, -1.6804352670210738, -1.0693615378859294, -1.9090165306219946, -1.4036168596502474, + -1.2392401335977272, -0.7753989839018987, 1.8934648648263472, 0.3583307892541949, 0.1338886754962525, + 1.8690213745085558, 1.455690685933316, 1.1699346906107797, 1.6059636790463596, -1.2445702023047094, + -0.26540435557319864, 0.7367777535407951, -0.06848824820724975, -0.35847294758367365, -1.8422274132326661, + -1.0257628610133205, -1.6669678072109697, 0.37716994278958094, 0.238235954755047, -0.2999780255928437, + -1.8840499434617515, -0.2146515363892023, 1.683440503823257, 0.11413961404491424, -1.6059835147695, + 1.1951777078503847, -0.5174290916747344, 1.3440252628311207, 0.21500397686099326, -0.9421891710579917, + 0.0593410986318057, 0.34006850000048683, 1.644097326727004, 1.74676874874047, -1.6742906231992345, + -0.33329412020743643, -1.6254771174048148, -1.6581767039373991, 0.791705460659057, 0.9383035214148672, + 1.7805390345327456, 0.5776366760158806, 1.587860436382865, -1.5903762069130911, 0.16034052878776794, + -0.2414627652627388, 1.2751236768892227, 0.3209997221960421, 0.31176177950076234, 0.6234148156263783, + 1.144504541840126, 1.1535423529138784, 0.11665655599341473, 1.9697764827003628, 0.14558312336598078, + 0.36578124791517297, -1.769415346962682, -0.8303724165129278, 0.44703666963932154, 0.35095942056362794, + -1.197063815711486, -1.5390788457973201, -0.8313129454097989, 1.952907404456571, 0.30612523411761394, + -0.9380264922530621, 0.9822286847681259, -0.12269281399330456, 0.6557752769532215, 0.48870196679108435, + 1.347011507573625, 1.6519563808835915, -0.7385795429014586, -1.048048723619047, -1.0859902684402032, + 1.7187556784188924, -1.8663335499394762, 1.3448325108921972, 1.5973182779732955, -1.9246562924691855, + -0.896157435511153, 0.5381287932952938, -1.7180528810790427, 1.6998135385965343, -1.99001646343652, + 1.7545133161159834, 1.9333853932807186, 0.509746393965961, -0.0675591816104566, -0.5487596842885525, + 1.1654829640998834, 1.2418276538086106, 1.1046551528849635, 1.1541227103848648, 1.3688916432757363, + 1.8682574813085164, -0.9855654818965549, 0.5606398966348412, 1.9470096222814695, -1.9903658034181957, + -0.48344153717960925, 0.6353768282033556, -0.0513621772163555, 1.097950015862886, 0.24567995461843584, + 0.8059244051805585, 1.3521742423136072, 0.47230481196519936, 0.6133818327341913, 1.9135550317308576, + -1.8499763818701833, 0.5894967733509713, -1.920326871370679, 1.6405640996262383, -1.1695280806136248, + 1.4241423736983672, -1.67349053378969, -0.8405344375321118, 1.4551541878625045, -0.9884334361969964, + -1.274246357510064, -0.5067514453663833, 0.5492300144792424, -1.063358889380532, 0.3303561470509724, + -0.9579434760073298, 0.057620205260851876, -1.4742255794760677, -1.6308108759776871, 0.050309644830520917, + -0.4069039662995735, -0.046184249642958086, 0.07949429360802096, 0.12024378717844808, -0.7878588154579758, + 1.472302474285109, -0.29877233638272127, 0.3469939919860652, 1.0030914450113793, -0.033399023382078674, + 1.337733463292298, 1.322404140400704, -0.9344843531853249, -0.7431360466648114, 1.6844959250870701, + 1.361728870488129, 0.9343674315961268, -0.027258361373763584, 1.838512510349391, 1.6503912849225708, + -0.4099533673620952, 1.0611278108159041, -1.6127864143155026, 0.24393478918735934, -0.44450380494664365, + -1.6597118492450527, -0.6490315790577066, -1.7612694260424497, -1.4276602313620197, 0.940659328987655, + -1.263148442418843, 0.5147963537125868, -1.3438202624356137, -1.2843267233008957, -1.2655953556300048, + 1.8008689611198623, -1.0529595001338974, -0.8423440348538298, 1.4068621033910276, 0.007761032105414678, + 1.0935924640537378, -0.09622562614772612, 1.5915223843314257, -0.6702176163118629, 0.773624899360053, + -1.0994916609949348, 1.8469641173001987, 0.8777243936892871, -0.7039703784245424, -0.6127726098737645, + -0.8456557917329288, -0.21546583787614892, -0.9360415939958875, 0.5705354560747677, 1.4330326228154542, + 1.5996006008539831, -0.07367645458791117, -1.5628497471538445, 0.6808136908868017, -0.6639347562269995, + 0.9713336233567063, 0.914773615335891, -1.4236420984162779, -0.02242621653471044, -1.898261801965826, + 1.8572128094883267, 1.5674319640876826, -0.825255227018233, -1.2984530858686814, -1.8947198959796596, + 0.587966233462053, 0.7276082291256891, -0.7118337876069631, -1.9491744048376294, -1.9749840187715382, + 0.6636341374951193, -0.8911042565498537, 1.1064802004273853, 0.5698018751524465, 0.31161591659845556, + -1.4542202550786563, -0.4973178781500307, -1.261090074552551, 1.202428109165787, -0.6502124259985935, + 0.875011031676804, -1.0459657545700463, -0.42208724851625856, -0.056326319293169114, 1.033820082809723, + -1.7520891696720033, 0.8087980713817391, 1.4090272899721707, 1.7934272795205217, 1.625604511965781, + -1.1093024540507015, -1.1836309931833142, 1.2726582815916636, -1.348432635892622, -1.3318038843301023, + 1.8025664988508767, -0.9447266569517438, -0.5044961285117484, -1.6608791173499942, -0.5404247801346038, + -1.8417295161873213, 1.3703119773103376, -1.762654389000578, 0.654516908110395, 1.7444654879377754, + 0.06700313474848674, -0.11772438211787417, -1.4346934108510236, -1.3471639350541027, -1.8846130174975517, + 1.4341206352848488, -0.1781484331915868, 1.8660723422398098, -0.14067413462914313, 0.8690053039391579, + 0.14439977774786517, 0.39859516757572777, 1.338326821281048, -0.3540839451368738, -0.3799467769042604, + -1.6627115692361185, -0.27677435345149703, -0.9817434354820271, -1.7561599198792992, -0.09529627982909172, + 1.0892976295039958, 0.8013191504371955, -1.6378286446982058, -0.2422910554223776, 1.192900022577974, + 1.552603422632739, -1.9305123473184267, -0.20948858983751073, 1.9616644595376727, 1.6353585267617419, + 0.8302939534322658, -1.3824308029410997, 1.4671017215695894, 1.075525400633838, -0.8516882354421424, + -1.764852757300301, -0.6508588489877889, -0.5864767969763536, 1.1318668208548859, -1.2774769138035325, + 0.28054164298575834, -0.7225628233832841, 1.0930686104208345, 1.3757645119448334, -0.8436086923316299, + 0.15317111418019813, 1.0696425581898357, 1.2977798557148477, -0.9586985895916751, -0.9572006108424791, + -0.5173013825178812, 0.9558842805311762, 1.4588133070282732, 1.4300025626225263, 0.9236625391445488, + 0.6245988744867041, -1.4555968079681971, -1.9081528171774265, 1.4969305861374398, 1.8013611610843867, + -1.215264739954538, -0.6460349115684512, -1.7385666591205897, 0.4952215434575882, 1.7813317720717237, + 1.4779804705753685, -1.6448372710981527, -1.7914377349318746, -0.8701351514587072, -0.6086582613844049, + -1.7386736832736416, -1.6630595398426538, 0.06404471619092522, 0.12472538361116214, 0.25173024087120677, + -1.4925192750724054, 1.1326320024353294, -0.6723638830002354, 0.9276409891081574, 0.5160132920113458, + -1.7619800169226787, 0.7729306209435105, 1.4852229135771582, -1.4043797098218525, -0.26979894999654697, + -1.5676449182307888, -0.8285263778416416, 1.7968643376162596, 1.1503154963149846, 1.3682632091264955, + 1.9014644171021402, 1.2035803257561772, 1.5807589620662768, 0.6681530530278019, 1.7430055744872055, + 1.4516895295108698, -0.6636088362253298, -1.3265544726243066, 0.7260245399794885, 1.0068518018712478, + 0.2208570840730273, 0.8459119656002851, -1.8915833254450725, 1.2162158433713186, 0.9752766886721753, + -1.2607916054285697, 1.9003684087234687, -1.6824694825149118, -1.3545700227621689, 0.5912583336167279, + -0.7008114462062913, -0.022208913414461406, 1.4871167887266532, -0.47220337297808346, + 0.001402495408155957, -1.6337795432062068, 1.6557142707874517, -0.21911880468117495, 1.994215681564901, + -1.4675327906481472, 1.9744129850181764, -0.10991781070844464, -1.5582267736493964, 0.9509729601966104, + 0.27383527366630744, -1.0109967848840293, -0.2652951445752292, -0.31773126890169845, 1.9347214689284513, + -0.48900940865557896, 1.0348946328564068, 1.6101647718098393, 1.2224869337691553, -0.24528963586677577, + -0.8282995437227312, -0.74214677104667, 1.9022077987920571, 1.065511429772795, 0.5557978714258756, + -0.8552846035431614, 0.14131421568194025, 1.6415849500887356, -1.0979455229862056, -0.1899406250116744, + 1.882935380340533, -0.47086245203046495, 1.1173180765349162, -0.38373005169056196, -1.9204322585926832, + -1.9947555620271595, 0.25610856180969677, 1.672078838942273, -1.9323275104124873, -0.7955457526088345, + 0.9709167319971037, -0.6356155194456985, 1.6590519014057472, 0.2576466445092809, 0.06826734181142502, + -0.8077944062086111, 1.6116753503934271, -0.360670679232701, 0.665216992360695, -1.3443131827622485, + 1.004588252277319, 1.4090953833875757, -1.6465558763926547, 1.5390983488230878, -1.3071804786368446, + 0.6990024492525402, 1.0093027361688671, 1.097146827869869, -1.1912001568366906, 1.630219037886489, + 0.5744645582929406, 0.6113886454941007, 1.9913520749447864, 0.44093025553359233, 0.08768839917963245, + 0.8058107685571905, -1.033178784078995, -0.8225351475861098, -1.033391812065922, -1.9518439963887113, + -1.4013652343418483, -0.3575891026292508, 0.5391177792156094, -0.589608397581415, -0.8183723923536403, + 1.2663801923646707, 0.2780137273448853, -0.6282934713194672, -0.3515950748704757, 0.32366450506041744, + 1.0378037684314076, -1.4606869437722425, -1.1777315953586012, 0.20077305098986375, 1.1132819638088014, + -1.3403580753455646, -0.8471624166708764, -0.78670588443378, 0.411075409745707, -1.1567621787661597, + -1.644706858279382, -1.6923798602467333, 0.7295344690619672, -1.6826221173466598, -0.5883451091973848, + 0.401937953067649, -1.9630301043637237, 0.9577847247199625, 1.2919304030896734, 1.225292029861409, + 1.0377797459888836, -1.4163731758310805, -1.1848812939548496, -0.6833661697787141, 1.8924034115938815, + 1.0815655745578274, 1.6514910861689147, 0.13638193195368675, -0.7978052236681465, 1.534221720841157, + -0.9153275868493269, 1.8162414196916217, -0.18899261449512395, -0.41163974536752157, -1.7273888453908217, + -0.2259883780746561, -1.5648477168095223, 1.483239033708478, -1.3942275569974942, -0.47490997296002035, + -1.7533267576457146, 1.885960859958221, -0.2403666825901727, 0.5586086137789854, -1.0921597903975728, + 0.31058170316775424, 1.3277885858021783, 0.7148623072607876, 1.557240774068096, -0.37204592237067047, + 1.296024489070251, 1.0544578912529046, 1.6854705704653057, 1.511922732706835, -1.7363032773317322, + -0.2092964598440208, 0.4481058099152593, -1.6737966633530679, -1.090792800318261, 0.6450021408332125, + 0.20592552116749374, 0.8156159228622641, -1.568521435345816, -1.9264895113126421, 1.3354691459199177, + 0.5736596388344699, -1.1328195208816618, -0.9748407260828982, 0.12745284809908775, 0.12755406203831754, + -0.402313518792381, -1.2084251106645967, -0.766591500306526, 0.5831135977679764, -1.9502950151238094, + 0.02368696030780093, 0.9474278651806163, 1.200170490724969, -0.23626565831359603, -0.00155175398078633, + -1.4962274409033034, 1.4172670942494623, 1.6336646415640272, 0.9015591359514898, 0.0028804858578190817, + 1.8517288251219313, -0.5845464836204135, 1.4764716372203344, -1.3353282541360834, 1.2177029719595032, + 1.990053698845105, -1.5060495488639427, 0.3106662987296689, 0.2870946491789992, -1.0746935708007666, + 0.895601302779335, -1.5268871259283276, 0.2678390558787376, 1.864513144299627, -1.509732497733654, + -1.5904836915336782, 1.677740063904718, -1.5200319590355935, -0.7931901348349131, -0.057131201288623146, + -0.5299918958977132, 1.5844048394762984, -0.9074859645216646, -0.6250324408575869, -1.7483344039680153, + -1.648406830815511, 1.790525414138699, 1.1087046319299025, 0.42706999037150517, -1.1252987787944102, + -1.7926103436944185, -1.661929519642099, 1.6441197132107996, -0.7994467312378033, 0.4924994603048294, + -0.08241669477843683, 1.3633290277505736, -1.5418640660707128, 0.05946840692833444, -1.6001798975438088, + 1.1048132720023895, 1.0870814820606292, 0.5831092120375949, -0.4133915802678638, -0.9142815961432058, + 1.9472407631770015, -1.4219510012962004, -1.1714172709742634, 1.4079274160525728, -0.34506772652017137, + -1.0158502812481522, -0.8168949547547397, -1.1457275079452227, 0.12910364073916902, -0.4662867248454887, + -1.9437241965128846, 0.07261938805201762, 1.600763502162561, 1.0777174066413018, -1.2723083195923186, + 1.3113857387077417, -0.5228664205198017, 0.4450488409424249, -0.5683762553586282, 1.8256298065201282, + -0.5555324792306022, -0.5028443495682451, 1.550965251170056, -0.47857481650681066, -1.008285169693293, + 1.9029801635553145, 0.7739617661198315, 1.8099201835531415, -0.3994817059435505, -0.8127756747385817, + 1.4033307810724054, -1.359844813448376, -1.3846355261466767, 1.8540201060398234, 1.0970430821179962, + 1.3778953118217157, -0.6311210216871839, -1.8270928773238353, -0.29073753592093343, 1.063407723752193, + 0.769348666705115, 1.069807859635052, -0.13297318999054397, -0.7495627942438086, 0.4278305696495597, + -1.7534013899377605, 0.20503621122624516, 0.8877917416885026, 1.9219368972107151, -0.7795858126195832, + 1.8045722365205155, 0.01848994995789255, 1.9822395081707462, 0.5682615557282436, 1.096590333951183, + -1.1060317246730396, 1.0869871276155, -1.8681569307257382, 1.9498301468214843, -0.5725242199723457, + -0.754441782550737, -0.0400922717249097, 0.010590885689596874, -0.6969977940409491, 0.6620666861327669, + -1.5969982725685883, 0.17340909047153819, 0.47755863566996126, -1.6291589696000264, 0.5780359168220688, + -0.5173306807336635, -0.9514848124225095, 1.6169705288679577, -0.42893373795490586, -1.1528283547930025, + -1.0400955977716713, 1.9827399061918518, -1.327376858845513, 1.6043081593845727, -0.11039938533269034, + -0.24997046912904874, 0.47014693724974954, 0.729145170631436, 0.7014015249399357, -0.210704593378912, + -0.8898579600723409, 1.127679820439794, 1.379686041589359, -0.781681363239616, 0.2858562618428895, + -0.7131287792008063, 1.7878252016142655, -1.0588662101910593, 0.4893786031875278, -0.8406649057821527, + -0.1534439859760095, 0.2331374640695536, 0.469582749149386, 0.4137313563589631, -0.9769266749700227, + -0.37870419628070984, -0.778681560279427, 0.9076631441596534, 0.9624606550623103, -0.10250544734763523, + -0.2518054637298146, -1.8389925418951307, -1.7523906632013846, -1.8251290048632933, -0.49859600328817244, + -1.5645964425382966, -0.46692843392327266, 0.3025335867203003, -0.8785897670006184, -1.7537151065566459, + 1.9360855593038684, 0.03479121265661611, -1.9402430169146303, -0.8981491188900641, -0.9720525655542991, + -1.2872361339345169, 1.9657361835316278, -1.9654227152525223, 1.4590349841798576, 1.7417951527725704, + -0.7636287264036836, 1.6938802231015364, 0.3969017158868704, -0.9308527980280088, 0.44396078845267084, + 0.8114974124677037, 0.22323733905690712, -1.157000049324795, -1.2116172131012615, -0.9832275983234169, + 1.773233656033006, -1.6481062641009663, -1.9471951041445985, 1.1654342998679619, -1.8679076405021187, + 1.9134708504745168, 1.9270958489182615, -0.4877809076980144, -1.4674512268342745, 0.006115322418878577, + 1.5523105881305073, 1.008791751555858, 1.7292932521498168, -1.2446660428848375, -1.32058622408507, + -1.6942157582592943, 0.9218514004458749, -0.823621328629307, -1.0203195530541063, 0.07206341884947509, + 1.214451058931207, -0.40454188129729296, 1.0066091638178039, 0.0801907243894604, -1.3250420419558173, + 1.6740542746900093, 1.1525840942242223, -1.7538751715267296, 1.2289357346449874, 0.44273632243705485, + -1.480515264351281, 0.7203216915034076, 1.736757457268701, -1.6126702540429125, 1.3353291202473017, + -0.3386246414186953, -0.15824184756675486, -0.9082701908024067, 0.19739090770367174, -0.6056382353292964, + -1.5021205949442713, -0.13004055376809376, -1.9680841369920756, -1.0085004366482355, -1.4660753620146698, + 0.5310372051600734, 0.6252656799282139, 0.28834705715856845, 1.9582690133559009, 1.0284365097891248, + 1.5791162728965862, 1.5890315798131152, -1.5740074592032895, -1.2346249557276874, 0.09869464843015763, + -1.9888659899819219, 1.6245510003031853, -1.8240356816949115, -1.1160775047609688, -0.9717920085031029, + 0.026054846056297265, -1.5990410562524549, 0.6191498034442082, 1.4318181978005144, 1.1449640789498945, + 1.435002701727269, -1.8991365043582489, 0.9679929619919578, -1.9806014397574234, 1.4536549482052994, + 1.2369898149063783, 0.0942097559544548, -0.44988290575276135, 0.6393419762034132, -0.6790983093222227, + -0.33932133359722183, -1.3323692893954657, 0.051614649124500644, -1.4850113224086279, -0.6288795685974646, + 1.6283138375987471, -1.7789341475324125, 0.9641353407036455, -0.8859758584446853, -1.3905521072955613, + 1.248121111974715, -0.148429098769566, -0.4602995590886545, 1.1003026112521521, -0.6827850627285645, + 1.9771039126131766, 1.227966495460028, -1.9054616874551513, -0.17250407406937818, -1.5976498513957544, + 0.04439303257666172, 1.51163887169748, 0.646404644877455, -1.721204116533527, -1.491970732730513, + 0.2989738629420451, -0.9550137794523614, 1.2025118418310514, 0.17278672215112767, -1.7073672613962216, + -0.6869493466963412, 1.4221175269968072, 0.8136078027936655, -0.09457917718456343, -0.6500956499585442, + -1.0403599385456666, -1.4197371934035345, 1.2190245805380533, 1.4761477510781154, 0.8005307377799946, + 1.5249741316932477, 1.0855648961952538, 1.9224829076060752, 0.2869108546051615, -1.5091598215762332, + 1.4501191712612789, 1.8870984229831107, -1.7302116354922958, -0.3545753940936196, 1.442437478066842, + -0.7823996675200249, -1.131407428223576, -0.5013183847734197, -1.4979613863674297, -0.24415013272007524, + -0.6627159008601602, -1.0181273908903332, 1.0328362758216585, 1.260250651113478, 1.0138156752044916, + 0.5371615864333288, 1.8553073783399325, 1.951357678034606, 0.7194607934648083, -1.0589630923826236, + -0.9620500497996112, 0.67763129815489, -1.6212754527879456, 0.8824272362591019, -1.8034693359581802, + 1.7422096047831568, -1.010660238079712, -1.8120346724245788, 0.7326343612924848, 0.8912492318672749, + 0.669928902828123, 1.5191540472733163, 1.9408210365763843, -0.2675149665538239, -1.5368478010395314, + 1.1378158248590227, -0.14340274845268297, -1.0420396266499052, 1.4238975359786146, 1.0186548966434374, + -0.37302282044451296, -1.6665521929489486, -1.9414179347758749, 1.9845099037831098, 0.25190468995929116, + -1.3565033558826896, -1.430084412385555, -1.9049229412063653, 1.221645795300951, -0.5219823891712627, + 1.9368378730562315, 0.7035479701902823, -0.3754047433995442, -0.14526093695990294, 0.17885379911634125, + -1.7453384743240203, 0.8465052490601339, -1.6823293227525946, -0.8161516604165566, 1.8312443096922442, + -1.9192510920287678, 0.37723942145518574, 1.725107407940751, 0.21381615826643507, -1.1716001801855649, + -1.0611345679669162, -1.9732355910502637, -0.4461777828323239, -0.23052068805350423, 1.218942575723136, + -0.22572812681057552, 1.9428383183668947, -0.06997110252961747, 0.2238383470247678, -1.178747770654164, + 1.289511874981887, -1.9756722906104285, -0.42650188553983703, -0.5494388087356263, 0.6619450518994654, + -0.8233262341940337, -0.551580552700945, 0.5817278377322888, 1.8685269618613036, -1.3000953227319512, + -1.4283838275880294, 0.5999358561474102, 1.0645958312240245, 0.014353697508937557, 1.2413161019277963, + -0.6897291610817131, 1.12297456609934, -1.8432752527139797, 0.9084575027035138, -0.1243916597867818, + 0.33130042087381195, 0.66895259903393, 1.6557830983974302, 0.08287276029065538, -1.5357530396168295, + 1.2343785017416087, 0.12902705924095592, -1.3495271524839367, -0.8728580889008626, -0.5244139433889465, + -0.3879006179068245, 0.9188768930207276, 0.7793805226934829, -1.4393071368258976, 1.613705376668844, + -1.0843958887438374, 1.7124014800074736, -0.9964777990871019, -0.8670020603088275, -1.779081809150778, + 0.12646904349631516, 1.6571805108463433, 0.12600417813102993, 0.49996900162664826, 1.9525284377831484, + -1.9652455426976498, 0.9494015016667543, 0.38443667884960586, -1.7724098362330505, -0.4082684743558227, + -1.7879879942659969, 0.08275123875162116, 1.7475525868209036, 0.6792010104118127, 0.039437186277630154, + 1.2836761397306624, 0.43674745284168726, -0.758347022092078, -1.7870493658991657, -1.7978809807072968, + 1.0584586913661846, -1.1157899056211305, 1.2701741054323072, 0.4374863843249699, -0.8214325870445185, + -0.7728421127264511, -1.5577282427526624, -0.23594830217496732, -0.5391955305244664, -0.5624792843906574, + -0.1938002353617021, 0.5223396103139288, 1.538375944848121, -1.9694136326451632, 0.8633772521280756, + 0.7208433609240901, 1.7666620769537023, -0.43532694305019426, 0.411879568237544, -1.2098790166305973, + -0.36423075107954883, 1.8891146830740784, 1.748054742284471, -0.24142532135673545, 0.5927554266793296, + -0.9152877032152276, -1.8063969967782656, 0.48207466271184707, 1.213706670479203, 1.6604107491663145, + 1.7558340275933135, -1.7932219074543312, 1.9856023833866967, 0.4913618217526574, 0.6789406744471576, + -1.2508048490603496, -1.3830358311750315, 1.217800508245622, -1.363187904834711, -1.9078698738816984, + -0.28177773798420525, 0.5582837083561785, -1.3049002248929513, 0.08913900413869147, 1.6980987887729269, + -1.1912978387544912, -0.8931915280846656, -0.18855808865450108, -0.45803053755017054, -0.8712638141863449, + 1.7094200997790034, -1.663338331843227, 1.2015315071432422, 1.7919441084478605, 0.31409629120692095, + -0.2597694212249051, -0.9168044364289045, 1.216621830723616, -1.6662413518480443, -0.8080880254762635, + 1.4073145445825483, 0.545639136243703, 1.4013255263406696, -0.5546097014600111, 1.0974522886316729, + 1.5316488607435108, 1.834526451907207, -1.9144935270991548, -0.7645822781961176, 0.06654825465795611, + 0.33190077251492234, 1.5544507881908975, -0.3538304760036741, -0.9572664339018262, 0.9528358413200575, + -1.6745715540252952, 0.507318989134915, -0.20165453900363595, -0.785857858404909, 0.8550126703052818, + -0.0012624172438071568, -0.9598028440876272, 1.2244820355277088, 0.2255077411064299, 1.8056720264962882, + 0.037756919076751494, -1.343540974005494, -0.5094599890534655, -1.9469287581463544, -1.9209187756131971, + -1.9807700974543714, 1.6311835876110612, 1.3993893579005068, -0.43399651753379054, 0.5299120477626555, + -1.9670191857782173, -0.5248122259552108, 1.7175353497229144, 1.7289010829074938, 0.5729177740933844, + 1.5352457320381498, 0.4553403180302915, 0.29793068106143483, -1.5797876923914638, -0.40357306624290423, + 1.265178047524694, 1.0072829549062359, -0.1769350094041604, 0.6540198866379798, -0.568023907344485, + -0.9385372669098198, -0.9547551373309231, 0.9690083247728642, -0.49448270214331114, -0.23248401423341658, + 0.8559950134661438, 1.5269704580997994, -0.6731018952946073, 1.4821498012885321, -0.086653738716163, + 1.5513273624777053, 0.4908490299053412, 0.608900001002695, 0.9347951454225933, -1.8764423825502883, + -0.5383639900183281, -0.9057735867914714, -0.4471390378994089, -0.9345110033957944, -1.2485918565222498, + 1.3231857428964897, 0.14498731120726926, 1.2862055594556612, 1.3345878394286323, -1.2756222907563037, + -1.3083824064683531, 1.5045780173000916, -0.3988401099234551, -1.8345796826462557, 0.9521512093444242, + 1.37476939261499, 1.7326371780327987, -1.1085753439494548, -1.9216876869475268, 0.32444794625794593, + -0.35462190108024316, -1.6069584401996142, -0.6840247431355166, -0.6774334428214628, -0.6167820243082769, + -0.7218673863781015, 0.4950955340603276, -1.9104643839392264, -1.8658682955390713, 1.6439842658028825, + -1.9665747219603489, 0.877772486257161, -1.447763157358751, -1.602252241120052, 0.44016487530366266, + 1.010332237856809, 1.1627057025546774, 1.3589381505180222, 1.6636605093405308, -0.30457004904529494, + 1.2269612708937165, 0.6454558582632632, -0.3281883947244264, 0.8131649340163314, -1.0598316735137594, + -0.42299178838843243, 1.2938482948248247, -1.9502364601006974, 0.33584798243720115, 1.960552892008372, + -1.5357606602986342, 0.7325579192102154, -1.7068807453175427, -0.27946385324092926, -1.8639767385643475, + 0.026428176272640158, 1.8041446105579393, -1.9753854510739615, -1.2061334845302873, 0.7372916274629793, + 1.2179795802972748, -1.3143781251446196, 0.682430888983494, 1.765010566861033, 0.7860745956875865, + 1.5788062358182353, -1.575242377037637, -1.6479554611466636, 1.2726425469891218, 1.3214340571366705, + -1.1176914021551418, -0.28104090150176475, -0.8644022366981297, -1.0077418811855443, 0.015316356685453059, + -1.8951979061169633, -0.5196653139834044, 0.8125953302477926, -1.823298987094664, -1.772714760583118, + -0.42385130914514946, 0.8535556140476261, -0.21002977910714105, -1.0489096773038966, 1.693320576211045, + -0.7552501574944035, -1.70051398713119, 1.7993076735058722, -1.5431556753936784, -0.44776937868310007, + -0.0034131973262736537, -0.08787851222156462, 0.1325414438300383, -1.9039783696720445, 0.9518867483514475, + -0.9603800917394327, -0.0648448360480609, 0.5218985531710549, -1.6635189965029777, 1.9924108002877654, + -0.26339506136937363, 1.1928641143061727, 0.3810845521550812, 1.1123779174325765, 1.2834151041594684, + 1.7494397068262924, 0.7406099190962179, -1.0134737497144135, 1.055889762638893, -0.8394462038526926, + -0.28477578596133046, -1.3393678408508727, -0.012496365797505682, -0.03748743063321758, + -0.09254003926345167, -0.19080760280617248, 0.34473162850229677, 1.3043786112928455, 0.33958283219176444, + -1.0099885582662163, 1.5669995819775036, -0.2714415980354179, 1.3346740594279858, 0.602475697679103, + -0.27309507620431717, -0.4346043241348738, 1.12017383631626, -1.6915194091713959, -1.1923394655016413, + -1.8515166462293298, -0.591963112025546, 0.5023793945680568, -0.4921032884548957, -0.4819170803548376, + -1.4966363280675248, -1.421677402946293, -1.6850285211377178, -1.9944566759402251, 0.23292198758882154, + 0.8553269921007791, -1.948229679796313, -0.13143026366451505, -1.6538966244541689, -0.9216045649630527, + 1.7264765451863706, 1.8222181199058705, 1.9306597140261852, -0.42071370405153363, -0.994060577839833, + 1.3461298706581069, 0.6155835113186514, -0.4233898713404436, 1.0628290316641253, -1.882099316617622, + 1.0589656835123407, -1.1563854484115064, 0.06709310867074869, 0.6384730812451886, 0.10574897066907951, + -1.702659493082189, 1.532461866391916, -1.4099555819463054, 0.9274818832867693, 1.5037867997803227, + 0.7264189517035895, 1.739697756396633, -0.6088742760292636, -1.6856099921156718, -1.772602909046058, + 0.3867733521889116, 0.43414404525556716, -1.4477323883785695, -0.5550413636141851, -1.0684298188676138, + 0.49963854275697717, -0.6428666361617692, -1.2461600431717166, 0.139396634056558, -0.8713484091078527, + 0.9279324928307542, -0.860886224082325, 0.6052214303020085, -0.07407732844211701, 0.8385159528403303, + -1.2883611258333367, -1.594397303034099, -1.288952302189137, -1.5793499078181599, -0.6082222788183742, + 0.35912648776686495, -0.607260423160894, -1.253926812856471, 0.5102013763540754, -1.6947762434136582, + -1.675889179914285, -1.7314605301275563, 1.5448045270744153, -0.6686717741699395, -0.5356827185210964, + 0.12358215343481938, -0.6340124361185859, -0.2817760753415781, 1.164433875881067, -1.0118173150641265, + 1.8268866563720367, 1.7413521495755342, -1.3276318599091654, -1.7238317321272323, -1.7921370418547173, + 0.3558056543449375, 1.6918022479955095, 0.9222053317838856, -0.052028097382029515, 0.5122787494307435, + -0.32626022494247575, 0.15032399126484997, 0.8080660425425252, 0.8796007677651145, 1.6881141214849356, + -1.9923518519205325, 1.1594510791129853, 1.8780756936993601, -1.7055973395367428, 0.7022611016052407, + -1.2390916075946476, 1.489326601979406, 0.09129782431581734, -1.0503781321438632, -1.830361985060848, + -1.6234239674810462, -1.9400691667730579, -1.775525052900898, 0.49354389704319335, -0.294835573216341, + 1.6593428114830902, 1.1178351651841174, 1.392043569467992, -0.23893431171828805, -1.4623844403945752, + -1.5343016105239773, 1.1865854046143252, -0.030035182509464242, -0.23319636854325143, 0.16118837392623053, + -1.4249606736799842, -0.10348851980420637, -0.512221808349385, 1.3638877591460936, 1.6510345373891697, + 0.5403817230171999, -1.7568167360776457, 1.582586423119313, 1.8459420743954364, -0.0937741201677813, + -0.07681514921378785, -1.6485771749332665, 1.405045845579937, -1.380639120572158, 1.3920918243796292, + 1.8133826455276356, -1.6212352715972855, -1.6619339109060558, 0.5860174449496753, 0.8058352539157623, + 1.2357143629814047, 0.22041032204430522, 0.7808171688721632, -1.8606664108725086, 1.8543721730560465, + -1.2809372837528121, 1.8923388155485306, 0.7298982025806318, 0.08659163513354162, -1.7687987574607682, + -0.4763155379116144, -1.5473458996242258, 1.5866611468929568, -1.5386305359062193, 0.2081267635605304, + -1.381273405047125, 0.29199706105659384, 0.5721917482420489, -1.6925391452117218, 0.7916231259738984, + -1.4231755623012, 0.3405385083200416, 0.12077341714245371, -1.203092512782769, 0.9840212059218176, + 0.8659573692212614, -1.7404045592057091, -1.200177869243232, 1.9929041194760506, 1.6176296085808595, + 1.15438166620096, 0.41001019693662766, 1.545005651638757, 1.4880131034252333, 0.6483595755736316, + 1.9934284869908785, -0.4425968131492626, 1.1431733762266791, -1.5542365498656396, -1.4707761026218984, + 1.6864275163379014, 1.8816548331946308, -1.896526297348939, -1.2767368243378723, 1.7876805565314937, + -1.9714294114028945, -0.4825746666154789, -1.7490480229694754, -1.6790035262680618, -1.89004754358142, + 0.37907912536042243, 1.2624866059593396, 0.10773840252143074, -0.6872687299448277, 1.9925905119168483, + 1.6613584791100244, 1.3497294256582544, 0.14770827679848963, -0.386793508812298, 1.6398868686034165, + -0.41758827391224607, 1.7067458043950365, 1.0278221550032498, 1.7159317753776575, 0.5096805329124043, + 0.9834561119760092, 0.09498575572277979, 1.9951803301511655, -0.378982487863734, -0.8263817670300764, + -1.0229038853801748, 0.3913502332282688, -0.6996774883339203, 1.8060537458340287, -1.5307723308680314, + 1.8274255271921316, 0.3374273766186473, 0.504303338314112, -1.6080869793956403, -1.158919549492441, + 1.5813373512787257, 0.9432537197048445, 0.7842306706160116, 1.3546056018625912, -0.9933122900788947, + 1.0300108626811424, -1.8830535281855107, 1.2388444130694714, 1.5820432556020423, 1.7375339367053177, + 1.6102618826367072, -0.8665078865132916, -0.6732392396345555, -0.6529773281138276, 0.10547233393450206, + -0.6941730912707627, 0.910541741327461, 0.9098240508320066, -1.6363195940333455, 0.5463776877988993, + -0.33425194180314666, 1.0461836058570837, -0.9752798490633703, 0.8152462082002794, -1.5875859561094057, + -1.9077898642385165, -0.6579383417965277, -1.3678615233634694, 1.285752623992095, -1.258025248987475, + -1.734869037268262, 1.0429220147507374, 0.7326509890514723, 1.9986880192758498, 0.6509156089882868, + 1.3547313230945077, -0.5885137816765393, 0.17851497147549722, -0.48206201174941743, 0.7757155223228231, + 1.5797039076642765, -1.8365491055202483, -0.5557915214323543, 0.8098649490633258, -0.9482838270042748, + 1.7240038491867988, -1.727136737786945, 1.504375516257281, 1.8508968488624733, -1.781048252815335, + 0.29406576456043076, 1.3929052384310125, -0.5672101671346796, 1.1612683698036195, 1.0365664134132775, + 0.52678187508665, 0.3071827595777501, 0.7041562182045773, -1.1798720159212701, -0.3576278976339484, + -1.816947298419736, 1.1348446198592024, -1.5847397567297339, 0.3604508734069167, 0.46593859852538255, + -1.7777395155730495, 1.6287619990007203, 0.05390138452125459, -1.6910054822794427, -1.1955797005332123, + 0.1527150449913588, 0.7396002223852989, -0.18108018677808335, 1.5581070045362102, -1.8797418765958547, + 0.8650557162994215, -1.611948653627663, 1.1857525788641103, -1.0890741767763679, -0.7623497273737962, + -0.5452689372225361, 1.4851221090925062, -1.5985252342075071, 0.020384445673001572, -0.9046335836897486, + 0.553216643980825, -1.9081035245013105, 1.1735365048310928, 0.1195026810678641, -0.13729942676911122, + -1.5259805046535062, 0.18920225926555112, 0.6705947977532496, 1.0030469424172317, 0.7610916391668443, + -1.2463291870058537, -1.0249840040963818, -0.19173741736380734, -0.6665958901111466, 0.4839557261866174, + -0.3485528787244352, -1.5242944187930236, -0.14435052683644933, 1.4529331704829067, 0.6458683077621643, + -0.20459459850901585, 0.631062246518785, -0.8791486672647375, -1.2386308323664368, -0.7619627858914138, + 1.2107017989790583, -0.33473019147897354, -1.9130923107708755, -1.7056520026340571, 0.8880970113501672, + 1.3775061802643274, 1.6842143744397662, 0.8956663807740703, -0.14207737810877585, 1.15151215723502, + 1.9554721163330937, -0.7192433376047749, 1.9055716193191357, -1.7324310669463001, -0.9426609664206191, + 1.4304692916943642, -1.4208060554924176, -1.6703510665846304, -0.11789403764850537, 1.1035841327225944, + 1.19193146120024, -1.735646676252749, 1.9517684584286812, 1.800446091627931, 1.2643875465217338, + -1.2016658470474093, 1.096389364153712, 1.5201049232112753, 0.30449137858893494, -1.1869295719311896, + -0.8315768582782672, 0.4346498942421677, 1.8028550790017723, 1.8857086005150672, 1.1520127103731062, + -0.1519769293665103, -1.9810749314865586, 0.6722940736723881, -0.8875375147759952, -0.7891504144449426, + -1.6946353432210914, -0.4359020089062122, -0.6980599672246557, -0.07982954815920884, 0.19084722828248868, + 0.6845828835105898, 1.361502619024912, -0.6198314937755827, 0.43206457200540616, -0.4275542957416594, + -0.5942951907386576, -1.2680930727916406, 0.8768595741043272, -1.621734829642608, -0.5341763533991353, + 0.5480403433778003, 0.04432939753449716, 0.6820566847717622, 1.9586624282689744, 0.32087743739982866, + -0.6031550453761696, -1.9650369588045296, -0.01282494430097092, 0.39651817911017506, -1.3672271702505698, + 0.5253796884817046, -0.9414239928629815, -0.8381706914260612, -1.9383642756327601, -1.035705422287875, + -1.206785655560016, 1.2381881481085397, 1.2976332585562043, 1.6406249791463123, 0.10308205600753784, + -0.5475238811744738, 0.29899086302238675, -0.8482038855976217, -0.5137345687600776, -0.9065955517315878, + 1.3104099207076656, 0.5025101714926583, -0.19511985691542133, -1.7979503287642702, -1.902372134245554, + -0.8024648307653406, -0.8296523085064926, -0.9765719040493881, -0.32443266986702035, 0.6539733476893286, + 0.5973369569721942, 0.9620679364522129, -1.1326092851006697, -0.5203786703966227, -1.9445684352154542, + 1.120817472266788, -0.7832944208518358, 1.7728569415904172, 0.4746637992232019, -0.6331056497466134, + -0.9560458440421451, 0.774229832283905, 0.9430691283749084, 1.7415084291941643, 1.7617290689369067, + -1.241891773067345, -0.4745318258222131, 0.170886449126332, 0.280332317095926, 1.797798588989031, + 1.4955462166754243, -1.9714454758512092, -1.0474834990256507, -1.4515647324997412, 0.5231202814417619, + -0.8247619693573318, 0.842919528017017, 0.3442132040114725, -0.7990310505752118, -0.3629900475529535, + 0.3295033891840484, 1.7075620006843595, -1.7886707971887938, -0.4229652804748518, -1.2955282842854743, + -1.178898100821427, 1.0539302135104265, 0.07818911036924803, -1.572166269506143, -0.9256373147747796, + -0.4523743452766471, 0.870552332090627, -0.14850807893624118, 0.4287510203485754, 0.7234083474785322, + -1.8786486629917283, -1.3259155700503378, -0.347806801338562, 1.3631756709064557, -1.8448244492904076, + 1.0352554853447788, 1.2734756957434143, -1.8995503975543802, 1.3929138458956896, 0.4683038046394792, + -0.2679051570033124, 0.8144995104717001, 1.4932193857012948, -1.6953990593425603, 1.402756253554589, + 0.3808795564720828, 0.8990244719837941, 0.40911995577962923, -1.1361078826263524, 0.17385823702487802, + -0.8186745793721117, 1.7076078317166345, -1.236408907822983, 1.6848698673082323, 1.1057305891182168, + 0.684172881489129, 0.6339043483376958, -0.5196970144327304, 0.3521577727883862, 1.650116506377742, + 0.2682799961840745, -0.4735884200746394, -0.37060613769970896, 1.1147917782653982, 1.7887865852421685, + 1.6978458641851324, 0.27152326646567104, -0.7479948142297932, -1.2016639963842959, 0.49444674806749767, + -1.3781311297932053, 1.6512147519128018, 1.7416186949120336, 0.045136221165612334, 0.5209559439848652, + 1.259816622507529, -1.4608419771090535, 0.044609757110229076, 0.9224148305655282, 1.001777160557932, + 0.6923994259395219, -1.7332027007479764, 0.1740919172843176, -1.6691024022127667, 0.020765839823362775, + -0.5411621161873468, -1.9512100832757255, 1.2205713100210618, -1.1974515945322022, -1.8219003860967407, + -0.20156369520917128, 0.8953615730514599, -0.6851191939445931, -0.9784599914417944, 0.6357787809646211, + 1.2543221592701617, 1.154804559643825, -1.3175939582140597, -0.6021071761808416, -1.3594268073464137, + 1.445454919786025, 1.6373224793127816, 0.4773095450916047, -0.5842274784088719, -0.7097055492598594, + 1.221870791596726, 0.5841175201744218, -0.6088866249685312, -1.0239540299091434, -1.5776221817975893, + -1.0528530592443435, -1.7599610522525682, -0.009359693820525372, -1.5849216381687343, 1.1123817614821956, + 0.48090947307860965, 0.5755896514757204, -0.49340183569225626, 0.38999288587119985, 1.9076108980193105, + 1.7293269863481244, -1.3443206598919062, 0.001652373471082491, -0.3612982015066102, -1.666179661360145, + -1.5652606963428637, 1.2847799597537284, -1.8746367656966019, -0.5051948920106453, 1.480902450916811, + -0.6574579095355038, 0.41299775479925227, 0.0038296277371880905, -0.49511696025555096, + 0.33076498894339057, 0.07531759452269782, 0.3925650285564686, 0.04553443036132965, -1.2835239354793773, + 1.022597041467466, -0.14452142564653414, -0.6834698246070756, -0.26296934920727555, -0.1553100097881268, + -0.7486942835272403, 1.6436947866485108, -1.453108500403955, -1.142814165110103, 0.8561776338025142, + 0.9114113939243937, -1.6289171667242108, -0.8765736035046707, -1.3197920400322358, -0.3734912789205094, + 1.5370462570492807, 1.0649239894263536, 0.8535930082816217, -1.6961129271355224, 1.4179270256764838, + 1.4041314424292972, -0.0700614584636794, -0.5005397839840926, 0.23183520631057597, -0.10990588024738823, + -1.147089599229715, 1.8964022235002318, 0.03743472262799674, 0.15742108441832148, -0.7185738786020632, + -1.1856642140796065, -1.044955084073341, 0.6700862122512099, -1.6025569167524276, 0.5441522921772082, + 1.756702564317389, 0.1676615923964384, 1.9361822628230643, 1.7064197592361863, 1.6880082988048626, + 1.4936942432907419, -0.5799520435760606, 0.990660775656627, -1.9229488942439295, 0.3155933315458741, + 0.27140426916735727, -0.5632551338077167, -1.4590940918768025, 1.5983827694310806, -1.4330645192158764, + 0.8003275119389475, -1.9841806470403869, -1.8395944973457397, -1.5552105471701392, 1.6842460594625583, + -1.7017713102134175, -1.3427034266215978, 0.6775219782398878, -0.3351302524035349, -0.022448660326078063, + -1.3500598715209335, -0.5202469644373027, 1.921619202644619, 1.3872706825409837, 1.396529188979435, + 0.07083965303226236, 0.6914877624468483, -0.29224764274698156, -1.751452537465684, 1.0064318828131888, + 1.5123878406285325, 0.4033025380645654, -0.9820341018265148, 1.7861648401668955, -1.0181815370408254, + -1.903647522119213, -0.6561248619205413, -0.8699254533593299, 1.7173299560789754, -1.9229485367712744, + -0.690953419069066, 1.2737422710672446, 1.1066566790733132, 0.8606336717720096, 0.8621209988717888, + -0.10083318929855434, 1.3453155024821513, 0.10215147045998396, 1.3751499812523518, -1.2192061255665605, + -1.0236206808587962, -1.533123935767816, 0.15060837984637665, 0.20593075721578913, -0.1420114293640644, + 0.17633888228692118, -0.5175320445499185, 1.3087124943884705, 1.8324200965520019, -0.4738758668846961, + 0.5584879952099433, 1.9477363800804062, -1.8628690990218209, 1.42344054481267, 1.9288039965215793, + -1.8566049760751664, -1.2884135845361993, -1.19038243691399, 0.7683237483024614, 0.4802985717653341, + 1.4329557197969898, 0.9128985393650337, 0.9461407015680665, -1.65078779921165, 0.4192615449543542, + 0.32096438114763437, -0.5129450234494479, 0.631644434863107, -0.47335834996781756, 1.3100891597589213, + -1.4790022015467565, 1.725986057822876, 0.8701026300053698, -0.6819596038865807, 0.4664857353751577, + -0.7112013772427535, 0.12572492597099938, -1.7769733290901675, -1.6440093209585278, 1.125533492349331, + 0.19218031047016204, 0.02403105386135085, -1.7090825450471296, -1.312565935939367, 0.03858791868454148, + -1.2641036387182139, 0.904155292993349, 1.3561478854255968, -1.539635445223043, 1.9235783529452943, + -1.874310049415433, -1.7773733191470518, 0.11118483759793829, -1.0886413092825036 + ], + "dims": [1, 1, 2560], + "type": "float32" + }, + { + "data": [ + -1.0123639551600512, -0.1262791332695521, -0.5528788189121379, -0.9891722578914903, -0.8541177111117033, + -1.842327110585285, -0.8753504664996115, 1.645494290881846, 0.12876899266900654, 0.5739158499727086, + 1.1027206966256946, -1.3155458981065662, -1.1433211051679475, 1.4367855916029315, 1.4674402192278242, + -1.3373231059554618, 0.8170172046647917, 1.1074697240531268, -1.6004007249577086, 1.4644646696571568, + -1.827927680383385, 1.4548611857965401, -0.8614990138298273, 1.6706016048131325, 1.5096827794979042, + -0.4651782953949448, -0.8577219028996828, -1.299490913831483, -1.339145989117756, -1.0017632367283333, + -1.3586772742922255, -1.8799261724983776, 0.059417093938870735, -0.6646734157727456, 0.5388638764003799, + -0.6378909942726629, -1.1647516486356562, -0.058721485739401835, -1.814477796499844, -1.189167849669282, + -0.6380012350722168, 0.5285662507696332, -0.534701982091482, 0.5570990437739303, -1.755585696977759, + 1.27238726136709, -1.9571057028071568, 0.04195657651183726, -1.9047024137637942, -0.5116506294039525, + 0.9926189908254566, 1.7759871807627992, -1.301591689492625, -0.9524123108659142, 0.524043944088068, + 0.5401946307447156, 1.2036398911372181, 0.3219194319137575, 1.9433711281899884, 0.33648919584354076, + 0.9772519308773946, -0.5575502080369272, -0.7345410843307336, 0.5778449333097511, -1.9408006240005902, + 1.9819202164371932, 0.8468700855540847, -1.3899691404202503, 0.15850835128301544, 0.49059781858603113, + 1.7764286491001, 1.7946165130578535, -1.16168298050607, 1.032789240397304, -0.7706982026329863, + -1.1038032957670127, 1.1838096309519006, -1.6323318982074788, 1.1064057844588042, 0.391546384023683, + 0.8726810963884386, -0.28563916045025906, -1.960002018275822, -1.3867181381833626, 1.921900210546779, + 0.23545042298506935, 1.84976268271061, 1.0525315891685612, -0.7472942377034979, 0.7577710804753881, + 0.3083793892222504, 1.327301485825914, -1.5659874146221533, -0.8978311834083046, -0.62907789098844, + 0.4870403076019034, 0.630783738162723, -1.9216326727288857, 0.9012236985202584, -0.7852645198565309, + -0.6937624671663629, -0.17653350051115613, -0.5561473457869717, 0.6698782481609369, 0.23362849702887267, + -1.2054404784680655, 0.7698259603739315, -1.1714183267177214, 1.1184512448333477, -1.6690588449303805, + 0.9147154841361029, -0.8955722282828678, 0.5231382201829353, -1.0773616491852902, -1.8410207920577744, + -1.1337885739036437, 0.5219438601865445, -0.5731516829203676, 1.5757208446602124, -1.8470713081756802, + -0.14985389360149082, -1.9692981274100303, -0.9532002408436595, -0.9539842568916299, 0.17378022347200517, + -0.41623035688267596, 1.6481595173848254, -1.026339675363599, -0.9699532421892103, 1.9461597162595536, + 1.3640648912196438, 1.1391500889364954, -0.5595329932725566, 0.43069270118926184, 0.9216291855227485, + -1.440051073691266, -1.47987236009598, 0.9087155921703598, 1.2787682017568338, 1.2363303394454128, + -0.49711585260182556, -0.9387884495809562, -0.19011148798368716, -0.08611760612433539, 0.8656095244085549, + -1.4316244544821588, -1.558915825119862, -0.6738541489070196, 0.7358531504667214, -1.5716157542957179, + -0.3210549969144969, -0.20072745672949832, -0.2416365577548918, 1.5081023632879136, -1.547604775100064, + -0.6334244403609635, 0.14810360033380032, -0.7978288773497635, 0.6204344672116999, 0.4773642826492761, + 1.2087249539596163, -0.6643075818626354, -0.4170560884596144, -1.6192321024457321, 0.7844847722786517, + -0.629690651133866, -0.7380758723814482, 0.303414620658204, 1.5479875220490822, -1.198103302774089, + 0.9760982659095188, -0.7574001500859886, -1.2724614749813545, 1.0658176639069543, -0.8843666652730136, + -0.6427064732600343, -1.4416669867869603, 1.473657450100351, -0.8344994942004691, -1.4224435385472942, + -1.0338023533751777, -1.933568422908838, -1.0802998481520287, -0.38091309180010224, 1.6199945117010506, + -1.702101910236685, -1.4725385504086255, -1.413591341417039, 0.540278745507603, 0.3517718642795238, + -1.590795907883174, -1.765499823368284, 1.7366923492439614, 0.4582221558773192, -0.10581682008337268, + -0.18516227544796227, -0.21097779387988158, -0.1428735544745896, -1.97510241493318, 0.32449001731988947, + -0.1832218746003349, 0.26181337286546746, -1.1227369967165552, 0.35351574098454375, -0.21205956319428498, + -1.3866497212089195, 0.5946688412485415, -0.3417425538750871, -0.33633058083047906, -1.7852940300594273, + 1.9919312461265548, -0.7629882135863388, 1.4620310920385196, 1.1115061446942711, -0.9057539166302142, + 1.775862903430335, -1.1324751374031425, -1.5851970376790732, 0.9843604936500894, 1.5734177841900427, + 0.9515914445205862, 0.034323622285483246, 0.8075573695504703, 0.5332420240003275, 0.7767308358623826, + 1.0329131214994085, -0.9838298807872725, -1.7429868813963063, 0.03740922197745089, -1.3794671490283932, + 0.9772124799843054, 1.6546060756751624, -1.345806362676182, -1.620585515255308, 1.3498272448019941, + -0.25283974040314394, 1.8309785362540882, 0.8336568766196351, -1.407378961144727, -0.8870392599067882, + 1.5455801491463914, 1.1404840595611354, 1.9778865957841072, 1.9026326233043243, -1.2919286899267508, + 1.5194536255103763, -0.40024201189426734, -0.38767200629106124, 0.37883119550528654, -0.8971148848399899, + 1.1472506966060552, 1.6769048658537242, 0.39834963390946676, -0.8584979863189526, 1.4851684858856853, + -1.9898922021489387, 1.7271323520062838, 1.8848497191146922, 1.7439889790675318, 0.5311134881425099, + 0.2302960173495876, -1.6217864910988453, 0.28492260856667784, 1.3550969896689358, 1.6762026245924515, + 0.45464402605973575, 1.8447468286497628, 1.7489125819896838, -0.7745650567526248, 0.6473255323813127, + 1.88270574713889, -1.4231865592800421, 1.406236181817195, -0.05820536366515672, -1.9830176146937077, + -0.927096735728453, -1.37521200952383, 1.6293084827507869, 0.18916714867483186, -0.3559388864834574, + -0.0626685044384443, 1.4510888124049117, 0.00665671994549033, 0.5852250937009087, -1.964947150735517, + 1.086994355276114, -1.5545621604146378, -0.6702039017668291, 0.15273130009205538, -1.354848404243989, + 0.8081822753111396, -0.2990136329330131, -0.1334268300545549, 1.2936295445817017, -0.5276761138383153, + 0.06209853112125252, -0.35227980331045927, -0.6683541541821878, 1.5365781152706175, -0.5227637702649135, + -0.43751245897261537, 0.39166051967309556, 0.6145502882685348, 0.6764920150128493, -0.46478346293163764, + 0.40093484640123567, -1.4385605602950564, 1.5318810200296449, -0.7902920012169599, -0.22815329205907098, + 1.5159148518766017, 1.7440445423086697, -0.7705868478778743, -1.0446035338845894, 1.4407728607631372, + -1.7690868678646723, -1.9956594357087072, 0.9165504950260104, 1.1647979922386025, 1.7626373785022524, + -0.3262003585763962, -0.7291462643423232, 1.2691368673965409, 0.9833027096614515, 0.7052758987187504, + -1.4080008451270958, 0.2004861907693547, 1.92413536100345, 1.8633978379666134, -1.5597901041000588, + -1.232525418601906, -1.9326471509575835, -0.23851047841947803, -1.745957663852197, -1.027455630245683, + 1.7842373183009093, 0.7098705198166604, -1.3523419086313861, 0.2493915779920206, 0.5836072040016118, + 0.8452857075528275, -0.6044200471234227, 1.335947146234287, 1.9535634253874816, 1.8737477649440653, + 1.6787256628480751, 1.3475059469256392, -0.9023420902836907, -1.815324493360138, -1.7487338231501415, + -0.08107787176718784, 1.178869071718574, 0.5869021791922719, 0.1289991861916615, 0.9871714466975163, + -0.7828180891664971, -0.9162265218952319, 0.7883323334301799, -0.7738825207321494, 1.0578051800827781, + -0.48483804389576335, -1.7003938250158095, -1.7474401518911815, -0.6024807198720463, 1.470072074418848, + -0.809852698248462, -1.8087803758512981, 1.1275613461510172, 0.9110052554791794, -1.4827836388852713, + -1.3641845213240744, -1.9188108209559402, 0.8859208024949954, 1.7438050845669144, 0.14476912919518536, + 0.6121128834023981, -1.7692670213619586, 0.023661688752081744, 1.1007625036098432, 1.1758330104763122, + 0.4546664062325476, 0.022499008403786824, 0.8120850018523171, -0.7886059301759083, 0.8107426171843777, + 0.015751759753425354, 1.9515003227015306, -0.3285612629290764, 1.758730602588753, -1.9178063288185045, + -1.3319225925070368, 0.5970900239608552, 1.8634221473873263, -0.7483844730402502, 0.0851383845623852, + -0.10037389959678844, 1.8601880295663413, -0.5358906627108242, 1.5311027975069011, -0.7567148434480719, + 1.7810484758849983, 1.5004941791198378, 1.238866744077014, 0.019238796977725237, 0.7314924609545477, + -0.6404106749076366, -0.30544348502988683, -0.7754562102568752, -1.1903829239480253, -0.7557972926946839, + 1.418804956107497, 1.6841275666684883, -0.35403092145419013, 0.3072276436064163, 0.4941160183076647, + -0.010460638654985033, -0.7496577784263767, -0.05957826320949966, -0.5349743628709929, + 0.44780861823397355, -1.9548584156880642, -0.6834407857845042, -0.6574778495500677, -0.2568872307434864, + -0.8179424074332058, -1.9399599284052886, -1.7438777236599172, -1.9046697213047699, 0.33576417528481173, + 1.0390831565369494, -1.2867357835981865, -0.9105779330773034, -1.2600968940701254, 1.7546113033912878, + 0.8638193166816803, 1.915934439034265, -0.18936860893703056, 1.6490561383957179, -1.7404200826424407, + 0.15942118157817387, 1.174512061322961, 1.087287672904493, 1.182852158431765, -0.12430741231511089, + 1.6711861366379157, 0.7124940145742862, -0.3946246470773911, -0.7754542725640272, 0.9539784330716907, + -0.5716889107776746, 1.4262896723570924, 1.4675456840569163, -1.7077488525524833, -1.6888666810589683, + 1.2108429896458865, -0.30524840414522547, -0.18167408305726607, -0.2569749511019337, 1.2912167486614727, + 0.6208472747047127, 0.9472464500515958, -1.1302136544927563, -1.478282134349313, 0.4848945322578242, + 0.8298435424742152, 1.6932133553283863, -1.4458048451455756, 1.4088139925156833, 0.505348371415975, + -0.21105001864112882, -0.04858175142791943, 1.3570555900503694, 1.2673714070205957, 1.1469844853077413, + 1.2000011591622064, -1.0780533577358637, 0.37698814259115565, 0.6997609434842227, -0.7604196150995675, + 1.8410835681246196, -1.6836663805912915, -1.6352482015283725, -0.4811456756273813, 1.3045848106454878, + -0.823583139102726, -0.646527859067727, 1.5092372843244393, -1.0424659042584983, -1.9448695809676995, + 0.2678845821239566, -0.6194354091338141, -0.3172475643478627, 0.16481577119936563, -1.1026554846901258, + -0.8352503899270465, -1.8149755146432849, 0.4839677208020152, -1.9367901959501284, 0.8680859459275245, + -1.2035834761537227, -0.8748603808576707, -1.8417628093555134, -1.0429294120821577, -0.2520578638761588, + 1.7833216800296539, 1.4367696159460968, 1.8669976567111535, -1.405562858069989, 1.0576377264778563, + -0.4569713929987014, 1.3255011842556819, 1.6171166029195225, 1.0403552739195874, 1.2264321768656314, + -0.47396544132443275, 0.7118492170263346, 0.6260191876547241, -1.2179712214091793, 1.5120789908822676, + -1.657525645319189, -1.2991286032659461, 0.22202239748400387, -0.5389051448124622, -1.594992260705033, + -0.0487688918807363, 1.1512759563478916, 1.4679486318272383, 0.8813613284468369, 1.4328674044139706, + 1.9999268579039367, -0.47950159568339323, -1.2281571927849866, -0.4554451947054856, -0.15429012922622043, + 0.19786052464284776, -0.3680312279906497, -0.18825645901610866, 0.13700608028054084, 0.11417316734012051, + -0.6463349589003959, -0.3770634118502656, -1.9950465240002844, 1.4676192894281632, -1.6060448800367215, + -0.6182395877160713, 1.2695963682598732, 1.4459649727588744, -1.88317964468765, 1.4240536704934144, + -1.5317465035623874, 1.9497915396745666, -1.991995390985421, -1.4828801801030478, -0.03471214257257316, + 1.0775554630217314, 0.38086611278178495, -0.1958126129950628, 0.3711869657718534, 1.7307000355063105, + 1.3370240564962907, -0.6892941432270163, -0.6176252997554714, -1.2761391889511922, -1.9463694295411074, + -0.820841598715079, -0.42139807880407787, -0.7976620746519121, 1.934915457164431, -0.4497028214466532, + 0.5258289450636102, -0.3002339930414486, 1.4317770429007526, -0.10670432773391081, 1.477187387671167, + 1.6422292268536607, 0.30393726544465505, -0.16649028812518285, -1.1690831963968895, -0.7043985973685203, + 0.47350023974648625, 1.657836032561474, 0.16219091081621606, 1.2307861721904754, -0.7270831655242516, + -1.0574142264570137, -1.9134290652692378, 0.1585901752075669, -1.922458865955397, -0.5216475421535529, + 1.4438431375673586, -0.2874803852531551, 0.3370681487492808, 0.9173203757850725, -1.3751125170831138, + -1.014212305918492, -1.5475568694685897, -0.5834419852252983, -0.022709263779811195, 1.4145035718404255, + -1.9267536984073965, -1.871498186038706, 0.525620783057489, -0.48663912480579086, -1.0308451661848164, + -0.1369351560707246, -0.7876105221422698, 0.6955249722293555, 0.29453585260072757, -0.06514154829755281, + -1.6429080966047698, 0.15520901396599296, -1.518429991046033, 0.6839405241853216, -1.8300625086431346, + -0.15898426442765246, 1.9278290352285792, -0.7150445644634695, 0.34034454186145346, -0.2506887167667946, + 0.2912251513885442, 0.10434269155791664, -0.5637887420304368, 0.031008416285043694, -1.0816174134360272, + 0.05203114530680164, 0.03172694813978172, -0.7646387549793285, 0.36213414786228526, 0.0060869909269349876, + -1.0367311632092022, -1.0684702277942222, 1.1874407786461294, -1.9290032593242623, -1.8550268296137276, + -1.161269082907582, -0.18656240236501098, 1.1070044180055767, 1.6261946461581385, -0.5698373516978554, + 0.9631347920513802, -1.0985201941795912, -0.45509125721838917, 1.6535643092193615, 1.9696290288271951, + 1.0266341388473261, -0.23790773845234447, 0.2828088466454748, -1.6537561622154024, -0.2286353308074256, + 0.8766588558049788, -0.8195788808423616, 0.7718354518479451, 0.46796484124644344, -1.7212327769837446, + 0.0658435971926874, -0.6407624425160288, -0.5647885630487526, -0.5284202936353299, 1.8818650199438771, + 0.29252062160862025, 0.09136912052125101, 0.4321630239196912, -1.1485094277982304, 1.6036235307678686, + 0.3318334927588493, 1.7219946936827109, -0.09166860362313312, -0.9321185046623404, -0.5842230824766759, + 0.5762857089716649, 0.6237761258836967, 1.3257989135149089, 1.65675048758645, -1.0060167288419342, + -0.08448091333478214, -1.2793076427969634, 0.7514972175750367, -0.4193024725154899, 1.427794959994305, + 0.9558973375817734, -0.00039143542951691757, 1.7030425931606343, 1.8219801925309609, 1.7260980421968792, + -1.9249357614979115, 1.6285038041870772, -1.6118493301059527, -0.4294907666236245, -0.1659993953929053, + -0.5722726383494532, 1.2935829105972072, -0.06859448172655114, 0.6177602091273879, -1.3370886026529494, + -0.8003712871381898, 1.4776462171750593, 1.4184671800982604, 0.5433276418773598, -1.1103872044287346, + -0.7572146251109908, 1.1710857107940438, -1.8705799333769377, -0.31024900334903194, 0.34866139709491595, + 1.4866168061361806, -1.9774625782466435, -1.9891386648785518, -1.7735018436923857, -1.5751766748778406, + 0.7218521749338684, -1.9390531947989702, 0.10502871018805493, -0.6908737000286438, -0.583334654840761, + 1.465181746808006, 0.9232784443998208, -0.37400862804876933, 0.6661364913985333, 1.8688403166211724, + -0.8717922772171374, 1.40123243258902, -1.9913513342271694, 0.7369262986601814, 1.2562396521830914, + 0.7638029152143444, -1.8164465814226718, 1.7184240047901733, -0.911895923498772, -0.43161449539234464, + -1.09011215721989, -1.865570383600069, -1.2232212962752618, -1.943030725162366, -1.3198980808588407, + 0.0564583162685901, -1.1298601037432707, 0.6392469941959655, -1.8442933136946733, -1.0692331296192306, + -0.29525834417416963, 1.1184299311108798, -1.6129180448223925, -1.2727580333965411, -0.9415967651718447, + -0.2597646669604643, -1.511150740860189, -1.769101860129168, -0.9185489030191736, 1.6841872338621604, + 1.7266136417112579, -0.6047332956995355, -0.4784036452377798, -0.6987121488961696, -0.3950169895430573, + 0.29099877073820757, -1.5250611710167732, -0.5876293953125105, -0.5938486494168753, -1.0021656820999798, + -1.9666708037201044, -1.272140943592933, 0.7880251982149247, 0.11964755378902137, -0.3901422866566291, + 0.7163616669643105, -0.04395212244207691, -1.0791955402155438, -1.2675936648298745, 1.0655012879795382, + -1.2960016160353804, 1.1268487593724137, -1.4561611267402474, 1.7853064994708516, 0.9817883639607627, + -1.509195648143982, -0.5791158763214721, 0.2952226835939813, 0.978029471962218, -1.7020877610480865, + 1.2949154364706787, -1.8669978207007674, 1.9804087372291983, -0.1920592681769353, -1.3129464854527964, + -0.13084211958976688, -0.25279655392730405, -1.8414897577067046, -0.7363208735547797, -0.6909260581968182, + -1.8811392695885178, -1.5901068180742568, 1.758878672856656, 1.5387787983193055, 1.6713805051822828, + 0.28500585759464503, 1.3914306792968247, -1.112480424362695, 0.43326162263712487, -1.8142315585546145, + 0.04023793859339708, -0.29805331377366073, -1.5940199056342657, 0.598129666067309, 0.625004812581027, + -0.911960460437454, -1.973405398299871, -0.7574758313972065, 1.6261060948310595, -1.0639316504874738, + -0.8549167612511983, -1.6781924250988283, -1.2461164334164385, -1.4396767476893544, 0.588676315376695, + -1.3923513282535058, 1.960640665522404, -1.1216598084556173, 0.29702774865635373, -1.1441990771482464, + -0.9733601129567919, -0.13533496827900127, 1.2875809157665516, -1.004348467034836, 1.501216437295625, + -1.7257128690349832, 1.3038540536955745, -0.23514094567048183, -1.0545846838443325, -1.5628126421353628, + 1.758225843292891, -1.752217343717482, 0.8827182187480176, 1.9633500079396518, 1.4124055174643644, + -1.6009057792139894, -0.124257691420528, -1.6361563376854855, -0.7163857270237415, 0.5991086423774714, + 1.7584781739562239, -1.625063774845441, -0.37572359945414213, -1.9506995916793395, -1.951072499257542, + -0.8315895595505731, -0.7002195813028456, 0.31048848147933406, 0.19118037223773499, -0.836819966187166, + 0.8992259497849764, 0.2769262236848036, -0.8190887502725346, 0.5908744335005158, 0.8309308801063224, + 1.828667115346997, 0.050270920754735826, -1.4675310474376078, 1.6474157726281966, 1.4412270025481906, + -0.071799132580602, 1.2723657902431542, 0.9847345744230269, -1.7967618044169429, -0.38502605748464447, + -0.9911154903546722, -0.9911306363398822, -0.9822148063420846, 1.404660627884402, -1.7223894428279198, + -1.906376932077218, 1.1267093944315762, 0.005733157947431344, 1.5499657009743384, -1.2918427917389055, + 1.8878866260898972, -1.450986879628605, -1.0670858020560647, 1.012435839252528, -1.0904895043171203, + -0.7636238525274122, 1.8215658692720762, 0.802604215057829, 1.6666057955071523, -0.6857256630224935, + -0.5501356674470292, 0.810089459752044, -1.6169276394201413, -1.7364078843810304, -1.1867030927097977, + -0.5172860730134063, -0.8556026745046195, 1.7395171980402528, 0.8977518661224195, -0.715248100272647, + 0.42642471199620147, -0.1359154018671509, -1.017818228497351, -0.00905895348889274, -1.6541703500137865, + 0.5001119548133026, 1.7346626988667904, -1.1674654589916598, -0.735219697011062, -0.1962670855393469, + 0.21602710767932987, 1.634800475118543, -1.614549402431435, -1.86469031751599, -1.0722793725135675, + -1.6751255258166085, -0.34784828468612705, -1.333517864681233, 0.2286247270719235, -0.9686327464308748, + 1.8030391927221219, -1.260653677947687, -1.3291707209191737, -1.7464317874557151, -0.6022055677476486, + 0.4234332187817236, -1.3942184957445614, -0.49460322127068856, 0.846493215661404, -1.3779621390020038, + 1.6170651737934367, -1.5949106936017516, -1.993666650794867, 0.517274246668932, 0.636335391123283, + 0.09728792236496986, 0.21871120388837983, 1.9033150412817141, 0.3761639225094582, -1.7448084623773052, + 1.001122609393847, -0.44673552515686765, -0.6566748795770616, 1.0022029703662332, 0.49152185517814306, + -0.19632140501675632, 0.6593309755584418, -1.42607069406814, -0.2499327686935615, -1.4645970035455589, + -1.8214929827258137, 1.7849263457214155, -0.46930999932929396, 0.930852011498847, 1.4657054090327062, + -0.8598379219960437, 0.21923117934684644, -0.2719980917718665, 1.1382814204088554, -1.4437234293121408, + -0.08437654030814734, 0.9230522551879243, 0.17552818859532504, -1.3982719952892557, 0.7727609217240659, + -1.3512364654797944, 0.4217546307725639, 1.8959084151076748, -1.9009721763514387, -0.2084686501701638, + 1.34507209606525, 1.4812500373319821, -0.25452451515050356, -0.7547650359872655, -1.5901912558006952, + -0.617303822265475, -1.568626713241147, 1.3511951372228292, 0.46680417658438866, -0.9843992974612537, + 0.21141544095185427, 1.9555502838890186, -0.5622924926922526, -0.7074175111692584, 1.6856741408764497, + 1.4329492504371748, -0.6904233032298688, -0.16570616327044707, -0.5819404191250754, 1.5298308400435117, + 1.2873904282242874, 1.75253332340609, -0.3969229369696805, 0.9712496560090953, -0.984449102903409, + 1.845837132921134, -1.23000834955623, 0.25823037305241403, 0.6562595586377551, 1.426434937488695, + 1.2050365141327637, 0.35112023386450986, -1.423781157416867, -1.22442697877245, 0.8857806584751993, + -0.27941851495344316, 0.383573200806417, -1.6546531309712131, -1.0620419037179136, -1.6487673588042684, + -1.1583303816085477, 0.11883432462925647, 1.8623629910270258, 0.7814730455397738, 1.3892839510915138, + -1.2109955091247775, -1.1531820955625154, 1.4249824872214445, 1.878872910977651, 1.96640914460413, + -1.7574195668520565, -0.17080472539130565, -1.2334517024508829, -0.042203796430729135, + -0.5119900340796173, 1.8245076363940829, -1.9504677488809792, -0.15838646478579133, 1.8653261691127963, + 0.9818831615417336, -1.1498049891062543, 1.8177635453973222, 1.0307183336158117, -0.9459693602670747, + -0.6967235469867861, -0.6765683802163993, -0.25117711682611343, 1.7085642032810782, 1.2354232326963057, + 0.16304953424899615, 0.8288006493055686, 0.7426908331423769, -1.8567733398412107, 1.9363614866712426, + -1.1491819532508236, 0.4456745704746172, -1.5200940589370164, -1.2921806881240192, -1.8425344342113679, + 1.385258197718982, -1.4468713069952912, -1.7194028306161009, 0.11112342320496005, -0.46993304971516636, + -0.3532497592673005, -0.20705867891006324, 1.1503266436423507, -0.8615322336734987, 1.4243814245202273, + 0.44605877897824087, 1.4291899741133527, -0.5503339260133986, 1.6845285529242942, -0.11203488733585854, + 0.7838364265704332, -0.3478757678382811, 1.4617786521240204, 0.6797786349237454, -1.4070556983478548, + 1.5368835556999283, 0.7270943692880465, -0.3943204095728534, -1.0159540349760245, 1.3420827647018054, + 1.4944969205796248, -1.3846348875424548, -1.1070204938214987, -1.5623431163391235, -0.7285340046827864, + -1.6146739312377756, 0.7914876850628412, -1.4285275663878165, -0.2102551967847095, -0.6036343290031185, + -0.7863519667198808, -0.5232027574551195, -1.5248951664568793, 0.6403374226115135, -1.3332654904167054, + -1.6013714748847017, 1.5264343369773945, -1.4659567584783701, -0.1255742527269854, -0.05787364846985188, + 1.0630262325298858, -1.0251739144160261, 0.6341878529485214, -0.37708094500223766, 1.5820017651752964, + -0.9754158194342759, 0.021470140570154506, -0.270780715342573, 0.8513662098629382, 1.812913501299759, + -0.5507306480213909, 1.6252304329937193, 1.6166807780171624, -1.0015815783816109, -1.6525598491008084, + -1.7640141719015405, 1.4567526655551655, 1.4314017477260261, 1.1147464550993922, -0.8609069210724725, + -0.8835445478716384, -0.8807052221352922, -1.054638273386189, -0.9307483447707048, 0.6532758793412583, + 0.2251469563444708, -1.4983410691150256, -1.355149970924261, 1.825265403016327, 0.3882557252462826, + -1.2005370275411478, 0.5167632305655108, 1.258505468633845, 0.09615276706317388, -0.11253109876707157, + 1.301050271249565, -0.5926127701907813, 1.6730534476647492, 1.2040207312610214, -1.5220497985479415, + 1.7064305519329883, -0.7793546956972763, 1.0964366887160475, -0.8642894018251441, -1.9411871186407836, + 0.074716954760353, -0.14369204309870298, 1.8393941069756865, -1.3769308839327685, -0.30862200557660646, + -1.993793823331563, 0.3709006233167482, -1.6557247604820402, 0.32053951400820324, 1.418554947267494, + -0.14801920801346657, 0.25882446183466357, -0.30227778350472967, -1.4281993644549162, -1.2907922764091362, + 0.9110864171884971, 0.613974184551024, 0.6697305289032087, 0.8489421527088039, 1.498148243683703, + -1.4269397350154973, 0.6132189565263042, 1.8741137765083877, -0.05705777446194382, -0.8855796810622429, + -0.8656995527854097, -0.5082467483431357, 0.4332387677470342, -0.7541429782381526, 1.305940642158534, + 1.3554774725998202, -1.2111195490457929, -1.4381676776657422, -0.5207599119467634, -1.8228914923142483, + -0.1877702958456453, 0.36230894851909135, 0.17851993376959374, 1.7927379289246463, 1.1368406732884377, + -0.9950802434664396, 1.6043789056058433, 0.6484661390901874, 0.933455445947418, 1.0965484754420745, + -0.6939481831648528, 1.2328397545284568, 0.9378872541025238, 0.34445900641106775, 0.1278365294249939, + -1.0258774152176224, 1.0892898808290514, -1.067964672512037, 0.44689053769430487, 1.9413539674195102, + 0.9528679183933839, 1.7587111895733223, 0.5576643641512362, -1.9897364013123235, 0.47036230007388813, + -1.7897636522061902, 1.5010421043239726, -0.6901446605239645, 0.45393761765945406, 0.2829324542693943, + 0.7192447802434581, -1.1331672455879191, 0.17337802502719768, 0.49963004068736083, -0.7815511951172738, + 0.9636628323801011, -0.8323427533238403, 1.2095900671544157, 1.934951536397083, 0.5200093238971917, + 1.158643992086569, -1.5891421117437572, 1.1934443042649656, 0.8276249819736909, 1.4212639972350827, + -0.8419641775597322, 1.2661381994885206, 0.06158022749540493, 0.46246628371777465, 0.7573951522202043, + 0.7619102543908491, -1.9354772481952196, -0.08541307148121735, 0.32286401465719994, 1.170684883004844, + 0.7749781307966064, 0.07936620410936168, 0.2315148010116328, 1.263046672689227, 1.031153500041193, + -1.4759863901212418, 0.9873606549190814, 0.7777062346878534, -1.7228292836209809, -1.5782061917261343, + -1.9224252408205036, 1.6523756827872802, 0.2993123495389698, -1.0626357773248643, 1.0197875082835361, + 1.9808878982391658, -1.7335070696659765, -1.3148099108168987, -1.9494880550804297, 0.17268926636423831, + 1.242098464223531, 1.1180045679110648, -0.6736825181756254, 1.0154763824746373, -0.19957976202073535, + 1.1193883202782464, -1.1478784045364039, -1.1330343367225462, -0.8689855158445736, 1.0116971724295913, + 1.2506562934116516, 0.9967792918784815, -1.0610583651118048, -0.532297674201299, -0.7150119770938783, + -1.2232688695213954, -0.4743547399060102, -0.46434726873258025, 1.0115464987548544, -1.0177528333698103, + 1.733606846210808, -0.8155729251191701, 0.5232276995816267, 1.7914758177269183, 1.0157926433716051, + -1.8351739772356694, 1.2012667794467289, 0.8545667843161366, 1.294851140223634, -1.3948286719024399, + -1.1466399375859053, 1.3041194739119701, 0.7377935849631942, -0.7995103884432107, 1.6159710967964172, + 1.1620542442509896, 0.0055755030499886615, -1.0334307489311154, -1.5841651529973335, 1.4111075346503084, + -1.2009959445559417, 1.9720845657548844, -0.8035695524416591, -0.14331494930025102, -1.4264826197967428, + 1.7586100751820988, -0.739427319576853, 1.0666567765029669, -1.5916344843881065, -1.4350385771936356, + 0.42668754843352463, -1.6484548280031257, -0.37728075139212613, -0.7687371980269919, -0.8179781982796062, + 0.465506361112749, 1.3840790073759015, 1.0880219954714985, -0.7036856119504717, -0.6963478247947235, + 1.0979632285661367, -1.3810178288903288, 0.8156134298411253, 0.10032276394202011, -1.6008032936016479, + 0.18932441617959217, 1.3551150619453578, 1.3534267176398442, -1.1276635081675348, 0.6608005967002919, + 0.793182461268664, 1.398769261405878, -1.2123058565244103, -0.08803245110073288, -0.2893447181014084, + -0.9972961711861705, 1.5618332897004663, 1.1591927593779072, 0.511047619279922, -1.970138349713964, + -1.3628804504805752, 0.2782295809685751, 0.30358322411230354, 1.9514398542744873, 0.25960063763317454, + 0.4976234205537926, 0.012047558099780531, 1.79534431915478, -1.7723391117592922, -1.9992623394682916, + -0.4524505481322034, 1.3804610881366495, 1.1587664810582172, 0.5111716739430667, -1.6928217537440542, + -1.0278751605013383, 0.20893412968684988, 0.5871739815665329, -0.8412950581167742, -0.4077765748738731, + 1.7498754266646204, 0.8583271186920243, 0.5762482317954367, -1.8599099610537024, -0.19242912582490845, + 1.2512291284228754, -0.8441763329152305, 0.26980735485206075, 1.4044456507515894, -0.8516268695811835, + 1.4493090144656193, -1.3915783403234894, 0.35557624127716814, 0.17226516309619733, -0.5021504124493701, + -0.766188811190383, 1.1332244159180078, 0.011135590774230764, 1.8851307343362258, 0.9148262788018782, + -0.8299956158151707, -1.6057691996197043, 0.5678238711359924, 1.8008767630518667, -0.586304639586193, + 0.47029839294118503, -1.463460016599707, 0.20856103503853962, 1.2545845494965118, -1.0729668619560213, + -0.14947337388785709, -0.20035875199434283, -0.07202935566940027, 0.05533721254453372, 1.3677731442776313, + 1.7011893855187177, -1.7202195328563636, 1.9488792451860384, 1.3096386167232117, -0.5132153326702822, + 0.5616165083457831, 0.4157359121447879, 1.8006481124839855, 0.230442477572935, 1.1686013774607265, + -1.7879670674147912, -0.842370723742838, -0.7927388944332199, -0.9586442598316518, -1.54708954114047, + -1.2956507442445577, -1.4031204732951874, -1.6120562181795481, 0.6283505387959369, 1.9223686678649798, + 0.12298814371626143, 1.6278360280654836, -1.6223557147160461, 0.43054457669015456, -1.7908842288361821, + 0.5775385836169233, -0.15097004219870414, -1.2290692851647318, -0.620782793926316, 1.1062043604891931, + -1.9746433547898716, 0.6382174626600765, 1.749692571795931, -1.8339775549081967, -1.464038954875173, + -0.9639795224425223, -0.990228592162139, -0.9403728223487793, -1.6685943188697578, 0.07041255085387288, + 1.5308882897823413, -0.47253846241489494, 0.8106189739961147, -0.2270261582976314, 1.8799983165866934, + -1.2472611795739992, -0.15798247303785384, 1.1021702350596438, 1.554345756888078, -0.555551779439396, + -0.2519441527853594, 1.6402787480890648, 1.0543147407284197, 1.6335920152443828, 1.6300453015691483, + -1.0481818985064661, -0.7279789339473091, 0.4888000242214119, 0.48732772667936697, 0.8584068837243475, + -0.2507103566018847, -0.4408909179300613, 1.1233796101976283, -0.4632288097116861, 0.2476082637987398, + 1.0440154957174563, -0.8869375992382329, -1.4091245527531013, 1.1244796013617524, 0.6892639663640718, + -0.022508581368523295, -0.1947662213695871, -0.03308275174489772, -1.5224685354120124, -1.749461912732114, + -0.8082369514527556, -0.25696601764524907, -0.29739460489203395, -0.7761972897539025, 0.11904461166545932, + 0.12301125764257481, -0.42555313462267197, 1.193050472577613, 0.41286096527825666, 0.19893639265378305, + 0.8311360551038502, 0.07310473388978878, 0.6685202630904961, 1.3372821732384246, -0.12066792119153114, + 0.05609759368011602, 1.0451437108149522, 1.662035784060775, -1.1922702285472315, 0.6416834232307735, + 0.6055133359661493, 0.7942484561725927, 1.7276358502409623, -1.2864330076965746, -1.0918326505522646, + 1.1998891150473536, -0.28628797182602295, -0.5716447940943157, -0.5689380516868878, -0.12795378416769676, + -0.6791193619117468, 0.10038501986448622, 1.7063771482852008, 1.194191767284221, 0.6647694331145066, + -1.7348358815181681, 0.2817755499501846, -1.8525340342841066, 0.9851618492112006, -1.1059877587552673, + -1.5295669583116638, -1.0505227522752554, -0.9487207502273058, 0.809796237150997, -1.8498482833968302, + 0.48019511430536, 1.1226579267420975, 0.742760641350614, 0.29074715598422696, 1.3377685110252564, + 1.0867564725571857, 0.18478374884500237, -0.3164295444608616, 1.622419371685548, -1.1918941739540632, + -1.85979992788012, -0.12624583269268985, -0.7280639736725645, 1.7648123921702625, -0.2906781194024406, + -0.09262137204418774, 0.4138200526064937, 0.04327935477419942, 0.9797051661471672, 1.6815036980821159, + 1.1825812431197686, -0.4448889424945657, -1.7444130322963716, -1.2294019413222204, 0.050868336661046065, + -0.28346687593229625, 1.76715189128082, 1.611534349398629, -1.1862000782338198, -0.18760651204489776, + -0.6355147050302099, -1.0693954841485978, -1.3343928935517813, -0.027843745181175272, 1.8949354550001445, + 1.1947313752564401, -1.705187723165774, 0.7263276294399761, -1.7205254544798727, 1.3224746522778261, + 1.1708741480141782, -0.521723626523988, 0.14495711998638772, 0.1539508663177509, 1.3474086945015884, + 1.629428728782094, -0.4727414254013107, 0.6353417064465656, -0.03451981234038648, -0.5836884681507799, + -0.20106395276242583, -1.7734283674131142, 0.32803960400104515, 1.9133097530382273, -0.24540943135655446, + 0.5453897283838289, -0.5427539051784143, -0.6173702928258038, 0.9933437009783477, -0.3068382522087365, + -0.9721176288546554, -0.2527049737425271, -1.273295878249181, 1.0233037474978097, -0.3711521481099638, + 0.5595202642561681, 0.8557760663482039, -1.8239035041840888, 0.9147544038362989, 1.7978278449861724, + -1.4194290997912544, 1.963345131841125, 0.23090855457281645, -1.3144424697360408, -0.1798082266955383, + -1.2650654602972473, -0.18679991460173095, -0.17629662209988428, -0.18349530363472422, + -0.5475428774726536, -1.04729696306776, -0.6411208082297168, -0.7308839648617003, -1.3455838296857312, + 1.3597877917216312, -1.8340993142240434, -0.6466367039451653, -1.8498009040527492, 1.50126492806735, + -0.8473870152631546, -0.566754366425382, -0.11291722444101016, -1.204932180016593, -1.3014629360740422, + 0.8120263473720843, -1.8930799484123453, 0.5895700763514027, 1.7580882681692787, 1.477127335430259, + 1.2047061381469248, 0.16485767274321272, 0.23788298792881424, 1.2386497309565767, -0.05281444305869076, + 0.17334853252953497, -0.9480401485516037, -0.12016072866190708, -1.5973543471706693, 0.9977273313654775, + -1.859593065673394, -1.22159503592711, 1.1674399144996856, 1.2941842856062413, -1.1135548933000354, + -0.9788839012477277, 1.5718286586532253, -0.1759982975417227, 1.1031262106075763, -0.8778538891016368, + -1.0912009563680698, -1.4342020269338933, 0.7224191131816049, -1.2061468497546022, 1.5502197738160985, + -0.4251181088474105, -0.3206510906592923, -0.20288873333289725, -0.3560455119064194, -1.7508196425056557, + 1.5986301909131457, -1.4438601542872123, -1.7892166215086185, 0.7616375057554672, 1.6373087559633968, + 0.384206177448279, 0.4567136418012572, 1.2972172920819673, -1.1093690558377691, 1.6750979445730616, + 0.90908880550879, 1.2770950805973325, -0.30072973792624325, -0.30858954557714835, -1.9794583286836787, + 1.4463537661478734, 1.0183415951718624, 0.19309738632155593, -1.9449394663020856, 0.7108312298345272, + -1.131148529447878, 1.5259401710667637, -1.2736225271446795, 1.449852979043179, -0.5704566653927854, + -0.3074713127723667, 1.3993673001014262, 1.7718101827800963, -1.4492416161385497, 1.204820691151994, + 0.537241600964693, -1.5810854626566924, -1.5966679426112087, -0.6946214128685, 0.7401275132249641, + -1.7160959560539784, 0.4050021853660404, -1.2105835684442203, 1.7944565918560151, -0.6818406209243566, + -1.2359644949792958, 0.49907683448423157, -0.8405824207686488, 0.4689270476935219, -1.4699087331797918, + 1.43264169315706, 0.10499119180317251, -1.9830520821430992, -1.5927469176409472, -0.7632048947095695, + 1.303303968850459, 0.5773554905161333, 0.6761130632322967, 0.9023788989770569, 1.8960847479504084, + 0.8846144527507596, 0.9891128512774987, -0.7137249307539442, 1.267181508288493, -1.5113665535004523, + -0.17564834166799148, 0.9032525164747707, -0.25795858643541525, -0.2153155122989885, 0.14650372443070392, + -0.9984704344431465, 0.19261182317116177, -1.568804815745514, -1.8496400745041237, 0.7219284702138937, + 0.47816525621959105, -0.3273800096758981, -1.3390312793543542, -1.9838474983553418, 1.4066433632074071, + 1.9258390074704312, 0.4520281509654822, -0.30119846185025256, -1.8086624212334463, 0.851460380433025, + -0.4149432793401253, 1.4970655678791776, -0.9139304175330043, -1.1721517169571731, -1.9882366923130537, + 0.20630155701555886, 0.0891351853539204, -0.18485046053672338, -0.5253430902979694, 1.1136150007281822, + -1.072256739674419, 0.5677711226994742, -1.6986682182236068, -1.373143853609375, 0.5391705517446521, + 1.615483488858379, -0.18222418110590155, 1.5270125615115786, 1.4186450525284275, 0.6856859039097802, + 0.5948037341597869, -1.0097732940745248, -0.7260016082299225, -0.1705798585617213, -1.4460592122059417, + 0.2804912469966476, -0.0574570618149588, -1.4226509159038505, -1.3490817825559507, 0.7561887451342573, + -1.0315310280500372, 0.7865868802852711, -1.7955739447423928, -0.20476732094967787, 0.6532859024525468, + -1.398626809307176, 0.31416850473475755, -1.0474173751356446, 0.049027524534579925, 1.335442264825483, + 0.5839485880852768, 0.8416818491436544, 0.7729008830376998, 1.7957935152184445, -0.20047560204525272, + 1.9653799460331678, -0.756998178675067, -0.12357101901807699, -1.4272827743751613, 0.7149414745051672, + 1.4783565252719182, -1.2368177109511205, -1.4571248051607144, -0.7948678149157731, -0.6295946982419727, + -0.022851757488315805, -0.07947620035768654, -1.3106359681202076, 0.1591438592300909, -1.4970586188027868, + 1.3181273904865316, -1.508591213967403, 0.5722257787143228, 0.774539967054146, -0.5579675263215638, + -0.801690277809052, -0.8966439545169163, -0.2168181087774288, 1.8549965661558616, 0.7870136331314779, + 1.0426166176054243, 1.2052992540989846, -0.6116512580549873, -1.7800528483131748, 0.6162047118916432, + -1.1406795391578877, -1.3126212462178328, -0.1255252753148266, -0.048214851156274996, -1.7513823416941525, + -1.9966724157135571, 1.468282137353885, 1.1596808879879097, 1.848952713577705, 1.9276331797246486, + -0.6082295997412146, 1.9590194651252002, -0.6705403599782791, 0.8982591946264264, -1.8582005994721253, + -0.6224103416017206, 1.3118474535601639, 1.927285880838153, -1.3435831019941835, -0.02035775798119932, + -0.258091815197548, 1.5685276792778557, -1.7504336743073416, -1.3270808448193447, 1.9609655615175043, + -0.5002114597187894, -0.8302889305621663, 0.6662682285835677, -1.3588868202703237, -1.4263374077454936, + -1.117653746556062, 1.6959423725848142, -0.3368698386266633, 0.6329184444264122, 1.4360518922995382, + 0.2209792086889042, -1.7826312330601093, -1.9055378329489479, 1.8363537758423742, 1.8612237061845747, + -1.163857834211714, 0.38823573714522475, 0.9933133475252713, -1.769852560129741, 0.6303163049709841, + 1.6352278260339865, -1.4220707937174062, 0.4996182092181929, 0.1748538915264719, -1.389807604688972, + -0.5041547053983226, -0.7755917479953034, 0.33822942573796055, -1.3957767536429841, -0.16066323457963172, + 1.8426173458683097, 1.0912529333551886, 0.04454407634104118, 1.4585397734066836, 1.314915917164475, + 1.0930141444320949, -0.9720567164640972, -0.5831452038265033, 0.8082335756515109, 0.4358913655339238, + 0.8310387682873994, -0.8242800720840835, -0.47497624245619896, 0.000058968841639917, -1.6746583184349388, + 0.2586283765233146, -0.03952361428650608, 1.9572062803747263, -1.364317103129661, 0.16484595584710782, + 0.6889848970954304, 0.33625779127527444, 0.28142293472509294, -1.5510992496482494, 1.6785313595707674, + 0.4921495479711657, -0.42294403727168906, -0.10192465238332815, 1.583070264702826, 0.6464143795128816, + 0.7706704090619576, -0.45316577898360944, -0.6156337052461307, 0.2949317256431403, -1.1153946167003506, + 0.23143632095143918, -1.187495465719234, -0.754948635807529, -1.090644714217727, 0.8562387289761135, + 1.4209567719285578, -1.867698005779011, 1.3320884849513037, 0.5619380450950349, 1.8886416226851166, + -1.7314027359692306, -1.0362482885730966, 0.9807231768105664, -1.3689591083054822, 1.5694772951886131, + 0.4400722090716478, 0.9539178709143741, 0.4832872148319014, 0.23471769113792984, 1.9643745055943542, + 1.1325801513292664, -0.43752654225713705, 0.4538975778222154, -1.6157155513403065, -0.961125955159364, + -1.1751535270699955, 1.1277536127856669, 0.11594556933087752, -1.9276503102738447, -1.5774089828974898, + -1.0029301039964427, 0.4455245589428616, -0.4739643281334569, -0.8513671370845639, -1.0336436816615233, + 0.6626347865920605, 0.7885413873550009, -0.013439463013608766, 0.6488139123172507, 1.4291110296253855, + -1.9761431450732667, -0.5012957954679527, -1.1585910698227027, -1.1021436093929422, 0.036919383239228054, + 1.8089329170710071, -0.005231354013648826, -1.2082234644042886, -0.3456887578591781, -0.8017405353429492, + -0.5345375675659492, 0.8534420279659507, 1.7447905633469585, 0.43817127727920724, -1.8499205219957702, + 1.4797731845522186, -1.5443888715914138, 1.0131225647292235, 0.34701885989022063, -0.41455116200553377, + 0.40209313291363724, 0.11900781713274, -0.9935386808363758, -1.8322340002578068, 0.9811839836103111, + 1.502154507354498, 0.8891949169357574, 0.899071159318308, 1.752337905147142, -0.04599799842734953, + -1.6347681983052045, -0.5522741247690393, 1.505215487771519, 0.8504281241898015, 1.8693941265525265, + -1.1512863792577441, -1.8748160415118837, -1.879939448107617, 0.9353149913960506, -1.077101932112896, + 1.2322050595012843, 1.2672982902122678, 0.9384368132472858, 1.7274119921052788, -0.9601726232935137, + 0.19420343716687505, -0.7830049935581602, -1.9099470296794694, -1.213386784368356, 1.6800660417837605, + -0.9282638481321719, 0.5088239004955142, -0.5528513330962577, -0.4235136044745138, 1.6316021980530238, + -0.3087654696690505, -0.10527992793999896, -1.4364007982935343, -0.4455364976497762, -1.3433044303003099, + 0.6517505064656408, 1.6050250028051813, 1.6490276577492855, 1.9140119353414144, -0.7684496098140174, + -1.01738188731548, -1.1250647161193914, -1.6586222112755102, 1.1599068196677091, 0.795751774794466, + 0.5733174614685748, -1.3655937932875277, -1.507254849973065, -0.2831083653801638, 1.3241227396573514, + 1.574957068221127, -0.31194765030973937, 0.4008126582755933, 0.43635579619776443, -1.3214048572867325, + -0.8447194221435215, -1.1526249262582748, 0.7073544609451421, 1.17078844004073, 0.2425026449956018, + -1.7518561882120753, 1.6591407848437605, 0.06616038448738504, -1.928680221520497, -1.0504809684365677, + -1.0974712342176778, 1.6344494477175475, -0.4129201382527832, -1.7111789594333953, -1.6070549808753904, + -1.2456702084965565, -0.012663680475193395, -1.1305840149083926, 0.734392120651302, 0.18651679771884844, + -0.22974141381305735, 1.9415149817194726, 1.9280078232850126, -1.2072428658632042, -0.14782869942839927, + -1.6523593328098034, 0.4844141001145905, 1.0492278525622805, 0.5924539450553175, 0.848097235977705, + 1.8881210898619676, 0.20004070245023797, 1.4305799893712425, -0.9082660328332564, -0.14268688754147085, + -1.1201061991671262, 1.1399839045134712, -1.5579448101377515, 0.5516078322124933, 0.3365679579810825, + -0.636402425334972, 0.7364990374614768, -1.2657328423109204, 1.4084870144147636, -0.5538274490613713, + 0.43684201943536305, -0.706532199493215, -1.7678543182116737, 0.5086879667154935, -0.8888826793267235, + -1.0510640830474856, -1.775013227468511, -0.7345226367397419, 0.9474796694127203, -0.649964939391042, + -1.5189099534245498, -1.9260526549789567, -0.457330394781688, -0.9340352374682741, -1.3868164983748459, + 0.553888202560878, 0.36818698767921365, -0.9382183717778192, -1.0829596839250488, 1.3646658325042367, + -1.3240476722940633, 1.9816923192707012, 0.5300123477141927, 0.0790085088366057, -0.4455760575475125, + 0.48463653297167486, -0.7788483158746828, 0.8253771773416885, -0.6823431576948558, 1.7776737534704772, + 0.5497713214586923, 1.4452464852137838, 0.23037431004796094, 0.31188142524786766, -1.7543267850797761, + 0.6063452856820941, 1.207122989395999, -1.926907332363795, 0.45038239145265724, 1.0988284911574286, + -1.6007436457047142, -0.4728890687538678, 0.3195037474199047, 0.8855124961325762, -0.2555993730577626, + 1.8813620496087493, 1.8900177166377103, 0.09592367474164032, 1.8974568987778628, -1.1058972953708501, + -1.1512017435907103, 0.40201549011430693, -0.1831060132280804, 0.22245091899613723, -1.1866541831479385, + -0.5451040730392975, 0.9199451579519691, -0.42060255461704177, -1.3791747236441925, -0.3024448490768936, + -1.6611455107283604, -0.3106541240888907, -0.9498356682876157, 1.769660410309836, 1.598216213741022, + 0.4623205859503434, 0.03664778458072249, -0.6655252973523123, -0.7325818423601653, -1.591871681771024, + 0.9451427301297981, -0.8203468674560934, -1.5069504221011005, 0.7243170638862324, 1.1839749702725175, + -0.7128348341329511, -1.5076965090949397, -1.59865172895221, 0.08910680490749368, 1.5717880586471278, + 1.8951504684652152, 0.42550207471805646, -0.128409822054123, 0.9896766313315162, 0.34808462644009275, + -1.7082990487472571, -1.0459982270685435, -1.1132292311691874, -0.3022325459842996, -1.7274216318536348, + -0.11775921716410043, 0.7403577685290532, 1.188824227090608, 1.387282721223393, 1.0688709331799577, + -0.6395615121564866, 0.8142138261269114, -1.4467483751545576, 0.8996177593321626, -1.8193866881462766, + 0.08924208518874632, -1.405297919708996, 0.31754790231458685, 0.9823851818369507, -0.49590144528424585, + 1.4194064220328588, 0.9729299634967079, -0.46170090347918613, -1.634203532024186, -1.3139980454214522, + -1.8469876250843802, 0.710926322864931, -0.4029599381569682, 1.7246833539931403, -1.4088169680807, + -1.9165388068372708, 0.21804317359714798, 0.3898186987610348, -1.6118063668363405, 0.28583673086194583, + 1.7015683175211391, -1.2836168642070582, 0.8463494611619371, -1.1625839799245696, 0.2640138032690018, + 0.5041551687310717, -1.755824925370514, -1.177748867346489, -0.3829444120449246, 0.8360805034202388, + 0.05022254918868896, 0.4276469609032256, -0.8235567451730139, -1.94062145827791, 0.35097020666890355, + -0.6636358495150212, -1.2587452298002964, -1.6575362996910616, 1.060467707494971, 0.9661831305087887, + -1.77149852066976, 0.18844425542251564, 0.0897431507375952, -0.8281105706620897, 1.5632959207508623, + 0.07141359825195703, -1.0844701043735414, -1.905443802421968, -0.27588161311845294, -0.40342423607415956, + -0.34332304727825136, 0.0176022295550462, 1.94359831375926, 0.09702777017089215, -0.11695098349068722, + -1.2810374187149858, 0.37597456306160204, 1.7631374725308877, -0.7830266259108773, -0.5605784036815882, + -0.4409773606270875, -0.49636250754717803, -1.549108447216227, 0.6261185797820117, -1.260881611110821, + 1.691411217905281, 0.899655093658585, -1.0875528162122174, -0.7120948701980732, 1.8214705523154269, + 1.3010380076854968, -0.4492643980075144, 0.9914465230608682, 1.7590027691290615, 0.8514661670055963, + 1.5263431803492642, -1.7260779024351258, -0.14589666108296218, 0.18011804793376918, 0.7175880982696512, + -1.0399388762140145, -1.0480376846250712, 1.5656146512942648, 0.4435540525930799, -0.4175857955829816, + -1.8218436496980575, -0.9346408060646185, 1.40089015453285, 1.6926667426168764, 0.4187248632147291, + -0.6755275086264145, -0.7011229448771363, -0.9528087286614646, 0.0730922604589237, 0.6252467328216804, + -0.5573518555770702, 0.9864121888624755, 0.6646486706800783, -0.8405364163020792, -1.0505815213688878, + 0.8989991238262265, -0.5022516947851985, -1.784806766373471, 1.2637708002659025, 0.5065772818030325, + 0.9973024415787677, 0.08348064671549338, -1.757630249437522, -0.2016005631449005, -1.0360120513086803, + 1.786128872822113, -1.2720919225213843, 1.3430514506545927, -1.0762325516117865, -1.0578995596104255, + -0.47242526972271204, 0.05539697038437996, -1.08558757453563, -1.404337586710036, -0.059247790489052043, + 1.3998034069978171, 1.6067367856721155, -0.5185826391883994, -0.051896682542695416, 0.11112005542023429, + -1.2231398633939348, 0.2886372299242277, 0.6564469519248641, 1.67404118804063, -1.538487261793886, + 0.18551945213331145, -1.1342837192256061, -0.3318725405647953, -1.4531152273595227, -0.6934713826285721, + -0.24436235286417052, 0.6776292484438171, 0.8871814678850702, 0.41826798275898014, 1.0161513742931785, + -0.13947907673300097, -0.7736759327375049, -0.43981678279829683, 1.635191807530191, 1.3044401854805878, + -0.3097446711021723, 1.8125726195847056, -0.26127912212234694, 0.8564630403854094, -1.519521818793156, + -0.5727391479884938, 1.7015469847109976, 0.663240965083145, 0.31064120951508656, 1.4030451184981052, + 0.3325065959732836, -0.7178902057747756, 0.6090652378284442, 0.8426138183122633, -0.580146652112278, + -0.6076938097212707, -1.6599273271373782, 0.29960912457791444, -1.6741835731853065, 1.5428301790607195, + 0.8970548194971704, 1.6066845600081736, 0.5404165757730146, -1.9537941867764292, -1.5234595572340748, + 0.5293735217702951, -0.64620260665742, 1.8818640992235771, -1.7237764606754276, -0.8040024538741264, + 0.0642546885214017, 1.4395299343641659, -1.462587128675942, 0.011882540823848764, 0.12033421748154716, + 0.5458210215408705, -1.5141295301316422, 1.8809343680577308, -1.8801856621666753, -1.901376259472575, + -1.4374202976060095, 0.8473513507453765, -0.896351895119154, 0.457751001832321, 1.876657552919962, + 1.267733433184599, -0.30894648094866195, -1.8178016120669414, -0.7711776919446018, 0.29038564786361576, + 1.6396189720781438, 0.9597929848181161, -0.34227788522140834, -1.4450087753233527, -1.7068508353679626, + 0.8426935759536303, 0.7173810205823674, -0.6580236891322881, 0.8663322021812405, -0.38112472550089915, + -0.3331447946260786, 1.8551673806318556, -0.6731525492100126, -0.009001319785657103, 0.41039833755685784, + 0.025091358839535616, -0.49823213412251555, -0.9827448714264726, 1.1077851800046377, 0.5740585983905078, + -0.7235926614954762, 0.5059901875180826, -0.850177898505664, -0.05453987892121592, 0.8633840127545733, + -0.3153969644106205, -1.3028092681229868, 0.7523083527030074, 1.413775575813558, 1.0697458650110754, + -1.1839403319780022, 1.5167022074836893, -0.36486781099211996, -1.7835462010879564, -0.6061285803342944, + 1.9969466536022722, 1.9531672204642883, 0.7967381388222403, 1.0934589095880973, 1.6405590176012312, + 0.3501113568054244, -0.8786338692108497, -0.0545508019996932, 1.4464849975584952, -1.8853956921596513, + 0.25983013847132774, 0.8440414107184964, -1.8818826057620326, -0.22674619971532906, -0.35951513414106007, + 1.6757192364875237, -0.15819503713874195, 1.6691357866915144, 1.8534771980207108, -1.8297709996602967, + -1.6514392305036534, -0.2343385561012088, -1.820925987823773, -1.2556074451315578, 1.0621490016055715, + 1.2109203955756476, -1.1500152481919903, 1.0466452723330733, -0.5833431814034746, -1.0817348313127493, + -0.40295742758029984, 1.0986483593098368, -0.42704879342020696, 0.7065668399686658, -1.7278295290305223, + -0.7183051438456021, 0.15944089245801774, -1.0276995188812146, -1.6398474518024653, -1.2318071869917695, + 0.23333723988807886, 0.4626060172442088, 1.2255286520425894, 0.9652309086097803, -1.5254810192280601, + -0.6683416541099767, -1.9000628944332894, -0.7244291249780632, 1.0347731523086239, -0.9009629081875952, + -0.11734057204667625, -0.6698286923748489, -0.9472592207823913, -1.8286232413501864, 0.2898215560382935, + -1.422921306299517, 0.6696091233175077, -1.014229444534946, 0.7139087775492996, -0.23241859018004174, + 0.9620272535910068, 1.1497473812544134, -0.8640235723632799, -1.3121563547519584, -1.0396316763878488, + 1.7769188425035614, -1.014039070903868, 0.4489833453895331, 1.0763456360970807, 0.6229672422660295, + -1.7464692834364435, -0.3663893352922596, -0.8769410861076103, 1.0608706111705635, -0.9262810932490861, + -0.5468726859924029, 1.7966956706569919, -1.0663859467239112, 0.7378193848221777, 1.068192632208282, + 1.3312476842417755, -1.5902814653913628, -1.0131078061920382, -0.6235296781383814, 0.37504339068020176, + 0.9126508132330242, 0.3999532385546267, 0.6552059838941, -1.6053942866342332, 1.7900258342291853, + 0.08171912833062756, 1.6137883979635745, 1.3466843147948442, 1.8505801094553158, 1.3728813966930913, + -0.4473279660140852, -0.20009909626620814, 1.4067472413245437, 0.36658966226851764, 1.4566800897303072, + -0.11633958899045194, -1.9458410018060368, 0.5651869174802018, -0.9885077925334622, 0.24385043055374833, + -0.4407908079663816, -1.7015252126482139, 1.5396273477916198, -0.9801103159055833, 0.9331708410017399, + 0.058036076446482454, 0.29277070369481883, 1.6896333641554682, -0.2872886303585469, 0.2981100430160728, + 0.1670720357805502, -1.6828245857476496, -1.4681960401028125, -0.9933436100210251, -1.827639383468739, + 0.08433714147463611, -1.1318904274562795, 0.9840669856671846, 0.8204547128989219, 0.5959008566248984, + -0.22424536381303728, 1.765380932910376, 1.050492887173749, 0.8249285352430338, 1.5823516671950122, + -1.4695844512182843, 1.4009128159343485, 1.0886951647082785, -0.4963319371911856, -1.8633848779197413, + 0.660465126445569, -1.2319891082878298, 1.6547000157065659, -1.6403428022350113, -1.2308283749125177, + 0.9142339764828238, 0.18691349086990705, -1.148271069003111, -1.266859733272054, -1.4482873768560758, + 1.6888579757850914, 1.5392518897570104, 0.41499451567073464, -1.0517290742419663, -0.9856143466540894, + 0.704611691207357, -0.27871441123648655, 0.445828139270918, -0.8125969294930622, 1.521716695437079, + 0.5657668386735519, -1.813374372841099, 1.0076529676525672, -0.5864288471977783, 0.5855480422270194, + -1.8330974772064481, 0.9782157266479414, -1.6230556142249775, 0.5265126362718373, -1.6878701852107563, + -0.3955226747487526, -1.3888929741627605, 0.2905034183357449, 1.0489208524387843, -0.3118857187498678, + -0.6289506096761981, -0.05735383950307149, 1.8668941791416147, 0.8898345005884769, 1.7147482078759548, + -0.12387314928310289, 0.2298818139402634, 1.9294076224252024, -0.43580099597679656, 1.7512542893273144, + -1.258214124547644, 0.9779750741630782, -0.2566261319632144, -1.9813300069235993, -1.3498734101224414, + 0.7506344777083953, 1.8867470646651894, -1.918953273635191, 1.7429571494233906, 0.7638060343526085, + -0.44782770384121484, -1.1300950570142518, -1.4753506380821149 + ], + "dims": [2560], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 0.027313262224197388, -0.005701353773474693, 0.1959753781557083, 0.10011828690767288, -3.6098804473876953, + 0.00864929985255003, -0.011981655843555927, -0.11036527156829834, 0.6647213101387024, 0.276733934879303, + -0.6354819536209106, -0.014075735583901405, -0.059462033212184906, -0.3388662040233612, + -0.017422985285520554, -0.043299876153469086, -0.16756349802017212, -0.07582926005125046, + -0.16514767706394196, 2.9962074756622314, -2.600733757019043, 0.04413439333438873, 0.07896167039871216, + 1.1207873821258545, -0.032255738973617554, -0.09964963793754578, -2.1782073974609375, + -0.01814177632331848, 0.08586198836565018, 0.380964457988739, -0.01918521337211132, 0.006902141962200403, + 0.0669674500823021, -0.09234043955802917, -1.0496017932891846, 0.020094068720936775, -0.11474193632602692, + -0.056350305676460266, -2.2275612354278564, -2.648808240890503, -0.017779357731342316, + -0.2514607608318329, 0.008559616282582283, 0.010673644952476025, -0.32376542687416077, + -0.16903237998485565, 0.026010606437921524, -2.163571357727051, 0.35461699962615967, 1.5194188356399536, + -1.1094666719436646, -0.012471643276512623, 0.0767873078584671, -0.21644049882888794, + -0.043257202953100204, 0.001341399853117764, -0.1367240697145462, 0.005313852336257696, 2.144134759902954, + 0.11904949694871902, -0.26428619027137756, 0.014375614002346992, 0.06913577765226364, -4.196413516998291, + -5.172718524932861, 0.06162356957793236, -0.0010976337362080812, 0.21020400524139404, 4.567638397216797, + -0.059758733958005905, 5.990215301513672, 0.19405193626880646, 0.003011247143149376, -0.1036064475774765, + -0.016247211024165154, -0.12790939211845398, -0.08561908453702927, 0.25051021575927734, + 0.07514326274394989, -0.6767844557762146, 1.5661166906356812, -4.326471328735352, 0.07481537014245987, + -0.7969828248023987, -0.45468205213546753, 0.21233250200748444, 3.420551061630249, -0.0759267508983612, + 0.16086462140083313, -0.3939729928970337, 1.3957020044326782, -0.2972649931907654, -0.31666669249534607, + 0.35118427872657776, 0.3117898404598236, 0.088602215051651, 0.17165301740169525, -0.8542330265045166, + -0.06893759965896606, 0.08126193284988403, 0.02327258512377739, -0.1314769983291626, 0.035079699009656906, + 1.2096712589263916, 0.9461245536804199, -6.337772846221924, 5.575413703918457, -3.9876515865325928, + 0.01430205162614584, -2.093717098236084, -0.056584782898426056, 0.05612698942422867, -0.01935030147433281, + 0.0010159225203096867, -0.38109132647514343, -0.000587565591558814, 0.12273997068405151, + -1.4854758977890015, 0.016024703159928322, -0.05192752555012703, -0.257480651140213, -4.023406982421875, + 0.03150588274002075, -5.065948486328125, -0.07601942121982574, -0.04482676833868027, -0.01937261037528515, + -6.9667582511901855, -0.05368780344724655, 0.6142992377281189, 0.3128206431865692, -0.3862888216972351, + 0.053061991930007935, -0.24360240995883942, -0.018439287319779396, 0.1868235021829605, + 0.005632609128952026, -0.10385553538799286, 1.2077943086624146, -0.07107469439506531, 1.771382212638855, + 0.3696843981742859, -0.31587034463882446, 0.0002820117224473506, 0.055834002792835236, + -0.7621694803237915, -3.773604393005371, -0.12602387368679047, 0.8626934289932251, -4.139935493469238, + -0.08643748611211777, -2.25795841217041, -0.025201046839356422, -0.28647178411483765, -0.5088312029838562, + -2.3566224575042725, -0.20447342097759247, 0.3922976553440094, 0.047735944390296936, -0.09984598308801651, + 4.436963081359863, 0.17725177109241486, 0.01968466490507126, -0.4080508351325989, 0.600350558757782, + -0.1489681750535965, -2.3178586959838867, 0.010645782575011253, -0.5052445530891418, 0.12876634299755096, + 2.72904109954834, 0.007315368857234716, 0.503023624420166, 0.9695355892181396, 0.4959081709384918, + -3.562389612197876, 0.3780525028705597, -0.194877028465271, -1.0815603733062744, 0.6436595320701599, + -0.10088582336902618, -0.06308454275131226, -3.7394943237304688, -0.0011674398556351662, + -0.19378826022148132, -0.2329375147819519, 0.029814809560775757, 0.20438098907470703, + -0.23114298284053802, 0.026816120371222496, -7.350013256072998, -0.011900502257049084, 2.0180928707122803, + 0.20987474918365479, -3.209254503250122, -7.5496602058410645, 0.232008695602417, 0.0027162893675267696, + -0.4211888611316681, 0.287914901971817, 0.028367964550852776, 0.015583046711981297, 0.07393462210893631, + 0.6514078974723816, -0.04090245068073273, -0.004561522509902716, 0.2931022346019745, -0.4355356991291046, + -0.1867547184228897, -5.984931945800781, 0.044270407408475876, -0.35987964272499084, + -0.033762961626052856, -0.2677021622657776, -0.013161826878786087, -0.010206296108663082, + -4.528798580169678, 0.4174078106880188, 0.12906667590141296, 0.04690857604146004, -0.08034832775592804, + 1.5188398361206055, 1.3247699737548828, 0.011872933246195316, 0.055544108152389526, 0.0025585023686289787, + -7.696174621582031, 0.030730921775102615, 0.039231084287166595, -0.4407111704349518, -0.3110845386981964, + 2.284346342086792, -0.027610689401626587, 0.09054349362850189, 1.7885178327560425, -0.11802572757005692, + 0.03795969486236572, 2.373623847961426, 0.11311819404363632, 0.009557336568832397, -0.02887658029794693, + -0.28853726387023926, -0.17708882689476013, -0.22821268439292908, 0.0237746462225914, 3.257477283477783, + 0.2507217526435852, 0.17421714961528778, -0.12231585383415222, 0.18179824948310852, -0.3428541123867035, + 0.024907970800995827, 0.2441745400428772, 1.13312828540802, -0.0009440237190574408, -0.594701886177063, + -0.008615869097411633, 4.071537017822266, 0.6198470592498779, -0.3097928464412689, -0.4404515027999878, + -7.008431911468506, 0.024559520184993744, 0.1267288327217102, 0.2140975296497345, 0.5778637528419495, + -0.03296203166246414, 0.8842242360115051, 0.16367295384407043, -0.3035202920436859, -0.09384048730134964, + 0.6805808544158936, -0.2706672251224518, -1.429656982421875, -0.1497703641653061, 0.4302230775356293, + -1.864505648612976, 0.01007054653018713, 0.23598365485668182, -0.08086620271205902, 0.001842734171077609, + 0.08458849042654037, 0.3059651553630829, -0.06515960395336151, -3.803208589553833, -0.41865429282188416, + 0.2828770875930786, 3.459416151046753, -0.00129605398979038, -9.578699111938477, -0.06560757756233215, + 0.026055261492729187, 0.0672057718038559, 0.08423102647066116, -1.3624160289764404, 0.013521464541554451, + -0.027731282636523247, -0.9650477766990662, -0.012694457545876503, -0.2116907835006714, + -0.10714730620384216, 0.0034909709356725216, 1.5338910818099976, 0.0006434338865801692, + 0.1618947833776474, -0.10659407079219818, -6.774624347686768, -0.08567759394645691, 0.5162889361381531, + 0.11074300855398178, -0.09961605817079544, -0.005474632140249014, 0.1132681593298912, 0.10878968983888626, + -0.4140564203262329, -6.274385452270508, -3.410104274749756, -0.2155490219593048, 0.13330507278442383, + -0.2973288297653198, -0.5738739371299744, 0.3465871810913086, -0.2567448318004608, -0.13507360219955444, + -0.014550707302987576, 0.039058394730091095, -0.25891509652137756, -0.30598220229148865, + -0.14163219928741455, 1.2217881679534912, -0.2967555820941925, 0.024605438113212585, -0.03864026814699173, + -1.4379907846450806, -3.0257911682128906, 10.609665870666504, -0.0002576113329268992, 0.1658751666545868, + -0.01822504773736, 0.1141287237405777, -0.3072766363620758, 2.9927172660827637, 0.42983293533325195, + 0.9799204468727112, -0.007520963903516531, 3.565046787261963, -0.18206597864627838, 1.1247198581695557, + -0.0011717785382643342, 0.0026591955684125423, 3.689824104309082, 0.03598639369010925, + -0.09997520595788956, 0.06576227396726608, -1.7916548252105713, 0.030312752351164818, -0.4527510106563568, + -0.26613515615463257, 0.025749003514647484, -0.17866003513336182, 0.18729515373706818, + 0.003528681118041277, -0.1579633355140686, 1.070467472076416, -0.20637144148349762, -0.10882926732301712, + -6.439236640930176, -0.25033196806907654, -0.26708030700683594, 0.036800775676965714, -1.9130735397338867, + 0.11082696169614792, 0.10686857253313065, 7.136363506317139, -2.1805343627929688, 0.002802944276481867, + -1.0081117153167725, -0.08366546779870987, -0.07263432443141937, -0.011199882254004478, + -0.015524221584200859, -0.008838756941258907, -0.005488056223839521, 0.6502953767776489, + -0.010726823471486568, 0.41685575246810913, -0.23590049147605896, -0.0868658497929573, + -0.07914192229509354, 0.22732190787792206, 0.0985199362039566, 0.013477811589837074, 0.5970719456672668, + -0.12020514905452728, -0.0009808604372665286, 0.1139480322599411, -0.5872443914413452, -5.610537528991699, + 0.14893069863319397, 0.44541916251182556, 0.599539041519165, -0.028194887563586235, -0.42580458521842957, + -0.24352392554283142, 0.25486475229263306, -3.251058578491211, 0.042388759553432465, -0.15446388721466064, + -0.01016155257821083, 0.07647261768579483, 4.707305431365967, -0.0834866315126419, -0.21240641176700592, + 0.34789028763771057, 0.0710633248090744, 0.013448074460029602, -0.18779638409614563, 0.022113602608442307, + -1.8543815612792969, 0.012882213108241558, 0.0508059561252594, -2.1125378608703613, 1.00347900390625, + 0.34287792444229126, 0.023498279973864555, 0.2604916989803314, -0.854418158531189, 0.3368889391422272, + -3.5361156463623047, -0.3238249719142914, 0.09940877556800842, 0.011137581430375576, -0.09505806118249893, + 1.4575674533843994, 0.18798890709877014, 0.13481135666370392, -3.1009016036987305, -0.0046508763916790485, + -0.002944883657619357, 0.008598391897976398, -0.05753857269883156, -0.007956058718264103, + 0.7023902535438538, -2.114570140838623, 0.6187217235565186, 4.448208808898926, 2.5069539546966553, + -0.10476846992969513, -0.04466601461172104, 0.32297447323799133, 0.06604880094528198, + -0.0016604098491370678, -2.8530216217041016, -1.2369211912155151, -0.02766953594982624, + -0.025159431621432304, 0.0029653196688741446, 0.04569535329937935, 0.03927958756685257, + -0.0021295847836881876, 0.024881726130843163, 0.028491219505667686, -0.0042065782472491264, + -0.05266435816884041, -0.08988969027996063, -0.04083021357655525, -0.040847159922122955, + 3.154191732406616, -0.06132543459534645, -0.7507759928703308, -0.029571423307061195, 0.03537856787443161, + -1.4017058610916138, 0.3888748586177826, 0.9719987511634827, -0.010947618633508682, 3.847195863723755, + 1.015498161315918, 0.012234801426529884, -0.3849196434020996, 0.5072981119155884, -0.07829593122005463, + 0.2524659037590027, -0.13102610409259796, 0.020525088533759117, 0.15267324447631836, -0.11044808477163315, + 0.008630136027932167, 0.0009689829312264919, 2.615210771560669, 0.3638320863246918, 0.2452821582555771, + 0.01092306338250637, 0.03127167001366615, -3.899691104888916, 0.16573800146579742, -0.06733611971139908, + -0.39246127009391785, 5.207749843597412, 0.05021298676729202, 0.17778877913951874, -1.4260956048965454, + 0.19870443642139435, -0.27705708146095276, -0.11092191934585571, -0.09528861939907074, + -0.5703115463256836, 0.5077508687973022, 0.3938415050506592, 0.47991737723350525, 0.7821948528289795, + -0.2891596853733063, -0.3829837143421173, -0.010832893662154675, 0.15224608778953552, + -0.00014581253344658762, 0.00025647180154919624, 0.02536843903362751, -0.06366542726755142, + -8.023703575134277, 0.027589797973632812, -0.1799485832452774, -0.2505863904953003, -0.3841714859008789, + -0.00031740960548631847, -0.04642002657055855, 0.38759565353393555, -0.05341910943388939, + -0.37632811069488525, 0.6983012557029724, -0.10781889408826828, -0.0007781427120789886, + -0.0877101942896843, -0.5221861600875854, -0.07871037721633911, -2.2496471405029297, + -0.042697690427303314, 2.38197922706604, 0.035262834280729294, 0.0695495679974556, 1.6927565336227417, + -10.396214485168457, 0.05338706448674202, -2.813828468322754, -3.691652536392212, -0.008508461527526379, + 1.570121169090271, -0.4011033773422241, -0.24479898810386658, 0.30835238099098206, 0.1998486965894699, + 0.0945337787270546, 0.39656326174736023, -0.23758645355701447, 0.1661674976348877, -0.04912934452295303, + 0.024212513118982315, -1.0319569110870361, 0.04704924300312996, -0.058226123452186584, + -1.6492913961410522, 0.6406868100166321, -0.005447663366794586, 0.19865849614143372, -0.3373563289642334, + -0.03675329312682152, -0.19241032004356384, -0.43262550234794617, -0.08300381153821945, + -0.014068910852074623, 0.11309102177619934, -0.02719084918498993, 0.2096254676580429, + -0.02292095310986042, 0.4072689712047577, 0.0003724964044522494, 0.20711149275302887, 1.0793871879577637, + 0.06120060756802559, 0.11688049882650375, -0.0023522432893514633, -0.9283630847930908, 2.477475881576538, + 0.26047614216804504, 0.7143173813819885, -1.4795730113983154, -0.15119962394237518, -1.4587875604629517, + -0.03378799930214882, -0.3518248498439789, 0.1747346669435501, 0.002720446093007922, 0.865147590637207, + 0.015568590722978115, 0.1952929049730301, 0.1818414330482483, -1.4265116453170776, 0.2012042999267578, + -0.2151491343975067, 0.11098571866750717, -0.16003955900669098, 7.798532962799072, 0.299221396446228, + -1.0280503034591675, -0.1838797777891159, 0.005458994302898645, -0.13420982658863068, + -0.06905319541692734, -1.8678100109100342, 0.40917104482650757, -0.09650467336177826, 0.2953720986843109, + 0.008414564654231071, 0.1998010128736496, 0.34882158041000366, -0.17196929454803467, 0.031611330807209015, + 0.08629407733678818, -0.1856321394443512, -0.22879824042320251, -0.09241079539060593, 0.2628664970397949, + 0.03050280548632145, 0.15829861164093018, -0.06391621381044388, -0.01048242673277855, + -0.010927671566605568, 1.0013593435287476, -0.15796290338039398, -0.10746872425079346, + -0.0013137627393007278, -0.2024063915014267, 0.0005700114998035133, 0.03609214723110199, + -0.4168614447116852, 0.12660957872867584, -0.005800928454846144, -0.5319929718971252, 0.32967525720596313, + 0.028021158650517464, -1.217489242553711, -0.09096843004226685, 2.344956636428833, 5.432365894317627, + 0.7993219494819641, -0.15543963015079498, -0.0007157247746363282, -0.08481337875127792, + -0.12131065130233765, 1.1516586542129517, -0.01504613272845745, 0.03704383224248886, 0.004402496851980686, + -0.581766664981842, 0.07592090964317322, 0.3745843768119812, -0.4989067614078522, 0.04438084363937378, + 3.175107479095459, -4.4077911376953125, -0.002988559426739812, 1.0332038402557373, 0.006027049385011196, + -0.0018258332274854183, 2.3033533096313477, -0.10134122520685196, 0.02520211599767208, + 0.005497010890394449, 0.0003968894889112562, -0.00029831190477125347, 1.2718113660812378, + -0.34650272130966187, 9.8225736618042, -6.33831787109375, -0.9639870524406433, -4.028343677520752, + 0.016925739124417305, -0.0449683852493763, 0.6271276473999023, -0.13903772830963135, -0.06179821118712425, + 5.860689640045166, -0.005071749445050955, 0.5026626586914062, 0.3309268057346344, 2.2567200660705566, + -4.23521089553833, -0.01613122597336769, -0.02665529027581215, 0.04668727517127991, 3.150425434112549, + -0.0052042026072740555, 1.067151427268982, 0.025044972077012062, -1.325771689414978, -0.1094195619225502, + 0.26904425024986267, 0.6204038262367249, 0.006285298615694046, 0.002915250603109598, -0.6165238618850708, + -4.090943813323975, 2.8669519424438477, -0.09453117847442627, -0.09316729754209518, 0.034191157668828964, + -6.707476615905762, -0.20231620967388153, -1.6191682815551758, 2.0373117923736572, -0.10501966625452042, + 0.0019581259693950415, 0.21420015394687653, 0.0156276635825634, 7.224427223205566, 0.1236666664481163, + 0.294806569814682, -0.0061331382021307945, -0.10612531006336212, -0.8333144187927246, + -0.001029952079989016, 0.38204053044319153, -0.03597458079457283, -1.41422438621521, -0.2833155691623688, + -0.0006075138808228076, -0.3701440095901489, 0.1309424191713333, 0.06839437037706375, + -0.0017361472127959132, -1.69569993019104, -0.20629459619522095, -0.5999218225479126, 0.114132359623909, + 6.6828436851501465, -0.36263933777809143, 0.41539111733436584, 0.022192703559994698, -0.06610587984323502, + -1.683022141456604, -0.2835130989551544, 0.27643388509750366, -0.6247501373291016, -1.421617865562439, + -0.08159351348876953, 0.005017416086047888, -2.026592493057251, 0.0009393739746883512, 1.760980486869812, + 0.00019237841479480267, 0.0022294363006949425, 0.22415778040885925, -0.09657209366559982, + -3.056180953979492, -0.24515365064144135, -1.7638490200042725, -1.900456428527832, 1.7747641801834106, + -0.9473960399627686, 0.27619242668151855, -0.11893711239099503, 0.7769895792007446, 0.09835439175367355, + 0.0019296495011076331, -0.043601375073194504, 0.03626292571425438, 0.1591210663318634, + 0.45964139699935913, -0.06853707879781723, -2.4563350677490234, -0.13421472907066345, + 0.040955424308776855, -0.2855738699436188, 0.11433675140142441, 0.00306497560814023, -0.48573875427246094, + -0.046301428228616714, 0.6907474994659424, -3.983771800994873, 2.3131954669952393, 0.05256381258368492, + -0.0911293551325798, -1.8945766687393188, 0.03453084081411362, 1.8747694492340088, 0.3433213233947754, + -1.1485600471496582, 1.6418366432189941, -0.13894057273864746, -3.1275031566619873, -0.9726752638816833, + -0.0012102212058380246, 3.898921251296997, 0.2646528482437134, 0.01665414497256279, -0.06312943994998932, + -3.0655016899108887, 0.024803519248962402, -0.25584203004837036, -1.3387784957885742, + -0.03684002161026001, -0.524848461151123, -0.9969499707221985, -1.8777778148651123, -0.14820723235607147, + -8.182234764099121, 0.015234949067234993, -0.010302969254553318, 0.10785042494535446, + -0.02237902209162712, -2.500221014022827, -0.006121153011918068, 0.054380882531404495, 2.607618808746338, + -0.48403942584991455, 1.7271841764450073, -0.054084569215774536, 0.04733904451131821, 0.29113033413887024, + 0.3090323507785797, -0.4069989323616028, 0.827186644077301, -0.8676308393478394, -0.18980173766613007, + 0.017093969509005547, 0.05046425387263298, 0.025303032249212265, -0.9938563704490662, 0.0307749193161726, + -0.003506980137899518, -2.145794153213501, 0.08889109641313553, -0.2744760513305664, 0.02533753775060177, + -0.008416163735091686, 1.6139867305755615, -6.39102840423584, -4.842134952545166, -0.7291613817214966, + 0.9694556593894958, 0.07247202843427658, 0.005149913020431995, 0.0029090321622788906, -0.1867554932832718, + 0.0015806574374437332, -0.04847263917326927, -0.18512502312660217, -0.04184968024492264, + -1.2331782579421997, -0.6159178018569946, 0.025481248274445534, 0.11850030720233917, -0.2734290063381195, + 0.26392263174057007, -0.2278929203748703, -2.4300358295440674, 0.0007563972030766308, 1.2603007555007935, + -0.009525062516331673, -2.5598459243774414, -0.1015859916806221, -0.3136966824531555, + -0.0023580349516123533, 3.0281076431274414, 0.0851983055472374, 0.18700700998306274, + -0.008541906252503395, -0.007119827438145876, 0.42274990677833557, -0.06235692277550697, + 0.3246764540672302, 0.047069381922483444, 0.00011004792031599209, -0.49105241894721985, + 0.041874051094055176, 0.013326031155884266, -2.525364875793457, 0.5126351714134216, -0.01582668349146843, + -0.3391125500202179, -0.1049877479672432, -0.36534854769706726, 0.027926098555326462, + 0.004374756012111902, -0.10876958072185516, 0.579942524433136, 0.34367814660072327, 0.12710949778556824, + -0.28762391209602356, 0.028134111315011978, 0.001072783605195582, -0.430772066116333, 0.21052536368370056, + 0.09690351784229279, 0.000786184798926115, 0.06906910240650177, -3.5896573066711426, 0.24118882417678833, + -3.176041841506958, 1.3121603727340698, -0.40836477279663086, -7.590582370758057, -1.9390276670455933, + -0.06406442821025848, 0.00011302023631287739, 0.013246525079011917, 0.21886053681373596, + 0.090825155377388, -0.06342892348766327, -0.14027893543243408, 0.017751706764101982, 0.11045858263969421, + -0.05397825688123703, 0.2152465432882309, 0.14184458553791046, -1.6443814039230347, -1.023624300956726, + 0.050706081092357635, -0.8185511231422424, -0.009972692467272282, -1.6231411695480347, + -0.010527506470680237, 1.5382870435714722, -2.6943516731262207, 0.965884804725647, -0.5423170924186707, + -2.0661613941192627, -0.4436858892440796, 0.0058816042728722095, -0.665194034576416, 0.8273401260375977, + 0.10996203124523163, -0.1316700130701065, 0.027179520577192307, -0.2735114097595215, -0.10301132500171661, + -1.906333565711975, -0.32074108719825745, 0.4478001892566681, -1.1052520275115967, 0.009423047304153442, + 0.5322814583778381, -0.004648196045309305, -0.009632693603634834, -0.7735386490821838, + 0.005249344743788242, 0.11850841343402863, -0.0034776863176375628, -0.1439099758863449, + 0.2767007648944855, -2.8716399669647217, -0.16290035843849182, -0.1801692247390747, 0.19117145240306854, + -0.7634338736534119, -0.29985561966896057, 0.009378351271152496, -0.6186265349388123, + -0.13845475018024445, 0.03558935597538948, -0.20145508646965027, -0.5337783694267273, 0.28876203298568726, + -0.5732369422912598, 0.03304499760270119, 0.8687714338302612, -0.2524224817752838, 0.4371426999568939, + 0.03568745777010918, 0.4382450580596924, 0.03245728462934494, -0.14247629046440125, 0.7598915696144104, + 0.30114904046058655, -0.21331092715263367, -0.0028205476701259613, 0.09227168560028076, + 0.008056613616645336, 1.635034203529358, -0.166751429438591, -0.020675446838140488, -1.2244166135787964, + -0.10547340661287308, 2.802537441253662, -0.004014655947685242, 3.690307140350342, 0.0017954192589968443, + -0.45281466841697693, 0.020796259865164757, 1.5265557765960693, 0.20084713399410248, -0.21376214921474457, + -0.025406286120414734, 1.9211277961730957, -0.7583361268043518, 4.267587661743164, -4.551294803619385, + -0.08887865394353867, 0.07695532590150833, -0.17959536612033844, 0.5096666216850281, -9.957548141479492, + -0.11618410050868988, 0.09543278813362122, 0.270590603351593, 0.024046115577220917, -0.245524600148201, + -0.307966023683548, 2.2781827449798584, -0.14958485960960388, -0.06977607309818268, 2.3428077697753906, + 0.8067795038223267, 0.9448233842849731, 0.35110360383987427, 0.4814533293247223, 0.00956026278436184, + -0.03395213186740875, 0.2255835384130478, 0.4806722402572632, 0.0005861452082172036, -0.5671629309654236, + 0.5004423260688782, 7.04985237121582, -0.41439759731292725, -0.2847898304462433, -0.10965365916490555, + 0.3427604138851166, 0.12897160649299622, -0.6046913266181946, -0.1840457171201706, 0.002393739065155387, + 0.41798320412635803, 3.0662004947662354, -0.0002512158825993538, 3.1039047241210938, -0.6795744895935059, + 0.3395420014858246, 0.33144497871398926, 6.939244270324707, 0.0011752373538911343, 0.09660591185092926, + 0.4894343912601471, 4.00507116317749, 0.005761822685599327, 0.5680277943611145, 3.315598726272583, + -0.5373169779777527, 0.44444069266319275, -0.0027646978851407766, 1.6829670667648315, -2.6327450275421143, + 7.026787757873535, 1.5361366271972656, -0.3418486714363098, -0.21611177921295166, -0.05756537616252899, + -0.023443035781383514, 0.23109422624111176, 0.21568432450294495, 1.236294150352478, -0.4581376612186432, + -4.188883304595947, 0.28338590264320374, -0.04076884314417839, -0.0198111142963171, -1.1872543096542358, + -0.062372151762247086, 2.7870607376098633, -0.6517040729522705, -1.7516529560089111, -1.8091800212860107, + -0.15286022424697876, 0.09354468435049057, 0.11334053426980972, -1.8259668350219727, + -0.017136069014668465, 0.12429703027009964, 0.00773972412571311, -1.0446529388427734, 0.2356342375278473, + 0.9886437654495239, 0.18150633573532104, 0.4682118594646454, -0.11415664851665497, 0.003153775818645954, + 0.03332524746656418, -1.1180988550186157, 4.163827896118164, -0.18159173429012299, 0.0999181717634201, + 2.058738946914673, 4.2595014572143555, 0.010485258884727955, 0.16270849108695984, -0.9842506051063538, + -0.0003203578235115856, 6.040156364440918, 1.308574914932251, -0.36853301525115967, -0.29669252038002014, + 0.2741331160068512, -0.3422505855560303, -0.7587988972663879, 2.9686927795410156, 0.8209773898124695, + -0.0007857486489228904, 0.02395622618496418, 0.05102493613958359, 0.15041151642799377, + -0.002741719363257289, -4.980742454528809, 0.2880830764770508, 0.24828404188156128, 0.15224777162075043, + 0.13059845566749573, 1.5662918090820312, -0.8474074006080627, 0.08810389041900635, -0.2630780041217804, + 0.43268874287605286, -0.001932788989506662, 0.012391841039061546, 3.4719245433807373, + -0.0024825239088386297, -0.11508434265851974, 0.40480512380599976, 0.4639693796634674, 1.0610097646713257, + -0.3626452386379242, -0.18480324745178223, 0.2711804509162903, 0.21260038018226624, -0.02785545215010643, + -0.11340377479791641, -0.027071641758084297, -0.22035427391529083, -0.10667331516742706, + 0.16908612847328186, 0.10428506135940552, 0.1233300194144249, -0.06643304973840714, 0.2004912942647934, + 0.09342359751462936, 0.03175133094191551, -1.5805491209030151, -0.4311752915382385, 0.5455954074859619, + 0.15516425669193268, -0.04940091818571091, -0.1447768211364746, 0.21044854819774628, 4.243553161621094, + 0.2046128362417221, -0.2688649892807007, 8.694140434265137, 1.6753796339035034, 0.05555065721273422, + -0.16497156023979187, 0.1828891634941101, 1.505862832069397, 0.08261094242334366, 3.104039430618286, + 5.931700706481934, 0.4487259089946747, -0.011016261763870716, 0.012441087514162064, -0.5082470774650574, + -0.11641799658536911, 0.01356368139386177, -0.01659458689391613, 7.667582988739014, -0.346441388130188, + -5.981383323669434, 7.6806206703186035, -0.0383722260594368, 0.4603561460971832, -0.16010521352291107, + -3.173022985458374, -0.4581749737262726, 0.06485684961080551, -0.4382486939430237, 0.25564998388290405, + 0.21537242829799652, 8.642731666564941, -0.00621763477101922, 1.9488575458526611, -0.030582817271351814, + 0.00024394349020440131, -0.015434199944138527, 1.1591683626174927, 0.29453498125076294, + -0.06397286057472229, 0.2931598126888275, 3.056126594543457, 0.35364240407943726, -0.07242457568645477, + 0.19602720439434052, 3.9850471019744873, -0.141653373837471, 0.7269659042358398, 0.020487932488322258, + -1.1716344356536865, -0.13872867822647095, 0.004658155608922243, 0.0002524183946661651, + -0.3965473771095276, 0.058615222573280334, -0.005210256204009056, -6.64661169052124, + -0.003978024236857891, -0.11946024745702744, -0.15917553007602692, -2.0323069095611572, 3.694988250732422, + -0.2484176903963089, 1.117063283920288, 0.04259220138192177, 5.167333126068115, -4.669308185577393, + 0.0845528095960617, 0.011936561204493046, 5.875088691711426, -0.032051537185907364, -1.2815053462982178, + -0.010812301188707352, -0.021371034905314445, -0.0037929099053144455, 0.09110360592603683, + -0.2871546745300293, -0.2895969748497009, 0.2745248079299927, -2.462489366531372, -1.4751100540161133, + 0.6493473052978516, 0.105231374502182, -0.11311662197113037, 5.5921311378479, -0.11590596288442612, + -4.284017562866211, -3.0435032844543457, 0.2767241597175598, -0.2098703682422638, -0.008056361228227615, + 3.738365888595581, 0.03918733820319176, 0.5250627994537354, -0.03754068538546562, -1.1869362592697144, + 0.0016376300482079387, -0.19201913475990295, -0.12353025376796722, -0.038338642567396164, + -0.5192811489105225, -0.07935589551925659, 2.1531660556793213, -0.002360888756811619, 0.3615436553955078, + -1.499021291732788, 0.6402538418769836, -2.886809825897217, 2.502922296524048, -0.0014745928347110748, + -0.09228605031967163, -0.14953434467315674, 0.2779182493686676, 2.071781635284424, -0.2248198240995407, + 0.5830495357513428, -0.1257641464471817, -0.06734845042228699, 0.003910396713763475, -1.285413146018982, + -5.392889976501465, -0.0003311980399303138, 8.632763862609863, 0.1819709688425064, -0.013432486914098263, + -0.019152071326971054, -0.026376325637102127 + ], + "dims": [1, 1, 1280], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index a249dc807fa0b..7038e2a4f8766 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -28,6 +28,37 @@ } ] }, + { + "name": "ConvTranspose without bias addition A - NHWC", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [10, 40, 40, 60, 200, 160, 90, 240, 160], + "dims": [1, 1, 3, 3], + "type": "float32" + } + ] + } + ] + }, { "name": "ConvTranspose without bias addition B", "operator": "ConvTranspose", @@ -74,26 +105,22 @@ }, { "data": [ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 ], "dims": [4, 4, 2, 2], "type": "float32" }, { - "data": [0.1, 0.2, 0.3, 0.4], + "data": [65, 66, 67, 68], "dims": [4], "type": "float32" } ], "outputs": [ { - "data": [ - 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.19999694824219, - 100.19999694824219, 100.19999694824219, 100.19999694824219, 100.30000305175781, 100.30000305175781, - 100.30000305175781, 100.30000305175781, 100.4000015258789, 100.4000015258789, 100.4000015258789, - 100.4000015258789 - ], + "data": [3365, 3465, 3565, 3665, 3766, 3866, 3966, 4066, 4167, 4267, 4367, 4467, 4568, 4668, 4768, 4868], "dims": [1, 4, 2, 2], "type": "float32" } @@ -115,7 +142,43 @@ "type": "float32" }, { - "data": [1, 1, 1, 1], + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [5], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [11, 25, 28, 19, 32, 86, 99, 55, 40, 114, 131, 67, 29, 73, 80, 41], + "dims": [1, 1, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose with bias addition B - NHWC", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [6, 8, 7, 9, 15, 11, 8, 12, 9], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], "dims": [1, 1, 2, 2], "type": "float32" }, @@ -127,7 +190,7 @@ ], "outputs": [ { - "data": [11, 19, 20, 12, 20, 43, 46, 23, 22, 49, 52, 25, 13, 25, 26, 14], + "data": [11, 25, 28, 19, 32, 86, 99, 55, 40, 114, 131, 67, 29, 73, 80, 41], "dims": [1, 1, 4, 4], "type": "float32" } @@ -251,7 +314,6 @@ } ] }, - { "name": "ConvTranspose- pointwise", "operator": "ConvTranspose", @@ -285,5 +347,50 @@ ] } ] + }, + { + "name": "ConvTranspose with bias addition C", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [{ "name": "kernel_shape", "data": [1, 1], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ], + "dims": [1, 4, 4, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [4, 4, 1, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1021, 1049, 1077, 1105, 1133, 1161, 1189, 1217, 1245, 1273, 1301, 1329, 1357, 1385, 1413, 1441, 1122, + 1154, 1186, 1218, 1250, 1282, 1314, 1346, 1378, 1410, 1442, 1474, 1506, 1538, 1570, 1602, 1223, 1259, + 1295, 1331, 1367, 1403, 1439, 1475, 1511, 1547, 1583, 1619, 1655, 1691, 1727, 1763, 1324, 1364, 1404, + 1444, 1484, 1524, 1564, 1604, 1644, 1684, 1724, 1764, 1804, 1844, 1884, 1924 + ], + "dims": [1, 4, 4, 4], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/data/ops/where.jsonc b/js/web/test/data/ops/where.jsonc new file mode 100644 index 0000000000000..047fd6fd7511b --- /dev/null +++ b/js/web/test/data/ops/where.jsonc @@ -0,0 +1,172 @@ +[ + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] float32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4.0, 8.0, 7.0, 2.0, 4.0, 8.0, 7.0, 1.0], + "dims": [8], + "type": "float32" + }, + { + "data": [1.0, 3.0, 9.0, 6.0, 1.0, 3.0, 9.0, 2.0], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [4.0, 3.0, 7.0, 6.0, 4.0, 3.0, 7.0, 2.0], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] int32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4, 8, 7, 2, 4, 8, 7, 1], + "dims": [8], + "type": "int32" + }, + { + "data": [1, 3, 9, 6, 1, 3, 9, 2], + "dims": [8], + "type": "int32" + } + ], + "outputs": [ + { + "data": [4, 3, 7, 6, 4, 3, 7, 2], + "dims": [8], + "type": "int32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] uint32 T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [4, 8, 7, 2, 4, 8, 7, 1], + "dims": [8], + "type": "uint32" + }, + { + "data": [1, 4294967295, 9, 6, 1, 3, 9, 2], + "dims": [8], + "type": "uint32" + } + ], + "outputs": [ + { + "data": [4, 4294967295, 7, 6, 4, 3, 7, 2], + "dims": [8], + "type": "uint32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3] T[3] T[3] bool T[3] ", + "inputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "bool" + }, + { + "data": [true, true, true, true, true, true, true, true], + "dims": [8], + "type": "float32" + }, + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "float32" + } + ], + "outputs": [ + { + "data": [true, false, true, false, true, false, true, false], + "dims": [8], + "type": "float32" + } + ] + } + ] + }, + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + "name": "T[3 3] T[3 3] T[1] float32 broadcast", + "inputs": [ + { + "data": [true, true, true, true, true, false, false, false, false], + "dims": [3, 3], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8], + "dims": [3, 3], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, -1, -1, -1, -1], + "dims": [3, 3], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/data/ops/where_broadcast.jsonc b/js/web/test/data/ops/where_broadcast.jsonc new file mode 100644 index 0000000000000..ad97177bb101b --- /dev/null +++ b/js/web/test/data/ops/where_broadcast.jsonc @@ -0,0 +1,84 @@ +[ + { + "name": "Where with no attributes", + "operator": "Where", + "attributes": [], + "cases": [ + { + // This failed due to: https://github.com/microsoft/onnxruntime/issues/17405. + "name": "T[3 6] T[3 6] T[1] float32 broadcast", + "inputs": [ + { + "data": [ + true, + true, + true, + true, + true, + false, + false, + false, + false, + false, + false, + true, + true, + true, + true, + true, + true, + true + ], + "dims": [3, 6], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, -1, -1, -1, -1, -1, -1, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + } + ] + }, + { + // This failed due to: https://github.com/microsoft/onnxruntime/issues/17405. + "name": "T[3 1] T[3 6] T[1] float32 broadcast", + "inputs": [ + { + "data": [true, false, true], + "dims": [3, 1], + "type": "bool" + }, + { + "data": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + }, + { + "data": [-1.0], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [0, 1, 2, 3, 4, 5, -1, -1, -1, -1, -1, -1, 12, 13, 14, 15, 16, 17], + "dims": [3, 6], + "type": "float32" + } + ] + } + ] + } +] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 6e65645ef4756..c80f0b04a9abc 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -885,10 +885,10 @@ // // "test_qlinearmatmul_3D", // // "test_quantizelinear_axis", // // "test_quantizelinear", - // "test_range_float_type_positive_delta_expanded", - // "test_range_float_type_positive_delta", - // "test_range_int32_type_negative_delta_expanded", - // "test_range_int32_type_negative_delta", + "test_range_float_type_positive_delta_expanded", + "test_range_float_type_positive_delta", + "test_range_int32_type_negative_delta_expanded", + "test_range_int32_type_negative_delta", "test_reciprocal_example", "test_reciprocal", "test_reduce_l1_default_axes_keepdims_example", @@ -1336,6 +1336,8 @@ "add_int32.jsonc", //"and.jsonc", "asin.jsonc", + "bias-add.jsonc", + "bias-split-gelu.jsonc", "ceil.jsonc", "concat.jsonc", "concat_int32.jsonc", @@ -1386,7 +1388,10 @@ "tan.jsonc", "tile.jsonc", "transpose.jsonc", - "transpose_int32_uint32.jsonc" + "transpose_int32_uint32.jsonc", + "where.jsonc" + // Turn on this when https://github.com/microsoft/onnxruntime/issues/17405 is fixed. + //"where_broadcast.jsonc", //"xor.jsonc" ] }, diff --git a/js/web/test/test-main.ts b/js/web/test/test-main.ts index 49d0ac225be2f..d3592875bb6c7 100644 --- a/js/web/test/test-main.ts +++ b/js/web/test/test-main.ts @@ -57,6 +57,9 @@ if (options.globalEnvFlags) { if (flags.webgpu?.profilingMode !== undefined) { ort.env.webgpu.profilingMode = flags.webgpu.profilingMode; } + if (flags.webgpu?.validateInputContent !== undefined) { + ort.env.webgpu.validateInputContent = flags.webgpu.validateInputContent; + } } // Set logging configuration diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 46d80a9f56f35..628e5408150f8 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -14,7 +14,8 @@ import {Operator} from '../lib/onnxjs/operators'; import {onnx} from '../lib/onnxjs/ort-schema/protobuf/onnx'; import {Tensor} from '../lib/onnxjs/tensor'; import {ProtoUtil} from '../lib/onnxjs/util'; -import {tensorDataTypeStringToEnum} from '../lib/wasm/wasm-common'; +import {createView} from '../lib/wasm/jsep/tensor-view'; +import {getTensorElementSize, isGpuBufferSupportedType, tensorDataTypeStringToEnum} from '../lib/wasm/wasm-common'; import {base64toBuffer, createMockGraph, readFile} from './test-shared'; import {Test} from './test-types'; @@ -136,8 +137,8 @@ async function loadTensors( } async function initializeSession( - modelFilePath: string, backendHint: string, profile: boolean, sessionOptions: ort.InferenceSession.SessionOptions, - fileCache?: FileCacheBuffer): Promise { + modelFilePath: string, backendHint: string, ioBindingMode: Test.IOBindingMode, profile: boolean, + sessionOptions: ort.InferenceSession.SessionOptions, fileCache?: FileCacheBuffer): Promise { const preloadModelData: Uint8Array|undefined = fileCache && fileCache[modelFilePath] ? fileCache[modelFilePath] : undefined; Logger.verbose( @@ -146,8 +147,14 @@ async function initializeSession( preloadModelData ? ` [preloaded(${preloadModelData.byteLength})]` : ''}`); const profilerConfig = profile ? {maxNumberEvents: 65536} : undefined; - const sessionConfig = - {...sessionOptions, executionProviders: [backendHint], profiler: profilerConfig, enableProfiling: profile}; + const sessionConfig = { + ...sessionOptions, + executionProviders: [backendHint], + profiler: profilerConfig, + enableProfiling: profile, + preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined + }; + let session: ort.InferenceSession; try { @@ -181,6 +188,7 @@ export class ModelTestContext { readonly session: ort.InferenceSession, readonly backend: string, readonly perfData: ModelTestContext.ModelTestPerfData, + readonly ioBinding: Test.IOBindingMode, private readonly profile: boolean, ) {} @@ -232,8 +240,8 @@ export class ModelTestContext { this.initializing = true; const initStart = now(); - const session = - await initializeSession(modelTest.modelUrl, modelTest.backend!, profile, sessionOptions || {}, this.cache); + const session = await initializeSession( + modelTest.modelUrl, modelTest.backend!, modelTest.ioBinding, profile, sessionOptions || {}, this.cache); const initEnd = now(); for (const testCase of modelTest.cases) { @@ -244,6 +252,7 @@ export class ModelTestContext { session, modelTest.backend!, {init: initEnd - initStart, firstRun: -1, runs: [], count: 0}, + modelTest.ioBinding, profile, ); } finally { @@ -481,6 +490,130 @@ export class TensorResultValidator { } } +function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor { + if (!isGpuBufferSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) { + throw new Error(`createGpuTensorForInput can not work with ${cpuTensor.type} tensor`); + } + const device = ort.env.webgpu.device as GPUDevice; + const gpuBuffer = device.createBuffer({ + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, + size: Math.ceil(cpuTensor.data.byteLength / 16) * 16, + mappedAtCreation: true + }); + const arrayBuffer = gpuBuffer.getMappedRange(); + new Uint8Array(arrayBuffer) + .set(new Uint8Array(cpuTensor.data.buffer, cpuTensor.data.byteOffset, cpuTensor.data.byteLength)); + gpuBuffer.unmap(); + + // TODO: how to "await" for the copy to finish, so that we can get more accurate performance data? + + return ort.Tensor.fromGpuBuffer( + gpuBuffer, {dataType: cpuTensor.type, dims: cpuTensor.dims, dispose: () => gpuBuffer.destroy()}); +} + +function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) { + if (!isGpuBufferSupportedType(type)) { + throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`); + } + + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(type))!; + const size = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + + const device = ort.env.webgpu.device as GPUDevice; + const gpuBuffer = device.createBuffer({ + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, + size: Math.ceil(size / 16) * 16 + }); + + return ort.Tensor.fromGpuBuffer(gpuBuffer, { + dataType: type, + dims, + dispose: () => gpuBuffer.destroy(), + download: async () => { + const stagingBuffer = device.createBuffer({ + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + size: gpuBuffer.size + }); + const encoder = device.createCommandEncoder(); + encoder.copyBufferToBuffer(gpuBuffer, 0, stagingBuffer, 0, gpuBuffer.size); + device.queue.submit([encoder.finish()]); + + await stagingBuffer.mapAsync(GPUMapMode.READ); + const arrayBuffer = stagingBuffer.getMappedRange().slice(0, size); + stagingBuffer.destroy(); + + return createView(arrayBuffer, type) as ort.Tensor.DataTypeMap[ort.Tensor.GpuBufferDataTypes]; + } + }); +} + +export async function sessionRun(options: { + session: ort.InferenceSession; feeds: Record; + outputsMetaInfo: Record>; + ioBinding: Test.IOBindingMode; +}): Promise<[number, number, ort.InferenceSession.OnnxValueMapType]> { + const session = options.session; + const feeds = options.feeds; + const fetches: Record = {}; + + // currently we only support IO Binding for WebGPU + // + // For inputs, we create GPU tensors on both 'gpu-tensor' and 'gpu-location' binding testing mode. + // For outputs, we create GPU tensors on 'gpu-tensor' binding testing mode only. + // in 'gpu-device' binding mode, outputs are not pre-allocated. + const shouldUploadInput = options.ioBinding === 'gpu-tensor' || options.ioBinding === 'gpu-location'; + const shouldUploadOutput = options.ioBinding === 'gpu-tensor'; + try { + if (shouldUploadInput) { + // replace the CPU tensors in feeds into GPU tensors + for (const name in feeds) { + if (Object.hasOwnProperty.call(feeds, name)) { + feeds[name] = createGpuTensorForInput(feeds[name]); + } + } + } + + if (shouldUploadOutput) { + for (const name in options.outputsMetaInfo) { + if (Object.hasOwnProperty.call(options.outputsMetaInfo, name)) { + const {type, dims} = options.outputsMetaInfo[name]; + fetches[name] = createGpuTensorForOutput(type, dims); + } + } + } + + const start = now(); + Logger.verbose('TestRunner', `Timestamp before session run: ${start}`); + const outputs = await ( + shouldUploadOutput ? session.run(feeds, fetches) : + session.run(feeds, Object.getOwnPropertyNames(options.outputsMetaInfo))); + const end = now(); + Logger.verbose('TestRunner', `Timestamp after session run: ${end}`); + + // download each output tensor if needed + for (const name in outputs) { + if (Object.hasOwnProperty.call(outputs, name)) { + const tensor = outputs[name]; + // Tensor.getData(true) release the underlying resource + await tensor.getData(true); + } + } + + return [start, end, outputs]; + } finally { + // dispose the GPU tensors in feeds + for (const name in feeds) { + if (Object.hasOwnProperty.call(feeds, name)) { + const tensor = feeds[name]; + tensor.dispose(); + } + } + } +} + /** * run a single model test case. the inputs/outputs tensors should already been prepared. */ @@ -491,12 +624,11 @@ export async function runModelTestSet( const validator = new TensorResultValidator(context.backend); try { const feeds: Record = {}; + const outputsMetaInfo: Record = {}; testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor); - const start = now(); - Logger.verbose('TestRunner', `Timestamp before session run: ${start}`); - const outputs = await context.session.run(feeds); - const end = now(); - Logger.verbose('TestRunner', `Timestamp after session run: ${end}`); + testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor); + const [start, end, outputs] = + await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding}); if (context.perfData.count === 0) { context.perfData.firstRun = end - start; } else { @@ -575,6 +707,7 @@ export class ProtoOpTestContext { private readonly loadedData: Uint8Array; // model data, inputs, outputs session: ort.InferenceSession; readonly backendHint: string; + readonly ioBindingMode: Test.IOBindingMode; constructor(test: Test.OperatorTest, private readonly sessionOptions: ort.InferenceSession.SessionOptions = {}) { const opsetImport = onnx.OperatorSetIdProto.create(test.opset); const operator = test.operator; @@ -713,6 +846,7 @@ export class ProtoOpTestContext { model.graph.name = test.name; this.backendHint = test.backend!; + this.ioBindingMode = test.ioBinding; this.loadedData = onnx.ModelProto.encode(model).finish(); // in debug mode, open a new tab in browser for the generated onnx model. @@ -729,8 +863,11 @@ export class ProtoOpTestContext { } } async init(): Promise { - this.session = await ort.InferenceSession.create( - this.loadedData, {executionProviders: [this.backendHint], ...this.sessionOptions}); + this.session = await ort.InferenceSession.create(this.loadedData, { + executionProviders: [this.backendHint], + preferredOutputLocation: this.ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined, + ...this.sessionOptions + }); } async dispose(): Promise { @@ -739,10 +876,11 @@ export class ProtoOpTestContext { } async function runProtoOpTestcase( - session: ort.InferenceSession, testCase: Test.OperatorTestCase, validator: TensorResultValidator): Promise { + session: ort.InferenceSession, testCase: Test.OperatorTestCase, ioBindingMode: Test.IOBindingMode, + validator: TensorResultValidator): Promise { const feeds: Record = {}; - const fetches: string[] = []; - testCase.inputs!.forEach((input, i) => { + const fetches: Record> = {}; + testCase.inputs.forEach((input, i) => { if (input.data) { let data: number[]|BigUint64Array|BigInt64Array = input.data; if (input.type === 'uint64') { @@ -756,7 +894,7 @@ async function runProtoOpTestcase( const outputs: ort.Tensor[] = []; const expectedOutputNames: string[] = []; - testCase.outputs!.forEach((output, i) => { + testCase.outputs.forEach((output, i) => { if (output.data) { let data: number[]|BigUint64Array|BigInt64Array = output.data; if (output.type === 'uint64') { @@ -766,11 +904,11 @@ async function runProtoOpTestcase( } outputs.push(new ort.Tensor(output.type, data, output.dims)); expectedOutputNames.push(`output_${i}`); - fetches.push(`output_${i}`); + fetches[`output_${i}`] = {dims: output.dims, type: output.type}; } }); - const results = await session.run(feeds, fetches); + const [, , results] = await sessionRun({session, feeds, outputsMetaInfo: fetches, ioBinding: ioBindingMode}); const actualOutputNames = Object.getOwnPropertyNames(results); expect(actualOutputNames.length).to.equal(expectedOutputNames.length); @@ -821,7 +959,8 @@ async function runOpTestcase( export async function runOpTest( testcase: Test.OperatorTestCase, context: ProtoOpTestContext|OpTestContext): Promise { if (context instanceof ProtoOpTestContext) { - await runProtoOpTestcase(context.session, testcase, new TensorResultValidator(context.backendHint)); + await runProtoOpTestcase( + context.session, testcase, context.ioBindingMode, new TensorResultValidator(context.backendHint)); } else { await runOpTestcase( context.inferenceHandler, context.createOperator(), testcase, new TensorResultValidator(context.backendHint)); diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts index 1f95d1cd8e682..88915e7972383 100644 --- a/js/web/test/test-types.ts +++ b/js/web/test/test-types.ts @@ -43,6 +43,18 @@ export declare namespace Test { */ export type PlatformCondition = string; + /** + * The IOBindingMode represents how to test a model with GPU data. + * + * - none: inputs will be pre-allocated as CPU tensors; no output will be pre-allocated; `preferredOutputLocation` + * will not be set. + * - gpu-location: inputs will be pre-allocated as GPU tensors; no output will be pre-allocated; + * `preferredOutputLocation` will be set to `gpu-buffer`. + * - gpu-tensor: inputs and outputs will all be pre-allocated as GPU tensors. `preferredOutputLocation` + * will not be set. + */ + export type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location'; + export interface ModelTestCase { name: string; dataFiles: readonly string[]; @@ -54,6 +66,7 @@ export declare namespace Test { name: string; modelUrl: string; backend?: string; // value should be populated at build time + ioBinding: IOBindingMode; platformCondition?: PlatformCondition; cases: readonly ModelTestCase[]; } @@ -82,6 +95,7 @@ export declare namespace Test { inputShapeDefinitions?: 'none'|'rankOnly'|'static'|ReadonlyArray; opset?: OperatorTestOpsetImport; backend?: string; // value should be populated at build time + ioBinding: IOBindingMode; platformCondition?: PlatformCondition; attributes?: readonly AttributeValue[]; cases: readonly OperatorTestCase[]; diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index fd147eaa11f3f..0ed7d887fc5e5 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -42,6 +42,7 @@ from onnxruntime.capi._pybind_state import get_build_info # noqa: F401 from onnxruntime.capi._pybind_state import get_device # noqa: F401 from onnxruntime.capi._pybind_state import get_version_string # noqa: F401 + from onnxruntime.capi._pybind_state import has_collective_ops # noqa: F401 from onnxruntime.capi._pybind_state import set_default_logger_severity # noqa: F401 from onnxruntime.capi._pybind_state import set_default_logger_verbosity # noqa: F401 from onnxruntime.capi._pybind_state import set_seed # noqa: F401 diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index f0385ea5abdfb..48af0a9a6adec 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -140,27 +140,29 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { #endif if (!use_flash_attention) { - if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT - // GPT fused kernels requires left side padding. mask can be: - // none (no padding), 1D sequence lengths or 2d mask. - // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token - // where past state is empty. - bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; - bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && - nullptr == relative_position_bias && - parameters.past_sequence_length == 0 && - parameters.hidden_size == parameters.v_hidden_size && - FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, true); - if (use_causal_fused_runner) { - // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. - if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); + if (is_unidirectional_) { // GPT + if (enable_fused_causal_attention_) { + // GPT fused kernels requires left side padding. mask can be: + // none (no padding), 1D sequence lengths or 2d mask. + // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token + // where past state is empty. + bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; + bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && + nullptr == relative_position_bias && + parameters.past_sequence_length == 0 && + parameters.hidden_size == parameters.v_hidden_size && + FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, true); + if (use_causal_fused_runner) { + // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. + if (nullptr == fused_fp16_runner_.get()) { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + } + + // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. + fused_runner = fused_fp16_runner_.get(); } - - // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. - fused_runner = fused_fp16_runner_.get(); } } else { // BERT bool use_fused_runner = !disable_fused_self_attention_ && diff --git a/onnxruntime/contrib_ops/js/bias_add.cc b/onnxruntime/contrib_ops/js/bias_add.cc new file mode 100644 index 0000000000000..9e70dead6a5da --- /dev/null +++ b/onnxruntime/contrib_ops/js/bias_add.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "bias_add.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + BiasAdd, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + BiasAdd); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bias_add.h b/onnxruntime/contrib_ops/js/bias_add.h new file mode 100644 index 0000000000000..62a4df9bcdf34 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bias_add.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; +JSEP_KERNEL_IMPL(BiasAdd, BiasAdd); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bias_split_gelu.cc b/onnxruntime/contrib_ops/js/bias_split_gelu.cc new file mode 100644 index 0000000000000..e16aa4367d1c7 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bias_split_gelu.cc @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "bias_split_gelu.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsepSupportedFloatTypes; + +ONNX_OPERATOR_KERNEL_EX( + BiasSplitGelu, + kMSDomain, + 1, + kJsExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T", JsepSupportedFloatTypes()), + BiasSplitGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/bias_split_gelu.h b/onnxruntime/contrib_ops/js/bias_split_gelu.h new file mode 100644 index 0000000000000..3b3b41c0ca1f3 --- /dev/null +++ b/onnxruntime/contrib_ops/js/bias_split_gelu.h @@ -0,0 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace js { + +using onnxruntime::js::JsKernel; +JSEP_KERNEL_IMPL(BiasSplitGelu, BiasSplitGelu); + +} // namespace js +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc index 0bf6a4a168e68..4641b006a7785 100644 --- a/onnxruntime/contrib_ops/js/js_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/js/js_contrib_kernels.cc @@ -8,6 +8,8 @@ namespace contrib { namespace js { class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, Gelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasSplitGelu); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, BiasAdd); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSDomain, 1, SkipLayerNormalization); template <> @@ -19,6 +21,8 @@ KernelCreateInfo BuildKernelCreateInfo() { Status RegisterJsContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index dede1ecc95885..1b492a3561396 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -177,9 +177,9 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - // Run layout transformer only for EPs other than CPU EP and provided the preferred layout is NHWC + // Run layout transformer for EPs with preferred layout of NHWC // CPU EP layout transformation happens later when level 3 transformers are run. - if (params.mode != GraphPartitioner::Mode::kAssignOnly && + if (params.mode != GraphPartitioner::Mode::kAssignOnly && params.transform_layout.get() && current_ep.GetPreferredLayout() == DataLayout::NHWC) { for (auto& capability : capabilities) { TryAssignNodes(graph, *capability->sub_graph, ep_type); diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index 38c8a4a4e3d5e..c4eef5b27c1bb 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -63,8 +63,9 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node, auto create_error_message = [&node, &status](const std::string& prefix) { std::ostringstream errormsg; errormsg << prefix << node.OpType() << "(" << node.SinceVersion() << ")"; - if (!node.Name().empty()) errormsg << " (node " << node.Name() << "). "; - if (!status.IsOK()) errormsg << status.ErrorMessage(); + errormsg << " (node:'" << node.Name() << "' ep:'" << node.GetExecutionProviderType() << "'). "; + if (!status.IsOK()) + errormsg << status.ErrorMessage(); return errormsg.str(); }; diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index aff90b8d40bde..8deeb4c2b8b64 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -148,6 +148,10 @@ struct SessionOptions { std::shared_ptr custom_op_libs; void AddCustomOpLibraryHandle(PathString library_name, void* library_handle); #endif + + // User specified logging func and param + OrtLoggingFunction user_logging_function = nullptr; + void* user_logging_param = nullptr; }; } // namespace onnxruntime diff --git a/onnxruntime/core/framework/session_state.cc b/onnxruntime/core/framework/session_state.cc index f0e5fbbd38721..6244d426450a2 100644 --- a/onnxruntime/core/framework/session_state.cc +++ b/onnxruntime/core/framework/session_state.cc @@ -1046,7 +1046,8 @@ Status SessionState::CreateSubgraphSessionState() { auto subgraph_session_state = std::make_unique(*subgraph, execution_providers_, thread_pool_, inter_op_thread_pool_, data_transfer_mgr_, - logger_, profiler_, sess_options_, nullptr, allocators_); + logger_, profiler_, sess_options_, + prepacked_weights_container_, allocators_); // Pass fused function manager to subgraph subgraph_session_state->fused_funcs_mgr_.SetFusedFuncs(fused_funcs_mgr_); diff --git a/onnxruntime/core/framework/tensor.cc b/onnxruntime/core/framework/tensor.cc index a9e5038e19e02..36f03a9b1046a 100644 --- a/onnxruntime/core/framework/tensor.cc +++ b/onnxruntime/core/framework/tensor.cc @@ -27,80 +27,68 @@ int64_t GetSizeFromStrides(const TensorShape& shape, gsl::span st } // namespace #endif -size_t Tensor::CalculateTensorStorageSize(MLDataType p_type, - const TensorShape& shape, - gsl::span strides) { -#ifdef ENABLE_STRIDED_TENSORS - int64_t shape_size = 1; - if (shape.NumDimensions() > 0 && !strides.empty()) { - ORT_ENFORCE(shape.NumDimensions() == strides.size(), "Length of strides doesn't match with tensor dimension size."); - shape_size = GetSizeFromStrides(shape, strides); - } else { - shape_size = shape.Size(); - } -#else - ORT_ENFORCE(strides.empty(), "Strided tensor is supported for training only for now."); +size_t Tensor::CalculateTensorStorageSize(MLDataType elt_type, const TensorShape& shape) { int64_t shape_size = shape.Size(); -#endif - if (shape_size < 0) ORT_THROW("shape.Size() must >=0"); + if (shape_size < 0) + ORT_THROW("shape.Size() must >=0"); if (shape_size > 0) { SafeInt len = 0; - if (!IAllocator::CalcMemSizeForArray(SafeInt(shape_size), p_type->Size(), &len)) + if (!IAllocator::CalcMemSizeForArray(SafeInt(shape_size), elt_type->Size(), &len)) ORT_THROW("tensor failed memory size calculation"); return len; } + return 0; } -Tensor::Tensor(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& alloc, +Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, ptrdiff_t offset, gsl::span strides) - : alloc_info_(alloc) { - ORT_ENFORCE(p_type != nullptr); - Init(p_type, shape, p_data, nullptr, offset, strides); + : alloc_info_(location) { + ORT_ENFORCE(elt_type != nullptr); + Init(elt_type, shape, p_data, nullptr, offset, strides); } -Tensor::Tensor(MLDataType p_type, const TensorShape& shape, std::shared_ptr allocator, - gsl::span strides) +Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, std::shared_ptr allocator) : alloc_info_(allocator->Info()) { - ORT_ENFORCE(p_type != nullptr); - size_t len = Tensor::CalculateTensorStorageSize(p_type, shape, strides); + ORT_ENFORCE(elt_type != nullptr); + size_t len = Tensor::CalculateTensorStorageSize(elt_type, shape); void* p_data = nullptr; if (len > 0) { p_data = allocator->Alloc(len); } - Init(p_type, shape, p_data, allocator, 0L, strides); + Init(elt_type, shape, p_data, allocator, 0L); } -Tensor::Tensor(MLDataType p_type, const TensorShape& shape, void* p_data, std::shared_ptr deleter, +Tensor::Tensor(MLDataType elt_type, const TensorShape& shape, void* p_data, std::shared_ptr deleter, ptrdiff_t offset, gsl::span strides) : alloc_info_(deleter->Info()) { - ORT_ENFORCE(p_type != nullptr); - Init(p_type, shape, p_data, deleter, offset, strides); + ORT_ENFORCE(elt_type != nullptr); + Init(elt_type, shape, p_data, deleter, offset, strides); } void Tensor::InitOrtValue(MLDataType elt_type, const TensorShape& shape, std::shared_ptr allocator, - OrtValue& ort_value, gsl::span strides) { - auto p_tensor = std::make_unique(elt_type, shape, std::move(allocator), strides); + OrtValue& ort_value) { + auto p_tensor = std::make_unique(elt_type, shape, std::move(allocator)); auto ml_tensor = DataTypeImpl::GetType(); ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); } -void Tensor::InitOrtValue(MLDataType p_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, +void Tensor::InitOrtValue(MLDataType elt_type, const TensorShape& shape, void* p_data, const OrtMemoryInfo& location, OrtValue& ort_value, ptrdiff_t offset, gsl::span strides) { auto ml_tensor = DataTypeImpl::GetType(); - auto p_tensor = std::make_unique(p_type, shape, p_data, location, offset, strides); + auto p_tensor = std::make_unique(elt_type, shape, p_data, location, offset, strides); ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); } -void Tensor::InitOrtValue(MLDataType p_type, const TensorShape& shape, +void Tensor::InitOrtValue(MLDataType elt_type, const TensorShape& shape, void* p_data, std::shared_ptr allocator, OrtValue& ort_value, ptrdiff_t offset, gsl::span strides) { auto ml_tensor = DataTypeImpl::GetType(); - auto p_tensor = std::make_unique(p_type, shape, p_data, std::move(allocator), offset, strides); + auto p_tensor = std::make_unique(elt_type, shape, p_data, std::move(allocator), offset, strides); ort_value.Init(p_tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); } @@ -123,27 +111,31 @@ size_t Tensor::SizeInBytes() const { return ret; } -void Tensor::Init(MLDataType p_type, const TensorShape& shape, void* p_raw_data, AllocatorPtr deleter, ptrdiff_t offset, - gsl::span strides) { +void Tensor::Init(MLDataType elt_type, const TensorShape& shape, void* p_raw_data, AllocatorPtr deleter, + ptrdiff_t offset, gsl::span strides) { int64_t shape_size = shape.Size(); - if (shape_size < 0) ORT_THROW("shape.Size() must >=0"); - dtype_ = p_type->AsPrimitiveDataType(); + if (shape_size < 0) + ORT_THROW("shape.Size() must >=0"); + + dtype_ = elt_type->AsPrimitiveDataType(); ORT_ENFORCE(dtype_ != nullptr, - "Tensor is expected to contain one of the primitive data types. Got: ", DataTypeImpl::ToString(p_type)); + "Tensor is expected to contain one of the primitive data types. Got: ", + DataTypeImpl::ToString(elt_type)); shape_ = shape; p_data_ = p_raw_data; - // if caller passed in a deleter, that means this tensor own this buffer - // we will release the buffer when this tensor is deconstructed. + // if caller passed in a deleter we now own p_data_ and must free it in the dtor buffer_deleter_ = std::move(deleter); // for string tensors, if this tensor own the buffer (caller passed in the deleter) // do the placement new for strings on pre-allocated buffer. if (buffer_deleter_ && IsDataTypeString()) { utils::ConstructStrings(p_data_, shape_size); } + byte_offset_ = offset; + #ifdef ENABLE_STRIDED_TENSORS if (shape.NumDimensions() > 0 && !strides.empty()) { - ORT_ENFORCE(shape.NumDimensions() == strides.size(), "Length of strides doesn't match with tensor dimension size."); + ORT_ENFORCE(shape.NumDimensions() == strides.size(), "Length of strides doesn't match tensor dimension size."); strides_.assign(strides.begin(), strides.end()); is_contiguous_ = CheckIsContiguous(); } @@ -156,13 +148,21 @@ Tensor::Tensor(Tensor&& other) noexcept : p_data_(other.p_data_), buffer_deleter_(other.buffer_deleter_), shape_(other.shape_), +#ifdef ENABLE_STRIDED_TENSORS + strides_(other.strides_), + is_contiguous_(other.is_contiguous_), +#endif dtype_(other.dtype_), alloc_info_(other.alloc_info_), byte_offset_(other.byte_offset_) { - other.dtype_ = DataTypeImpl::GetType()->AsPrimitiveDataType(); - other.shape_ = TensorShape(std::vector(1, 0)); other.p_data_ = nullptr; other.buffer_deleter_ = nullptr; + other.dtype_ = DataTypeImpl::GetType()->AsPrimitiveDataType(); + other.shape_ = TensorShape(std::vector(1, 0)); +#ifdef ENABLE_STRIDED_TENSORS + other.strides_ = {}; + other.is_contiguous_ = true; +#endif other.byte_offset_ = 0; } @@ -170,19 +170,28 @@ Tensor& Tensor::operator=(Tensor&& other) noexcept { if (this != &other) { ReleaseBuffer(); - dtype_ = other.dtype_; + p_data_ = other.p_data_; + buffer_deleter_ = other.buffer_deleter_; shape_ = other.shape_; +#ifdef ENABLE_STRIDED_TENSORS + strides_ = other.strides_; + is_contiguous_ = other.is_contiguous_; +#endif + dtype_ = other.dtype_; alloc_info_ = other.alloc_info_; byte_offset_ = other.byte_offset_; - p_data_ = other.p_data_; - buffer_deleter_ = other.buffer_deleter_; - other.dtype_ = DataTypeImpl::GetType()->AsPrimitiveDataType(); - other.shape_ = TensorShape(std::vector(1, 0)); other.p_data_ = nullptr; - other.byte_offset_ = 0; other.buffer_deleter_ = nullptr; + other.shape_ = TensorShape(std::vector(1, 0)); +#ifdef ENABLE_STRIDED_TENSORS + other.strides_ = {}; + other.is_contiguous_ = true; +#endif + other.dtype_ = DataTypeImpl::GetType()->AsPrimitiveDataType(); + other.byte_offset_ = 0; } + return *this; } diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 08ed811d9ac38..fd32aaedcc2ee 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -87,6 +87,7 @@ bool operator!=(const ONNX_NAMESPACE::TensorShapeProto_Dimension& l, } // namespace ONNX_NAMESPACE +namespace onnxruntime { namespace { // This function doesn't support string tensors @@ -162,9 +163,9 @@ static Status GetExternalDataInfo(const ONNX_NAMESPACE::TensorProto& tensor_prot // Uses the tensor_proto_dir to construct the full path for external data. If tensor_proto_dir == nullptr // then uses the current directory instead. // This function does not unpack string_data of an initializer tensor -static Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, - const ORTCHAR_T* tensor_proto_dir, - std::vector& unpacked_tensor) { +Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tensor_proto, + const ORTCHAR_T* tensor_proto_dir, + std::vector& unpacked_tensor) { std::basic_string external_file_path; onnxruntime::FileOffsetType file_offset; SafeInt tensor_byte_size; @@ -184,10 +185,52 @@ static Status ReadExternalDataForTensor(const ONNX_NAMESPACE::TensorProto& tenso return Status::OK(); } + +// TODO(unknown): Change the current interface to take Path object for model path +// so that validating and manipulating path for reading external data becomes easy +Status TensorProtoToOrtValueImpl(const Env& env, const ORTCHAR_T* model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + const MemBuffer* m, AllocatorPtr alloc, + OrtValue& value) { + if (m && m->GetBuffer() == nullptr) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "MemBuffer has not been allocated."); + } + + // to construct a Tensor with std::string we need to pass an allocator to the Tensor ctor + // as the contents of each string needs to be allocated and freed separately. + ONNXTensorElementDataType ele_type = utils::GetTensorElementType(tensor_proto); + if (ele_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING && (m || !alloc)) { + return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "string tensor requires allocator to be provided."); + } + + // Note: We permit an empty tensor_shape_vec, and treat it as a scalar (a tensor of size 1). + TensorShape tensor_shape = GetTensorShapeFromTensorProto(tensor_proto); + const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); + + std::unique_ptr tensor; + + if (m) { + tensor = std::make_unique(type, tensor_shape, m->GetBuffer(), m->GetAllocInfo()); + + if (tensor->SizeInBytes() > m->GetLen()) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The preallocated buffer is too small. Requires ", + tensor->SizeInBytes(), ", Got ", m->GetLen()); + } + } else { + tensor = std::make_unique(type, tensor_shape, alloc); + } + + ORT_RETURN_IF_ERROR(TensorProtoToTensor(env, model_path, tensor_proto, *tensor)); + + auto ml_tensor = DataTypeImpl::GetType(); + value.Init(tensor.release(), ml_tensor, ml_tensor->GetDeleteFunc()); + return Status::OK(); +} + } // namespace -namespace onnxruntime { namespace utils { + #if !defined(ORT_MINIMAL_BUILD) static Status UnpackTensorWithExternalDataImpl(const ONNX_NAMESPACE::TensorProto& tensor, const ORTCHAR_T* tensor_proto_dir, @@ -810,7 +853,7 @@ Status GetExtDataFromTensorProto(const Env& env, const ORTCHAR_T* model_path, * @param env * @param model_path * @param tensor_proto tensor data in protobuf format - * @param tensorp pre-allocated tensor object, where we store the data + * @param tensor pre-allocated tensor object, where we store the data * @return */ Status TensorProtoToTensor(const Env& env, const ORTCHAR_T* model_path, @@ -824,7 +867,7 @@ Status TensorProtoToTensor(const Env& env, const ORTCHAR_T* model_path, const DataTypeImpl* const source_type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); if (source_type->Size() > tensor.DataType()->Size()) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "TensorProto type ", DataTypeImpl::ToString(source_type), - " can not be writen into Tensor type ", DataTypeImpl::ToString(tensor.DataType())); + " can not be written into Tensor type ", DataTypeImpl::ToString(tensor.DataType())); } // find raw data in proto buf @@ -900,44 +943,18 @@ Status TensorProtoToTensor(const Env& env, const ORTCHAR_T* model_path, return Status::OK(); } -#ifdef _MSC_VER -#pragma warning(push) -#pragma warning(disable : 6239) -#endif -// TODO: Change the current interface to take Path object for model path -// so that validating and manipulating path for reading external data becomes easy -Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* model_path, - const ONNX_NAMESPACE::TensorProto& tensor_proto, - const MemBuffer& m, OrtValue& value) { - if (m.GetBuffer() == nullptr) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, - "TensorProtoToMLValue() must take a pre-allocated MemBuffer!"); - } - - ONNXTensorElementDataType ele_type = utils::GetTensorElementType(tensor_proto); - if (ele_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { - return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "string tensor can not use pre-allocated buffer"); - } - - // Note: We permit an empty tensor_shape_vec, and treat it as a scalar (a tensor of size 1). - TensorShape tensor_shape = GetTensorShapeFromTensorProto(tensor_proto); - const DataTypeImpl* const type = DataTypeImpl::TensorTypeFromONNXEnum(tensor_proto.data_type())->GetElementType(); - std::unique_ptr tensorp = std::make_unique(type, tensor_shape, m.GetBuffer(), m.GetAllocInfo()); - if (tensorp->SizeInBytes() > m.GetLen()) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "The preallocated buffer is too small. Requires ", - tensorp->SizeInBytes(), ", Got ", m.GetLen()); - } - - ORT_RETURN_IF_ERROR(TensorProtoToTensor(env, model_path, tensor_proto, *tensorp)); +Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + const MemBuffer& m, OrtValue& value) { + return TensorProtoToOrtValueImpl(env, model_path, tensor_proto, &m, nullptr, value); +} - auto ml_tensor = DataTypeImpl::GetType(); - value.Init(tensorp.release(), ml_tensor, ml_tensor->GetDeleteFunc()); - return Status::OK(); +Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* model_path, + const ONNX_NAMESPACE::TensorProto& tensor_proto, + AllocatorPtr alloc, OrtValue& value) { + return TensorProtoToOrtValueImpl(env, model_path, tensor_proto, nullptr, alloc, value); } -#ifdef _MSC_VER -#pragma warning(pop) -#pragma warning(disable : 6239) -#endif + #define CASE_TYPE(X) \ case ONNX_NAMESPACE::TensorProto_DataType_##X: \ return ONNX_TENSOR_ELEMENT_DATA_TYPE_##X; diff --git a/onnxruntime/core/framework/tensorprotoutils.h b/onnxruntime/core/framework/tensorprotoutils.h index 6f1b60dd2982e..000502ba47594 100644 --- a/onnxruntime/core/framework/tensorprotoutils.h +++ b/onnxruntime/core/framework/tensorprotoutils.h @@ -39,13 +39,28 @@ TensorShape GetTensorShapeFromTensorShapeProto(const ONNX_NAMESPACE::TensorShape TensorShape GetTensorShapeFromTensorProto(const ONNX_NAMESPACE::TensorProto& tensor_proto); /** - * deserialize a TensorProto into a preallocated memory buffer. - * \param tensor_proto_path A local file path of where the 'input' was loaded from. Can be NULL if the tensor proto doesn't - * have any external data or it was loaded from current working dir. This path could be either a - * relative path or an absolute path. + * deserialize a TensorProto into a preallocated memory buffer on CPU. + * \param tensor_proto_path A local file path of where the 'input' was loaded from. + * Can be NULL if the tensor proto doesn't have external data or it was loaded from + * the current working dir. This path could be either a relative path or an absolute path. + * \return Status::OK on success with 'value' containing the Tensor in CPU based memory. */ -common::Status TensorProtoToMLValue(const Env& env, const ORTCHAR_T* tensor_proto_path, - const ONNX_NAMESPACE::TensorProto& input, const MemBuffer& m, OrtValue& value); +common::Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* tensor_proto_path, + const ONNX_NAMESPACE::TensorProto& input, + const MemBuffer& m, OrtValue& value); + +/** + * deserialize a TensorProto into a buffer on CPU allocated using 'alloc'. + * \param tensor_proto_path A local file path of where the 'input' was loaded from. + * Can be NULL if the tensor proto doesn't have external data or it was loaded from + * the current working dir. This path could be either a relative path or an absolute path. + * \param alloc Allocator to use for allocating the buffer. Must allocate CPU based memory. + * \return Status::OK on success with 'value' containing the Tensor in CPU based memory. + */ +common::Status TensorProtoToOrtValue(const Env& env, const ORTCHAR_T* tensor_proto_path, + const ONNX_NAMESPACE::TensorProto& input, + AllocatorPtr alloc, OrtValue& value); + /** * @brief Deserialize a TensorProto into a preallocated empty Tensor * @param env diff --git a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc index 3ce7c40e754dc..d3fc5873cb274 100644 --- a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc +++ b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc @@ -94,7 +94,7 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function> GenerateTransformers( const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; #endif const InlinedHashSet dml_ep = {onnxruntime::kDmlExecutionProvider}; + AllocatorPtr cpu_allocator = std::make_shared(); + switch (level) { case TransformerLevel::Level1: { // RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run) @@ -240,13 +242,14 @@ InlinedVector> GenerateTransformers( // run TransposeOptimizer last as it works in a slightly different way by moving Transpose nodes around. // shouldn't affect the end result - just easier to debug any issue if it's last. - // local CPU allocator is enough as this allocator is finally passed to a local tensor. - // We will also benefit by using a local allocator as we don't need to pass allocator as parameter for EP API refactor - AllocatorPtr cpu_allocator = std::make_shared(); transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); } break; case TransformerLevel::Level2: { + // we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be + // applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2). + transformers.emplace_back(std::make_unique(std::move(cpu_allocator), kCpuExecutionProvider)); + const bool enable_quant_qdq_cleanup = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableQuantQDQCleanup, "0") == "1"; #if !defined(DISABLE_CONTRIB_OPS) @@ -366,16 +369,16 @@ InlinedVector> GenerateTransformers( if (MlasNchwcGetBlockSize() > 1) { transformers.emplace_back(std::make_unique()); } - AllocatorPtr cpu_allocator = std::make_shared(); + auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); if (nhwc_transformer->IsActive()) { transformers.emplace_back(std::move(nhwc_transformer)); } - // NCHWCtransformer should have a higher priority versus this. Because NCHWCtransformer also do the similar things - // of fusion patterns and target on CPU. However, NCHWCtransformer will reorder the layout to nchwc which is only available for - // x86-64 cpu, not edge cpu like arm. But This transformer could be used by opencl-ep/cpu-ep. So - // we will prefer NhwcTransformer once ort runs on x86-64 CPU, otherwise ConvAddActivationFusion is enabled. + + // NchwcTransformer must have a higher priority than ConvAddActivationFusion. NchwcTransformer does similar + // fusions targeting CPU but also reorders the layout to NCHWc which is expected to be more efficient but is + // only available on x86-64. // PR #6351 implemented similar fusion-pattern for CUDA only, and can only fuse conv-add-relu, // while we can fuse more activation. transformers.emplace_back(std::make_unique(cpu_ep)); diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index 9cdc0d9ef0473..c8da15f65a6d7 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -3,22 +3,23 @@ #include "core/optimizer/initializer.h" -#include "core/common/gsl.h" +#include +#include +#include "core/common/gsl.h" #include "core/common/path.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/tensor_external_data_info.h" #include "core/platform/env.h" -#include - namespace onnxruntime { Initializer::Initializer(ONNX_NAMESPACE::TensorProto_DataType data_type, std::string_view name, gsl::span dims) : name_(name), - data_(DataTypeImpl::TensorTypeFromONNXEnum(data_type)->GetElementType(), dims, std::make_shared()) { + data_(DataTypeImpl::TensorTypeFromONNXEnum(data_type)->GetElementType(), dims, + std::make_shared()) { if (!data_.IsDataTypeString()) { memset(data_.MutableDataRaw(), 0, data_.SizeInBytes()); } @@ -39,7 +40,8 @@ Initializer::Initializer(const ONNX_NAMESPACE::TensorProto& tensor_proto, const auto proto_shape = utils::GetTensorShapeFromTensorProto(tensor_proto); // This must be pre-allocated - Tensor w(DataTypeImpl::TensorTypeFromONNXEnum(proto_data_type)->GetElementType(), proto_shape, std::make_shared()); + Tensor w(DataTypeImpl::TensorTypeFromONNXEnum(proto_data_type)->GetElementType(), proto_shape, + std::make_shared()); ORT_THROW_IF_ERROR(utils::TensorProtoToTensor(Env::Default(), model_path.ToPathString().c_str(), tensor_proto, w)); data_ = std::move(w); } diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 2d12c407e6e31..6c91949e467ae 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -13,27 +13,91 @@ using namespace onnx_transpose_optimization; namespace onnxruntime { namespace layout_transformation { +namespace { +// Cost check for aggressively pushing the Transpose nodes involved in the layout transformation further out. +CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const api::NodeRef& node, + const std::vector& perm, + const std::unordered_set& outputs_leading_to_transpose) { + // we aggressively push the layout transpose nodes. + // Exception: pushing through a Concat can result in Transpose nodes being added to multiple other inputs which + // can potentially be worse for performance. Use the cost check in that case. + if (node.OpType() != "Concat" && + (perm == ChannelFirstToLastPerm(perm.size()) || perm == ChannelLastToFirstPerm(perm.size()))) { + return CostCheckResult::kPushTranspose; + } + + // for other nodes use the default ORT cost check + return OrtEPCostCheck(graph, node, perm, outputs_leading_to_transpose); +} + +///

+/// Default function for checking if a node should have its layout changed. Allows EP specific adjustments to the +/// default set of layout sensitive operators if required. +/// +/// Longer term, if required, the EP API could allow the EP to provide a delegate to plugin EP specific logic so we +/// don't hardcode it here. +/// +/// Node to check +/// true if the node should have its layout converted to NHWC. +bool ConvertNodeLayout(const api::NodeRef& node) { + // skip if op is not an ONNX or contrib op + auto domain = node.Domain(); + if (domain != kOnnxDomain && domain != kMSDomain) { + return false; + } + + const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); + + // handle special cases +#if defined(USE_XNNPACK) + if (node.GetExecutionProviderType() == kXnnpackExecutionProvider) { + if (node.OpType() == "Resize") { + // XNNPACK supports NCHW and NHWC for Resize so we don't need to use the internal NHWC domain and wrap the Resize + // with Transpose nodes. EPAwareHandleResize will allow an NCHW <-> NHWC Transpose to be pushed through + // the Resize during transpose optimization. + return false; + } + } +#endif + +#if defined(USE_JSEP) + // TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed + if (node.GetExecutionProviderType() == kJsExecutionProvider) { + if (node.OpType() == "Resize") { + // leave Resize as-is pending bugfix for NHWC implementation. this means the node will remain in the ONNX domain + // with the original input layout. + return false; + } + } +#endif + + // #if defined(USE_CUDA) + // if (node.GetExecutionProviderType() == kCudaExecutionProvider) { + // Update as per https://github.com/microsoft/onnxruntime/pull/17200 with CUDA ops that support NHWC + // } + // #endif + + return layout_sensitive_ops.count(node.OpType()) != 0; +} +} // namespace // Layout sensitive NCHW ops. TransformLayoutForEP will wrap these with Transpose nodes to convert the input // data to NHWC and output data back to NCHW, and move the op to the internal NHWC domain (kMSInternalNHWCDomain). -// The EP requesting these ops MUST be able to handle the node with the operator in the kMSInternalNHWCDomain. +// The EP requesting these ops MUST be able to handle the node with the operator in the kMSInternalNHWCDomain domain. // Once all the layout sensitive ops requested by the EP are wrapped the transpose optimizer will attempt to remove // as many of the layout transposes as possible. const std::unordered_set& GetORTLayoutSensitiveOps() { static std::unordered_set ort_layout_sensitive_ops = []() { const auto& layout_sensitive_ops = onnx_transpose_optimization::GetLayoutSensitiveOps(); std::unordered_set ort_specific_ops = - { "FusedConv", - "QLinearAveragePool", - "QLinearGlobalAveragePool" -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) || defined(USE_WEBNN) - // The CUDA/ROCM Resize kernel is layout sensitive as it only handles NCHW input. - // The CPU kernel and ONNX spec are not limited to handling NCHW input so are not layout sensitive, and - // onnx_layout_transformation::HandleResize is used. - , - "Resize" -#endif - }; + { + "FusedConv", + "QLinearAveragePool", + "QLinearGlobalAveragePool", + // Whilst the ONNX spec doesn't specify a layout for Resize, we treat it as layout sensitive by default + // as EPs tend to only support one layout. + "Resize", + }; ort_specific_ops.insert(layout_sensitive_ops.cbegin(), layout_sensitive_ops.cend()); return ort_specific_ops; @@ -42,45 +106,21 @@ const std::unordered_set& GetORTLayoutSensitiveOps() { return ort_layout_sensitive_ops; } -// Cost check for aggressively pushing the Transpose nodes involved in the layout transformation further out. -static CostCheckResult -PostLayoutTransformCostCheck(const api::GraphRef& graph, const api::NodeRef& node, - const std::vector& perm, - const std::unordered_set& outputs_leading_to_transpose) { - // we aggressively push the layout transpose nodes. - // Exception: pushing through a Concat can result in Transpose nodes being added to multiple other inputs which - // can potentially be worse for performance. Use the cost check in that case. - if (node.OpType() != "Concat" && - (perm == ChannelFirstToLastPerm(perm.size()) || perm == ChannelLastToFirstPerm(perm.size()))) { - return CostCheckResult::kPushTranspose; - } - - // for other nodes use the default ORT cost check - return OrtEPCostCheck(graph, node, perm, outputs_leading_to_transpose); -} - Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvider& execution_provider, AllocatorPtr cpu_allocator, const DebugGraphFn& debug_graph_fn) { // We pass in nullptr for the new_node_ep param as new nodes will be assigned by the graph partitioner after // TransformLayoutForEP returns. - // sub graph recurse will be added later. + // sub graph recurse will be added later auto api_graph = MakeApiGraph(graph, cpu_allocator, /*new_node_ep*/ nullptr); - const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); // to convert to NHWC we need to wrap layout sensitive nodes to Transpose from NCHW to NHWC and back. for (auto& node : api_graph->Nodes()) { - if (layout_sensitive_ops.count(node->OpType())) { - if (node->GetExecutionProviderType() != execution_provider.Type()) { - continue; - } - - auto domain = node->Domain(); - // Skip if domain is incorrect - if (domain != kOnnxDomain && domain != kMSDomain) { - continue; - } + if (node->GetExecutionProviderType() != execution_provider.Type()) { + continue; + } + if (ConvertNodeLayout(*node)) { // if already transformed then change the domain to kMSInternalNHWCDomain this way the EP // knows this op is in the expected format. if (node->GetAttributeIntDefault("channels_last", 0) == 1) { @@ -137,7 +177,6 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm}); } - // TODO: Technically Resize doesn't need to change domain as the ONNX Resize spec is not layout sensitive. SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain); modified = true; } diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.cc b/onnxruntime/core/optimizer/optimizer_execution_frame.cc index fc7e694b6a69b..46041bca9dcc1 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.cc +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.cc @@ -49,19 +49,13 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, InitializedTensorSet::const_iterator it = initialized_tensor_set.find(arg.Name()); if (it != initialized_tensor_set.cend()) { const auto& tensor_proto = *(it->second); - size_t cpu_tensor_length; - ORT_RETURN_IF_ERROR(utils::GetSizeInBytesFromTensorProto<0>(tensor_proto, &cpu_tensor_length)); OrtValue ort_value; - std::unique_ptr data = std::make_unique(cpu_tensor_length); - std::unique_ptr p_tensor; - ORT_RETURN_IF_ERROR(utils::TensorProtoToMLValue(Env::Default(), - model_path.IsEmpty() ? nullptr : model_path.ToPathString().c_str(), - tensor_proto, - MemBuffer(data.get(), cpu_tensor_length, allocator_ptr_->Info()), - ort_value)); - - initializers_[idx] = ort_value; - buffer_for_initialized_tensors_[idx] = std::move(data); + ORT_RETURN_IF_ERROR( + utils::TensorProtoToOrtValue(Env::Default(), + model_path.IsEmpty() ? nullptr : model_path.ToPathString().c_str(), + tensor_proto, allocator_ptr_, ort_value)); + + initializers_[idx] = std::move(ort_value); } return Status::OK(); @@ -72,7 +66,6 @@ OptimizerExecutionFrame::Info::Info(const std::vector& nodes, ort_value_name_idx_map_.Reserve(num_inputs_outputs); ort_value_idx_nodearg_map_.reserve(num_inputs_outputs); initializers_.reserve(initialized_tensor_set.size()); - buffer_for_initialized_tensors_.reserve(initialized_tensor_set.size()); for (auto* node : nodes) { ORT_THROW_IF_ERROR(onnxruntime::Node::ForEachWithIndex(node->InputDefs(), initialize_maps)); diff --git a/onnxruntime/core/optimizer/optimizer_execution_frame.h b/onnxruntime/core/optimizer/optimizer_execution_frame.h index e1b8a91545f64..13cf9e652c404 100644 --- a/onnxruntime/core/optimizer/optimizer_execution_frame.h +++ b/onnxruntime/core/optimizer/optimizer_execution_frame.h @@ -70,7 +70,6 @@ class OptimizerExecutionFrame final : public IExecutionFrame { OrtValueNameIdxMap ort_value_name_idx_map_; std::unordered_map ort_value_idx_nodearg_map_; std::unordered_map initializers_; - InlinedHashMap> buffer_for_initialized_tensors_; std::unique_ptr node_index_info_; const IExecutionProvider& execution_provider_; const std::function& is_sparse_initializer_func_; diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index 2c11bf144999e..81b415c2e40ae 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -4,9 +4,11 @@ #include "onnx_transpose_optimization.h" #include +#include #include #include #include +#include #include #include "core/common/gsl.h" @@ -93,6 +95,55 @@ static std::unique_ptr MakeSqueezeOrUnsqueeze(int64_t opset, api:: return graph.AddNode(op_type, inputs, /*num_outputs*/ 1); } +/// +/// Return a DequantizeLinear node if it's input is a constant initializer with known consumers. +/// In this case the initializer can be updated in-place by UnsqueezeInput or TransposeInput. +/// +/// Current graph +/// Value to check if produced by a DQ node who's input is a constant initializer +/// NodeRef for DQ node if it meets the requirements. +static std::unique_ptr GetDQWithConstInitializerInput(const api::GraphRef& graph, + std::string_view dq_output_name) { + std::unique_ptr dq_node; + auto maybe_dq_node = graph.GetNodeProducingOutput(dq_output_name); + + if (maybe_dq_node && maybe_dq_node->OpType() == "DequantizeLinear") { + do { + auto dq_input = maybe_dq_node->Inputs()[0]; + auto dq_constant = graph.GetConstant(dq_input); + + // input to DQ must be a constant initializer + if (!dq_constant) { + break; + } + + // For now keep it simple and don't support per-axis quantization as that would require updating the + // scale and zero point values in the DQ node to re-order if transposing, or reshape if unsqueezing. + // the rank of the `scale` and `zero point` inputs must match so we only need to check `scale`. + auto dq_scale = graph.GetConstant(maybe_dq_node->Inputs()[1]); + if (!dq_scale || dq_scale->NumElements() != 1) { + break; + } + + // need to know all the initializer consumers as we're potentially going to modify it directly + auto initializer_consumers = graph.GetValueConsumers(dq_input); + if (!initializer_consumers->comprehensive) { + break; + } + + // DQ output is only used by the node we're modifying. + auto dq_consumers = graph.GetValueConsumers(dq_output_name); + if (!dq_consumers->comprehensive || dq_consumers->nodes.size() != 1) { + break; + } + + dq_node = std::move(maybe_dq_node); + } while (false); + } + + return dq_node; +} + // Returns whether perm is a valid permutation (contains each value from 0 to perm.size() - 1 exactly once) static bool IsValidPerm(const std::vector& perm) { size_t rank = perm.size(); @@ -345,6 +396,12 @@ static std::vector SortedAxesForTransposedInput(const std::vectorShape(); + graph.GetValueInfo(dq.Outputs()[0])->SetShape(&new_shape); +} + /////// /////// /////// /////// @@ -357,51 +414,126 @@ static std::string_view HelpHandleUnsqueeze(HandlerArgs& args, const std::vector // broadcasting. static void UnsqueezeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, const std::vector& axes) { std::string_view input = node.Inputs()[i]; - // Remove this node as a consumer - node.SetInput(i, ""); std::unique_ptr constant = ctx.graph.GetLocalConstant(input); - auto consumers = ctx.graph.GetValueConsumers(input); + + // allow a constant initializer coming via a DQ node with a single consumer + std::unique_ptr dq_node; + std::string_view constant_dq_input; + + if (!constant) { + // look past a DQ node for a constant initializer. essentially we pretend the DQ node doesn't exist + // to enable directly making changes to the initializer. any nodes added for other consumers of the initializer + // in 'Case 1' are prior to the DQ so we don't break up any QDQ node units. + dq_node = GetDQWithConstInitializerInput(ctx.graph, input); + if (dq_node) { + // underlying string for the input name is in the Node so it's safe to store in string_view constant_dq_input + constant_dq_input = dq_node->Inputs()[0]; + constant = ctx.graph.GetLocalConstant(constant_dq_input); + // remove the DQ node as a consumer of the initializer while we modify things + dq_node->SetInput(0, ""); + } + } + + // Clear the input, which also removes this node's input as a consumer of the value. + // NOTE: the node may have multiple inputs consuming the value. + node.SetInput(i, ""); + auto value_to_modify = dq_node ? constant_dq_input : input; + auto consumers = ctx.graph.GetValueConsumers(value_to_modify); // Case 1: input is a constant with a known list of consumer nodes if (constant != nullptr && consumers->comprehensive) { - // We will reshape the initializer. If there are existing consumers, still reshape it but add Squeeze nodes + // We will reshape the initializer. If there are existing consumers, reshape it and add Squeeze nodes // to counteract its effect. If they later Unsqueeze the same input, the Squeeze nodes will simply be deleted // (see Case 2). if (consumers->nodes.size() > 0) { - auto squeeze_ptr = MakeSqueezeOrUnsqueeze(ctx.opset, ctx.graph, "Squeeze", input, axes); + // record the consumer node input as being special cased for use in Case 2 if a DQ node, and IsConstant + for (auto& consumer : consumers->nodes) { + auto& consumer_node_inputs = ctx.nodes_using_updated_shared_initializer[consumer->Id()]; + + // find input id/s for consumer + auto consumer_inputs = consumer->Inputs(); + for (size_t input_idx = 0; input_idx < consumer_inputs.size(); ++input_idx) { + if (consumer_inputs[input_idx] == value_to_modify) { + consumer_node_inputs.push_back(input_idx); + } + } + } + + auto squeeze_ptr = MakeSqueezeOrUnsqueeze(ctx.opset, ctx.graph, "Squeeze", value_to_modify, axes); api::NodeRef& squeeze = *squeeze_ptr; std::string_view sq_out = squeeze.Outputs()[0]; - ctx.graph.CopyValueInfo(input, sq_out); - ReplaceValueReferences(consumers->nodes, input, sq_out); + ctx.graph.CopyValueInfo(value_to_modify, sq_out); + ReplaceValueReferences(consumers->nodes, value_to_modify, sq_out); } + auto new_shape = UnsqueezeShape(constant->Shape(), axes); - ctx.graph.ReshapeInitializer(input, new_shape); - node.SetInput(i, input); + ctx.graph.ReshapeInitializer(value_to_modify, new_shape); + + if (dq_node) { + UpdateDQNodeInputAndShape(ctx.graph, *dq_node, constant_dq_input); + } + + node.SetInput(i, input); // restore the original connection return; } // Case 2: input is a Squeeze node with matching axes std::unique_ptr inp_node = ctx.graph.GetNodeProducingOutput(input); + + // check if this is a special-cased DQ node where we put the Squeeze on input 0 of the DQ in 'Case 1' above + if (inp_node && inp_node->OpType() == "DequantizeLinear" && + std::find_if(ctx.nodes_using_updated_shared_initializer.begin(), + ctx.nodes_using_updated_shared_initializer.end(), + [&inp_node](const auto& entry) { + const auto id = entry.first; + const auto& input_idxs = entry.second; + // check Id matches and the entry was for input 0 of the DQ node + return id == inp_node->Id() && + std::find(input_idxs.begin(), input_idxs.end(), size_t(0)) != input_idxs.end(); + }) != ctx.nodes_using_updated_shared_initializer.end()) { + // set things up so we can look past the DQ node to the Squeeze that was inserted in front of the reshaped + // constant initializer that was shared with this node. + dq_node = std::move(inp_node); + auto dq_input = dq_node->Inputs()[0]; + inp_node = ctx.graph.GetNodeProducingOutput(dq_input); + consumers = ctx.graph.GetValueConsumers(dq_input); + } + if (inp_node != nullptr && inp_node->IsOp("Squeeze")) { const std::vector& inp_node_inputs = inp_node->Inputs(); std::optional> squeeze_axes = std::nullopt; squeeze_axes = ReadFromAttrOrInput(ctx, *inp_node, "axes", /*inp_index*/ 1, /*opset*/ 13); if (squeeze_axes != std::nullopt && *squeeze_axes == axes) { + if (dq_node) { + UpdateDQNodeInputAndShape(ctx.graph, *dq_node, inp_node_inputs[0]); + node.SetInput(i, dq_node->Outputs()[0]); + } else { + node.SetInput(i, inp_node_inputs[0]); + } + // Remove the Squeeze node if possible - if (consumers->comprehensive && consumers->nodes.size() == 0) { + // if there's a DQ node the `consumers` list still includes it so allow for that. + // in that case UpdateDQNodeInputAndShape already updated the input of the DQ node so it's safe to remove it. + if (consumers->comprehensive && consumers->nodes.size() == size_t(dq_node ? 1 : 0)) { ctx.graph.RemoveNode(*inp_node); + if (ctx.opset >= 13 && !ctx.graph.HasValueConsumers(inp_node_inputs[1])) { ctx.graph.RemoveInitializer(inp_node_inputs[1]); } } - node.SetInput(i, inp_node_inputs[0]); + return; } // Axes don't match. Fall through to Case 3. } + // any DQ node special casing doesn't apply anymore, so go back to the original inp_node + if (dq_node) { + inp_node = std::move(dq_node); + } + // Case 3: Add an Unsqueeze node. auto unsqueeze_ptr = MakeSqueezeOrUnsqueeze(ctx.opset, ctx.graph, "Unsqueeze", input, axes); api::NodeRef& unsqueeze = *unsqueeze_ptr; @@ -453,83 +585,197 @@ static void Permute1DConstant(api::GraphRef& graph, api::NodeRef& node, api::Ten // Replaces ith input to node with transposed value. Might create a new Transpose node, find an existing one, // or transpose an initializer. -void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, - const std::vector& perm, const std::vector& perm_inv) { +static void TransposeInputImpl(api::GraphRef& graph, + NodeIdToInputIdxsMap* nodes_using_updated_shared_initializer, + api::NodeRef& node, size_t i, const std::vector& perm, + const std::vector& perm_inv) { std::string_view input = node.Inputs()[i]; - // Remove this node as a consumer - node.SetInput(i, ""); + // Only local constants are editable std::unique_ptr constant = graph.GetLocalConstant(input); - auto consumers = graph.GetValueConsumers(input); + + // allow a constant initializer coming via a DQ node with a single consumer + std::unique_ptr dq_node; + std::string_view constant_dq_input; + + if (!constant) { + // look past a DQ node for a constant initializer. essentially we pretend the DQ node doesn't exist + // to enable directly making changes to the initializer. any nodes added for other consumers of the initializer + // in 'Case 1' are prior to the DQ so we don't break up any QDQ node units. + dq_node = GetDQWithConstInitializerInput(graph, input); + if (dq_node) { + // underlying string for the input name is in the Node so it's safe to store in string_view constant_dq_input + constant_dq_input = dq_node->Inputs()[0]; + constant = graph.GetLocalConstant(constant_dq_input); + // remove the DQ node as a consumer of the initializer while we modify things + dq_node->SetInput(0, ""); + } + } + + // Clear the input, which also removes this node's input as a consumer of the value. + // NOTE: the node may have multiple inputs consuming the value. + node.SetInput(i, ""); + + auto constant_to_modify = dq_node ? constant_dq_input : input; + auto consumers = graph.GetValueConsumers(constant_to_modify); // Case 1: input is a constant with a known list of consumer nodes if (constant != nullptr && consumers->comprehensive) { - // Input is scalar, return early. - if (constant->Shape().size() == 1 && constant->Shape()[0] == 0) { + // we modify the initializer in-place and need to reconnect things up when we're done. this helper will + // do that when it goes out of scope. if we have manually reconnected, input or constant_dq_input is + // set to an empty string. + auto reconnect_nodes = gsl::finally([i, &node, &dq_node, &input, &constant_dq_input] { + if (!input.empty()) { + node.SetInput(i, input); + } + + if (!constant_dq_input.empty()) { + dq_node->SetInput(0, constant_dq_input); + } + }); + + // If there is only one element return early as the transpose won't change the data + if (constant->NumElements() == 1) { return; } + // This is a special case where the constant is 1D with length == perm. - // TODO: TransposeInitializer should be updated to handle this case. + // e.g. it provides a set of values that are relative to the input axes like the `sizes` input for Resize // Permute1DConstant permutes the constant and adds a new initializer. The old initializer is removed only if // there are no other consumers. if (constant->Shape().size() == 1 && constant->Shape()[0] == gsl::narrow_cast(perm.size())) { - Permute1DConstant(graph, node, *constant, i, input, perm); + auto& node_to_update = dq_node ? *dq_node : node; + Permute1DConstant(graph, node_to_update, *constant, i, constant_to_modify, perm); + + // unset updated input so reconnect_nodes doesn't change it back + if (dq_node) { + constant_dq_input = ""; + } else { + input = ""; + } + return; } + if (consumers->nodes.size() > 0) { // Transpose the initializer. If there are existing consumers, add Transpose nodes to them using perm_inv // to counteract the effect. These Transposes will hopefully be optimized out later. - auto transpose_inv_ptr = MakeTranspose(graph, input, perm_inv); + + // record the consumer node's input as being special cased for use in Case 2 if a DQ node, and IsConstant + if (nodes_using_updated_shared_initializer) { + for (auto& consumer : consumers->nodes) { + auto& consumer_node_inputs = (*nodes_using_updated_shared_initializer)[consumer->Id()]; + + // find input id/s for consumer + auto consumer_inputs = consumer->Inputs(); + for (size_t input_idx = 0; input_idx < consumer_inputs.size(); ++input_idx) { + if (consumer_inputs[input_idx] == constant_to_modify) { + consumer_node_inputs.push_back(input_idx); + } + } + } + } + + auto transpose_inv_ptr = MakeTranspose(graph, constant_to_modify, perm_inv); api::NodeRef& transpose_inv = *transpose_inv_ptr; std::string_view transpose_out = transpose_inv.Outputs()[0]; - graph.CopyValueInfo(input, transpose_out); - ReplaceValueReferences(consumers->nodes, input, transpose_out); + graph.CopyValueInfo(constant_to_modify, transpose_out); + ReplaceValueReferences(consumers->nodes, constant_to_modify, transpose_out); } - graph.TransposeInitializer(input, perm); - node.SetInput(i, input); + + graph.TransposeInitializer(constant_to_modify, perm); + + if (dq_node) { + UpdateDQNodeInputAndShape(graph, *dq_node, constant_to_modify); + constant_dq_input = ""; // DQ input was already updated so we don't need reconnect_nodes to handle it + } + return; } // Case 2: input is a Transpose node std::unique_ptr inp_node = graph.GetNodeProducingOutput(input); + + // check if this is a special-cased DQ node where we put the Transpose on input 0 of the DQ in 'Case 1' above + if (inp_node && inp_node->OpType() == "DequantizeLinear" && + nodes_using_updated_shared_initializer && + std::find_if(nodes_using_updated_shared_initializer->begin(), nodes_using_updated_shared_initializer->end(), + [&inp_node](const auto entry) { + const auto id = entry.first; + const auto& input_idxs = entry.second; + // id matches and the entry is for input 0 of the DQ node + return id == inp_node->Id() && + std::find(input_idxs.begin(), input_idxs.end(), size_t(0)) != input_idxs.end(); + }) != nodes_using_updated_shared_initializer->end()) { + // set things up so we can look past the DQ node to the Transpose that was inserted in front of the reshaped + // constant initializer that was shared with this node. + dq_node = std::move(inp_node); + auto dq_input = dq_node->Inputs()[0]; + inp_node = graph.GetNodeProducingOutput(dq_input); + consumers = graph.GetValueConsumers(dq_input); + } + if (inp_node != nullptr && inp_node->IsOp("Transpose")) { std::optional> perm2 = GetPermAttrIfValid(*inp_node); if (perm2 != std::nullopt && perm2->size() == perm.size()) { // If they cancel, use pre_transpose_value and remove Transpose if possible. if (*perm2 == perm_inv) { std::string_view pre_transpose_value = inp_node->Inputs()[0]; - if (consumers->comprehensive && consumers->nodes.size() == 0) { + + if (dq_node) { + UpdateDQNodeInputAndShape(graph, *dq_node, pre_transpose_value); + node.SetInput(i, dq_node->Outputs()[0]); + } else { + node.SetInput(i, pre_transpose_value); + } + + // Remove the Transpose node if possible + // if there's a DQ node the `consumers` list still includes it so allow for that. + // in that case UpdateDQNodeInputAndShape already updated the input of the DQ node so it's safe to remove it. + if (consumers->comprehensive && consumers->nodes.size() == size_t(dq_node ? 1 : 0)) { graph.RemoveNode(*inp_node); } - node.SetInput(i, pre_transpose_value); - return; - } else if (*perm2 == perm) { - // we are trying to add a duplicate transpose. - // do nothing and return + return; } - // Otherwise, compose the perm and Transpose pre_transpose_value. Cost is the same and we may be able to remove - // the other Transpose. - const std::vector& perm_combined = ComposePerm(*perm2, perm); - auto transpose_ptr = MakeTranspose(graph, inp_node->Inputs()[0], perm_combined); - api::NodeRef& transpose = *transpose_ptr; - std::string_view transpose_out = transpose.Outputs()[0]; - graph.CopyValueInfo(input, transpose_out); - graph.GetValueInfo(transpose_out)->PermuteDims(perm); - if (consumers->comprehensive && consumers->nodes.size() == 0) { - graph.RemoveNode(*inp_node); + // NOTE: We expect the Transpose to cancel out when handling a special-cased DQ node that was originally + // connected to a shared constant initializer, so we don't expect to get here if dq_node is not nullptr. + // If there was a dq_node where the Transpose didn't cancel out we fall through to the next case + // so we retain the potential to cancel out for any other usages of the shared initializer. + assert(!dq_node); // assert in debug build to investigate. fall through to next case in release build to be safe. + + if (!dq_node) { + // Otherwise, compose the perm and Transpose pre_transpose_value. Cost is the same and we may be able to remove + // the other Transpose. + const std::vector& perm_combined = ComposePerm(*perm2, perm); + auto transpose_ptr = MakeTranspose(graph, inp_node->Inputs()[0], perm_combined); + api::NodeRef& transpose = *transpose_ptr; + std::string_view transpose_out = transpose.Outputs()[0]; + graph.CopyValueInfo(input, transpose_out); + graph.GetValueInfo(transpose_out)->PermuteDims(perm); + + if (consumers->comprehensive && consumers->nodes.size() == 0) { + graph.RemoveNode(*inp_node); + } + + node.SetInput(i, transpose_out); + + return; } - node.SetInput(i, transpose_out); - return; } } - // Case 3: A Transpose op might already exist - for (size_t j = 0; j < consumers->nodes.size(); ++j) { - api::NodeRef& consumer = *consumers->nodes[j]; - if (consumer.IsOp("Transpose") && GetPermAttrIfValid(consumer) == perm) { - node.SetInput(i, consumer.Outputs()[0]); + // any DQ node special casing doesn't apply anymore, so go back to the original inp_node + if (dq_node) { + inp_node = std::move(dq_node); + consumers = graph.GetValueConsumers(input); + } + + // Case 3: A Transpose op with the same perms might already exist + for (auto& consumer : consumers->nodes) { + if (consumer->IsOp("Transpose") && GetPermAttrIfValid(*consumer) == perm) { + node.SetInput(i, consumer->Outputs()[0]); return; } } @@ -540,11 +786,25 @@ void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, std::string_view transpose_out = transpose.Outputs()[0]; graph.CopyValueInfo(input, transpose_out); graph.GetValueInfo(transpose_out)->PermuteDims(perm); + node.SetInput(i, transpose_out); } +void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, + const std::vector& perm, + const std::vector& perm_inv) { + // this TransposeInput is used by the layout transformer to wrap a node in Transpose ops. there's no OptimizerCtx + // in that scenario and we're not tracking special-cased DQ nodes as we only do that when pushing Transpose nodes. + TransposeInputImpl(graph, /* nodes_using_updated_shared_initializer */ nullptr, node, i, perm, perm_inv); +} + +static void TransposeInput(OptimizerCtx& ctx, api::NodeRef& node, size_t i, const std::vector& perm, + const std::vector& perm_inv) { + TransposeInputImpl(ctx.graph, &ctx.nodes_using_updated_shared_initializer, node, i, perm, perm_inv); +} + // Unsqueezes inputs of node to have uniform rank. Returns false if input ranks are unknown or exceed the target rank. -static bool NormalizeInputRanks(OptimizerCtx ctx, api::NodeRef& node, size_t target_rank, +static bool NormalizeInputRanks(OptimizerCtx& ctx, api::NodeRef& node, size_t target_rank, const std::vector& input_indices) { auto inputs = node.Inputs(); @@ -579,7 +839,7 @@ void TransposeInputs(OptimizerCtx& ctx, api::NodeRef& node, const std::vector& input_indices) { auto perm_inv = InvertPerm(perm); for (size_t j : input_indices) { - TransposeInput(ctx.graph, node, j, perm, perm_inv); + TransposeInput(ctx, node, j, perm, perm_inv); } } @@ -670,23 +930,74 @@ static bool CanLikelyRemoveTranspose(const api::GraphRef& graph, api::NodeRef& t return true; } +// return true if +// - the value is a constant initializer +// - the value is the output of a DQ node who's input is a constant initializer +// - UnsqueezeInput/TranposeInput can look past the DQ to update the constant initializer directly +// - DQ node is currently ignored if it uses per-channel quantization +// - supporting per-channel quantization requires modifying the scales and zero point data, which can be done +// if/when there's a use-case to justify the development cost. +// - the input was originally connected to a shared constant initializer that was updated in place by UnsqueezeInput +// or TransposeInput, and usage by this node had Squeeze/Transpose nodes inserted to counteract the effect of the +// in-place update. if we push the same transpose through this node it should cancel out that Squeeze/Transpose +// +// in all these cases we expect pushing the transpose through to not require a runtime Transpose node +static bool IsConstant(const api::GraphRef& graph, const api::NodeRef& node, + size_t input_id, + std::string_view value_name, + const NodeIdToInputIdxsMap& nodes_using_updated_shared_initializer) { + std::unique_ptr producer_node = graph.GetNodeProducingOutput(value_name); + + if (!producer_node) { + // initializer. may or may not be constant depending on whether it has a matching graph input + std::unique_ptr constant = graph.GetConstant(value_name); + return constant != nullptr; + } + + auto node_id_to_check = node.Id(); + + // handle potentially looking past a DQ node + if (producer_node->OpType() == "DequantizeLinear") { + std::unique_ptr dq_node = GetDQWithConstInitializerInput(graph, value_name); + if (dq_node != nullptr) { + // DQ node pointing to an initializer that has not been updated in-place yet + return true; + } + + // could also be a DQ that was connected to a shared initializer that was updated in-place. + // update the info on the node/input index to check and fall through + node_id_to_check = producer_node->Id(); + input_id = 0; // can only be input 0 of a DQ node + } + + auto entry = nodes_using_updated_shared_initializer.find(node_id_to_check); + if (entry != nodes_using_updated_shared_initializer.end()) { + if (std::find(entry->second.begin(), entry->second.end(), input_id) != entry->second.end()) { + return true; + } + } + + return false; +} + // Estimates the cost of transposing an input. Currently uses rank heuristic. Negative if transpose is removed. // Feel free to improve as needed. -static int EstimateTransposeValueCost(const api::GraphRef& graph, std::string_view input, +static int EstimateTransposeValueCost(const api::GraphRef& graph, const api::NodeRef& node, + size_t input_id, std::string_view input, const std::vector& perm_inv, - const HandlerMap& extended_handlers) { + const HandlerMap& extended_handlers, + const NodeIdToInputIdxsMap& nodes_using_updated_shared_initializer) { // Case 1: Transposing constants probably costs nothing. - std::unique_ptr constant = graph.GetConstant(input); - if (constant != nullptr) { + if (IsConstant(graph, node, input_id, input, nodes_using_updated_shared_initializer)) { return 0; } // Case 2: Transposing a transpose either cancels it or composes the permutations. - std::unique_ptr node = graph.GetNodeProducingOutput(input); - if (node != nullptr && node->IsOp("Transpose")) { - std::optional> perm2 = GetPermAttrIfValid(*node); + std::unique_ptr producer_node = graph.GetNodeProducingOutput(input); + if (producer_node != nullptr && producer_node->IsOp("Transpose")) { + std::optional> perm2 = GetPermAttrIfValid(*producer_node); if (perm2 != std::nullopt) { - if (*perm2 == perm_inv && CanLikelyRemoveTranspose(graph, *node, extended_handlers)) { + if (*perm2 == perm_inv && CanLikelyRemoveTranspose(graph, *producer_node, extended_handlers)) { return -EstimateValueRank(graph, input); } else { return 0; @@ -702,11 +1013,13 @@ static int EstimateTransposeValueCost(const api::GraphRef& graph, std::string_vi static int EstimateTransposeInputsCost(const api::GraphRef& graph, const api::NodeRef& node, const std::vector& perm_inv, const std::vector& input_indices, - const HandlerMap& extended_handlers) { + const HandlerMap& extended_handlers, + const NodeIdToInputIdxsMap& nodes_using_updated_shared_initializer) { auto inputs = node.Inputs(); int cost = 0; for (size_t j : input_indices) { - cost += EstimateTransposeValueCost(graph, inputs[j], perm_inv, extended_handlers); + cost += EstimateTransposeValueCost(graph, node, j, inputs[j], perm_inv, extended_handlers, + nodes_using_updated_shared_initializer); } return cost; } @@ -734,8 +1047,10 @@ static bool HandleSimpleNodeBase(HandlerArgs& args, bool broadcast_inputs) { if (broadcast_inputs && !NormalizeInputRanks(args.ctx, args.node, rank, args.transposible_inputs)) { return false; } + TransposeInputs(args.ctx, args.node, args.perm_inv, args.transposible_inputs); TransposeOutputs(args.ctx, args.node, args.perm); + return true; } @@ -927,18 +1242,7 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con node.SetInput(i, gather_output); } -static bool HandleResize([[maybe_unused]] HandlerArgs& args) { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) || defined(USE_WEBNN) - // The CUDA Resize kernel requires that the input is NCHW, so we can't push a Transpose through a Resize - // in ORT builds with CUDA enabled. - // The ROCm EP is generated from the CUDA EP kernel so the same applies to builds with ROCm enabled. - // The QNN EP requires the input to be NHWC, so the Resize handler is also not enabled for QNN builds. - // - // TODO: Remove this special case once the CUDA Resize kernel is implemented "generically" (i.e.) aligning with the - // generic nature of the ONNX spec. - // See https://github.com/microsoft/onnxruntime/pull/10824 for a similar fix applied to the CPU Resize kernel. - return false; -#else +bool HandleResize([[maybe_unused]] HandlerArgs& args) { auto inputs = args.node.Inputs(); int64_t rank_int = gsl::narrow_cast(args.perm.size()); @@ -964,10 +1268,10 @@ static bool HandleResize([[maybe_unused]] HandlerArgs& args) { TransposeOutputs(args.ctx, args.node, args.perm); return true; -#endif } -constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize}; +// Not currently registered by default. +// constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize}; static bool HandlePad(HandlerArgs& args) { size_t rank = args.perm.size(); @@ -1719,8 +2023,11 @@ static const std::unordered_map handler_ma {"Split", split_handler}, {"Shape", shape_handler}, {"Pad", pad_handler}, - {"Resize", resize_handler}, - {"ReduceSum", reduce_op_handler}, + + // Execution providers tend to only implement Resize for specific layouts. Due to that, it's safer to not + // push a Transpose through a Resize unless the EP specifically checks that it can handle the change via an + // extended handler. + // {"Resize", resize_handler}, {"ReduceLogSum", reduce_op_handler}, {"ReduceLogSumExp", reduce_op_handler}, @@ -1728,6 +2035,7 @@ static const std::unordered_map handler_ma {"ReduceMean", reduce_op_handler}, {"ReduceMin", reduce_op_handler}, {"ReduceProd", reduce_op_handler}, + {"ReduceSum", reduce_op_handler}, {"ReduceSumSquare", reduce_op_handler}, {"ReduceL1", reduce_op_handler}, {"ReduceL2", reduce_op_handler}, @@ -1787,12 +2095,14 @@ static int CalculateCost(const api::GraphRef& graph, const api::NodeRef& node, const std::unordered_set& outputs_leading_to_transpose, const HandlerInfo& info, const std::vector& input_indices, - const HandlerMap& extended_handlers) { + const HandlerMap& extended_handlers, + const NodeIdToInputIdxsMap& nodes_using_updated_shared_initializer) { // We require the input cost (number of transposes before the op) and the total cost to strictly decrease. // Strict decrease of the input cost ensures the optimization is stable, since the total cost decrease is just an // estimate (the transpose after the op may or may not cancel with a subsequent transpose). We don't want // repeated runs of the optimizer to have a transpose toggle between two inputs of a binary op. - int cost = EstimateTransposeInputsCost(graph, node, perm, input_indices, extended_handlers); + int cost = EstimateTransposeInputsCost(graph, node, perm, input_indices, extended_handlers, + nodes_using_updated_shared_initializer); if (cost < 0 && info.transposes_outputs) { // If the output will be transposed and won't ultimately cancel, factor in that cost. @@ -1822,13 +2132,14 @@ static bool ShouldPushTranspose(const api::GraphRef& graph, const api::NodeRef& const std::unordered_set& outputs_leading_to_transpose, const HandlerInfo& info, const std::vector transposable_input_indices, - const HandlerMap& extended_handlers) { + const HandlerMap& extended_handlers, + const NodeIdToInputIdxsMap& nodes_using_updated_shared_initializer) { if (node.IsOp("Transpose")) { return true; } int cost = CalculateCost(graph, node, perm, outputs_leading_to_transpose, info, transposable_input_indices, - extended_handlers); + extended_handlers, nodes_using_updated_shared_initializer); return cost < 0; } @@ -1855,7 +2166,7 @@ bool ProcessTranspose(OptimizerCtx& ctx, api::NodeRef& transpose, api::NodeRef& if (cost == CostCheckResult::kFallThrough) { cost = ShouldPushTranspose(ctx.graph, node, perm, outputs_leading_to_transpose, *info, input_indices, - ctx.extended_handlers) + ctx.extended_handlers, ctx.nodes_using_updated_shared_initializer) ? CostCheckResult::kPushTranspose : CostCheckResult::kStop; } @@ -1889,7 +2200,7 @@ std::optional MakeOptimizerContext(api::GraphRef& graph, return std::nullopt; } - OptimizerCtx ctx{*opset, graph, provider_type, cost_check_fn, extended_handlers}; + OptimizerCtx ctx{*opset, graph, provider_type, cost_check_fn, extended_handlers, {}}; return ctx; } @@ -2067,6 +2378,8 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { continue; } + // NOTE: this bleeds ORT specific logic into the base optimizer, however we justify that for now because we expect + // the types that the ORT DQ provides to be added to the ONNX spec, at which point this special case can go away. if (IsMSDomain(dq_domain) && !TransposeQuantizeDequantizeAxis(ctx.graph, perm_inv, *dq_node)) { continue; } diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h index 1a54e7834a4ae..cc1552704c187 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h @@ -3,6 +3,7 @@ #pragma once +#include #include // implementation details of the transpose optimizer API defined in optimizer_api.h. @@ -38,6 +39,8 @@ struct HandlerInfo { bool transposes_outputs = true; }; +using NodeIdToInputIdxsMap = std::unordered_map>; + struct OptimizerCtx { int64_t opset; api::GraphRef& graph; @@ -48,6 +51,32 @@ struct OptimizerCtx { // Handlers for ops that are not in the ONNX opset, or for ONNX ops where special handling is required. // If a handler is not found in this map, the default handlers will be used. const HandlerMap& extended_handlers; + + // When we update a shared constant initializer as part of pushing a transpose through a node we update the + // initializer in-place and insert Squeeze (in UnsqueezeInput if the initializer is broadcast) or + // Transpose (in TransposeInput) nodes between the updated initializer and the other usages. + // This map contains the set of nodes that had a Squeeze or Transpose added between them and the initializer. + // The entry contains the node id (key) and original input index/es (value) that were connected to the initializer + // prior to the insertion of the Squeeze/Transpose. + // + // Assuming we also transpose the other usages of the initializer in the same way (which would be expected) the + // Squeeze and Transpose nodes would be cancelled out, and the other usages will end up using the original + // initializer that was updated in-place. + // + // We use this information in two ways. + // + // 1. In the IsConstant calculation that determines the cost of pushing a transpose through a node. + // - as we expect the transpose to be making the same modification to all shared usages of the initializer we + // expect the Squeeze/Transpose nodes to be cancelled out, resulting in no runtime cost to push the transpose + // through that input. + // + // 2. To enable and track a special case in a QDQ format model where there is the added complexity of a DQ node + // between the initializer and each usage. + // - we look past a DQ node in UnsqueezeInput and TransposeInput to determine if there is a constant initializer + // that can be updated in-place as the DQ node is not sensitive to any rank or layout changes + // - NOTE we currently ignore DQ nodes with per-channel quantization as they are sensitive to changes + // - we also look past DQ nodes when processing the other usages in order to cancel out the Squeeze/Transpose + NodeIdToInputIdxsMap nodes_using_updated_shared_initializer; }; /// @@ -69,6 +98,7 @@ bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional default_ // base handlers that are used by extended handlers. add from transpose_optimizer.cc as needed. bool HandleReduceOps(HandlerArgs& args); +bool HandleResize([[maybe_unused]] HandlerArgs& args); void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std::vector& perm, diff --git a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h index 40a03f24f7648..fb338be1c7f5a 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h +++ b/onnxruntime/core/optimizer/transpose_optimization/optimizer_api.h @@ -242,6 +242,12 @@ class NodeRef { /// since version or default value -1 virtual int SinceVersion() const = 0; + /// + /// Get the unique id of the node. + /// + /// Id + virtual int64_t Id() const = 0; + virtual ~NodeRef(){}; }; @@ -442,7 +448,7 @@ class GraphRef { } // namespace api constexpr int64_t kMinSupportedOpset = 7; -constexpr int64_t kMaxSupportedOpset = 19; +constexpr int64_t kMaxSupportedOpset = 20; // enum of results that a CostCheckFn can return. enum class CostCheckResult { diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc index b30c94d7b3e40..2fcb88cb0b9ba 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_optimizer_api_impl.cc @@ -95,7 +95,8 @@ class ApiNode final : public api::NodeRef { void ClearAttribute(std::string_view name) override; void SetInput(size_t i, std::string_view name) override; std::string_view GetExecutionProviderType() const override; - virtual int SinceVersion() const override; + int SinceVersion() const override; + int64_t Id() const override; private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(ApiNode); @@ -417,6 +418,10 @@ int ApiNode::SinceVersion() const { return node_.SinceVersion(); } +int64_t ApiNode::Id() const { + return node_.Index(); +} + // std::optional ApiGraph::Opset(std::string_view domain) const { diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc index ead82a6b56741..f4f3505128737 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc @@ -5,12 +5,35 @@ #include #include "core/graph/constants.h" +#include "core/framework/utils.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" using namespace onnx_transpose_optimization; namespace onnxruntime { +static bool EPAwareHandleResize(HandlerArgs& args) { + // Whilst Resize is not technically layout sensitive, execution providers typically implement handling for only one + // layout. Due to that, only push a Transpose through a Resize once it is assigned and we know it's being handled + // by an EP that supports multiple layouts. Currently that's the CPU and XNNPACK EPs. + const auto ep_type = args.node.GetExecutionProviderType(); + if (ep_type == kCpuExecutionProvider || ep_type == kXnnpackExecutionProvider) { + // allow NCHW <-> NHWC for now. not clear any other sort of transpose has a valid usage in a real model + int64_t rank_int = gsl::narrow_cast(args.perm.size()); + if (rank_int == 4) { + static const std::vector nchw_to_nhwc_perm{0, 2, 3, 1}; + static const std::vector nhwc_to_nchw_perm{0, 3, 1, 2}; + if (args.perm == nchw_to_nhwc_perm || args.perm == nhwc_to_nchw_perm) { + return HandleResize(args); + } + } + } + + return false; +} + +constexpr HandlerInfo ep_aware_resize_handler = {&FirstInput, &EPAwareHandleResize}; + static bool HandleQLinearConcat(HandlerArgs& args) { return HandleSimpleNodeWithAxis(args); } @@ -62,7 +85,7 @@ static bool HandleMaxPool(HandlerArgs& args) { ORT_UNUSED_PARAMETER(args); return false; #else - if (args.node.GetExecutionProviderType() != "CPUExecutionProvider") { + if (args.node.GetExecutionProviderType() != kCpuExecutionProvider) { return false; } @@ -103,6 +126,7 @@ static bool HandleContribQuantizeDequantizeLinear(HandlerArgs& args) { } constexpr HandlerInfo max_pool_op_handler = {&FirstInput, &HandleMaxPool}; + constexpr HandlerInfo node_1_inp_handler = {&FirstInput, &HandleSimpleNode}; constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps}; constexpr HandlerInfo contrib_quantize_dequantize_linear_handler = {&FirstInput, @@ -113,6 +137,7 @@ const HandlerMap& OrtExtendedHandlers() { static const HandlerMap extended_handler_map = []() { HandlerMap map = { {"MaxPool", max_pool_op_handler}, + {"Resize", ep_aware_resize_handler}, {"com.microsoft.QuantizeLinear", contrib_quantize_dequantize_linear_handler}, {"com.microsoft.DequantizeLinear", contrib_quantize_dequantize_linear_handler}, {"com.microsoft.QLinearAdd", q_linear_binary_op_handler}, diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h index 0a5dbd6d13d06..8245d8c3b4eae 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h @@ -10,7 +10,6 @@ namespace onnxruntime { /// /// Get the extended handlers for ORT specific transpose optimization. /// These include handlers for contrib ops, and where we have an NHWC version of a layout sensitive op. -/// Extends the handlers returned by OrtHandlers. /// /// HandlerMap const onnx_transpose_optimization::HandlerMap& OrtExtendedHandlers(); diff --git a/onnxruntime/core/optimizer/transpose_optimizer.cc b/onnxruntime/core/optimizer/transpose_optimizer.cc index 33e3f5eeaf0fa..092df9cc7dcfb 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer.cc @@ -18,10 +18,18 @@ namespace onnxruntime { Status TransposeOptimizer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { - auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ nullptr); - - OptimizeResult result = onnx_transpose_optimization::Optimize(*api_graph, "", /* default cost check*/ nullptr, - OrtExtendedHandlers()); + OptimizeResult result; + + if (ep_.empty()) { + // basic usage - no EP specific optimizations + auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ nullptr); + result = onnx_transpose_optimization::Optimize(*api_graph, "", /* default cost check*/ nullptr, + OrtExtendedHandlers()); + } else { + // EP specific optimizations enabled. Currently only used for CPU EP. + auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ ep_.c_str()); + result = onnx_transpose_optimization::Optimize(*api_graph, ep_, OrtEPCostCheck, OrtExtendedHandlers()); + } if (result.error_msg) { // currently onnx_layout_transformation::Optimize only fails if we hit an unsupported opset. diff --git a/onnxruntime/core/optimizer/transpose_optimizer.h b/onnxruntime/core/optimizer/transpose_optimizer.h index 1ae6d611d2f0e..97d7ab4d0e220 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer.h +++ b/onnxruntime/core/optimizer/transpose_optimizer.h @@ -15,10 +15,14 @@ Push transposes through ops and eliminate them. class TransposeOptimizer : public GraphTransformer { private: AllocatorPtr cpu_allocator_; + const std::string ep_; public: - explicit TransposeOptimizer(AllocatorPtr cpu_allocator) noexcept - : GraphTransformer("TransposeOptimizer"), cpu_allocator_(std::move(cpu_allocator)) {} + explicit TransposeOptimizer(AllocatorPtr cpu_allocator, + const std::string& ep = {}) noexcept + : GraphTransformer(ep.empty() ? "TransposeOptimizer" : "TransposeOptimizer_" + ep), + cpu_allocator_(std::move(cpu_allocator)), + ep_{ep} {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; diff --git a/onnxruntime/core/platform/windows/debug_alloc.cc b/onnxruntime/core/platform/windows/debug_alloc.cc index b08d189f79866..ff6a059607367 100644 --- a/onnxruntime/core/platform/windows/debug_alloc.cc +++ b/onnxruntime/core/platform/windows/debug_alloc.cc @@ -12,7 +12,7 @@ // #ifndef NDEBUG #ifdef ONNXRUNTIME_ENABLE_MEMLEAK_CHECK -constexpr int c_callstack_limit = 16; // Maximum depth of callstack in leak trace +constexpr int c_callstack_limit = 32; // Maximum depth of callstack in leak trace #define VALIDATE_HEAP_EVERY_ALLOC 0 // Call HeapValidate on every new/delete #pragma warning(disable : 4073) // initializers put in library initialization area (this is intentional) @@ -223,6 +223,11 @@ Memory_LeakCheck::~Memory_LeakCheck() { // empty_group_names = new std::map; }); if (string.find("RtlRunOnceExecuteOnce") == std::string::npos && string.find("re2::RE2::Init") == std::string::npos && + string.find("dynamic initializer for 'FLAGS_") == std::string::npos && + string.find("AbslFlagDefaultGenForgtest_") == std::string::npos && + string.find("AbslFlagDefaultGenForundefok::Gen") == std::string::npos && + string.find("::SetProgramUsageMessage") == std::string::npos && + string.find("testing::internal::ParseGoogleTestFlagsOnly") == std::string::npos && string.find("testing::internal::Mutex::ThreadSafeLazyInit") == std::string::npos && string.find("testing::internal::ThreadLocalRegistryImpl::GetThreadLocalsMapLocked") == std::string::npos && string.find("testing::internal::ThreadLocalRegistryImpl::GetValueOnCurrentThread") == std::string::npos && diff --git a/onnxruntime/core/platform/windows/stacktrace.cc b/onnxruntime/core/platform/windows/stacktrace.cc index cac6f4f29043b..d7d423e4a483e 100644 --- a/onnxruntime/core/platform/windows/stacktrace.cc +++ b/onnxruntime/core/platform/windows/stacktrace.cc @@ -10,7 +10,6 @@ #include #endif #endif -#include #include "core/common/logging/logging.h" #include "core/common/gsl.h" diff --git a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h index b142db86a7902..b4132d3b770ec 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/coreml/builders/impl/base_op_builder.h @@ -51,7 +51,7 @@ class BaseOpBuilder : public IOpBuilder { virtual bool HasSupportedInputsImpl(const Node& node, const logging::Logger& logger) const; virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; } - virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 19; } + virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 20; } private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 18010960e11c8..3d03abf5b7ebc 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -273,7 +273,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDoma // Opset 9 class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 10, Compress); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, ConstantOfShape); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 19, ConstantOfShape); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, MeanVarianceNormalization); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, float, Greater); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 9, 12, double, Greater); @@ -958,6 +958,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Scan); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Shape); +// Opset 20 +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape); + // !!PLEASE READ BELOW!! Following that, add new entries above this comment /* *** IMPORTANT! *** @@ -1332,7 +1335,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { // Opset 9 BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + // Opset 20 + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc b/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc index 920db5ed34dd1..a93da12ccf595 100644 --- a/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape.cc @@ -11,11 +11,16 @@ ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0, ConstantOfShapeDefaultOutputTypes); +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, 20, Output, 0, + ConstantOfShapeDefaultOutputTypesOpset20); + // pytorch converter uses ConstantOfShape with int64 to create Pad input // https://github.com/pytorch/pytorch/blob/044b519a80459f6787f6723c1c091a18b153d184/torch/onnx/symbolic_opset11.py#L449 ORT_SPECIFY_OP_KERNEL_ARG_REQUIRED_TYPES_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0, int64_t); + } // namespace op_kernel_type_control namespace { @@ -24,6 +29,10 @@ using EnabledOutputTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS( kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, Output, 0); +using EnabledOutputTypesOpset20 = + ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, ConstantOfShape, 20, Output, 0); + class ConstantOfShape final : public ConstantOfShapeBase, public OpKernel { public: explicit ConstantOfShape(const OpKernelInfo& info) : ConstantOfShapeBase(info), OpKernel(info) {} @@ -66,13 +75,22 @@ Status ConstantOfShape::Compute(OpKernelContext* ctx) const { } // namespace -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( ConstantOfShape, 9, + 19, KernelDefBuilder() .TypeConstraint("T1", DataTypeImpl::GetTensorType()) .TypeConstraint("T2", BuildKernelDefConstraintsFromTypeList()), ConstantOfShape); +ONNX_CPU_OPERATOR_KERNEL( + ConstantOfShape, + 20, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", + BuildKernelDefConstraintsFromTypeList()), + ConstantOfShape); } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h index d96ff06e3d6d8..9aa73c714daea 100644 --- a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h @@ -23,6 +23,18 @@ using ConstantOfShapeDefaultOutputTypes = uint8_t, uint16_t, uint32_t, uint64_t, bool>; +using ConstantOfShapeDefaultOutputTypesOpset20 = + TypeList< + BFloat16, + MLFloat16, + float, double, +#if !defined(DISABLE_FLOAT8_TYPES) + Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ, +#endif + int8_t, int16_t, int32_t, int64_t, + uint8_t, uint16_t, uint32_t, uint64_t, + bool>; + template class ConstantOfShapeBase { protected: diff --git a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h index c13c9d42dd392..99522cdf0759a 100644 --- a/onnxruntime/core/providers/cpu/tensor/upsamplebase.h +++ b/onnxruntime/core/providers/cpu/tensor/upsamplebase.h @@ -239,7 +239,7 @@ class UpsampleBase { if (coordinate_transform_mode_name == "half_pixel_symmetric") { return HALF_PIXEL_SYMMETRIC; } - ORT_THROW("coordinate_transform_mode:[" + coordinate_transform_mode_name + "] is not supportted!"); + ORT_THROW("coordinate_transform_mode:[" + coordinate_transform_mode_name + "] is not supported!"); } GetOriginalCoordinateFunc GetOriginalCoordinateFromResizedCoordinate( diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 72e36a161e9aa..1b1cbbf8dead9 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -229,21 +229,24 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Unsqueeze); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Unsqueeze); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 15, Where); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 16, Where); + class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, Conv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Conv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Conv); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Conv); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, ConvTranspose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm); @@ -318,6 +321,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Til class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Range); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad); @@ -493,21 +499,24 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, + KERNEL_CREATE_INFO_VERSIONED(9, 15, Where), + KERNEL_CREATE_INFO(16, Where), + BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -584,7 +593,11 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -617,7 +630,8 @@ std::unique_ptr RegisterKernels() { using namespace js; JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info) - : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), true} { + : IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), true}, + preferred_data_layout_{info.data_layout} { } std::vector JsExecutionProvider::CreatePreferredAllocators() { diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index 091aa2904604a..39d43498c0717 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -19,12 +19,21 @@ KernelCreateInfo BuildKernelCreateInfo(); } // namespace js -// placeholder for future use. no options currently struct JsExecutionProviderInfo { - JsExecutionProviderInfo() = default; - JsExecutionProviderInfo(const ProviderOptions& po) { + auto it = po.find("preferred_layout"); + if (it != po.end()) { + auto& value = it->second; + if (value == "NCHW") { + data_layout = DataLayout::NCHW; + } else if (value == "NHWC") { + data_layout = DataLayout::NHWC; + } + } } + + // JSEP default preferred layout is NHWC + DataLayout data_layout = DataLayout::NHWC; }; class JsExecutionProvider : public IExecutionProvider { @@ -39,7 +48,7 @@ class JsExecutionProvider : public IExecutionProvider { std::shared_ptr GetKernelRegistry() const override; std::unique_ptr GetDataTransfer() const override; - DataLayout GetPreferredLayout() const override { return DataLayout::NHWC; } + DataLayout GetPreferredLayout() const override { return preferred_data_layout_; } FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } @@ -48,6 +57,7 @@ class JsExecutionProvider : public IExecutionProvider { bool ConcurrentRunSupported() const override { return false; } std::vector CreatePreferredAllocators() override; + DataLayout preferred_data_layout_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index 177c0a9e691ed..fdd5c7dee5bfc 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -196,7 +196,7 @@ class JsKernel : public OpKernel { } int status_code = EM_ASM_INT( - { return Module.jsepRunKernel($0, $1, Module.jsepSessionState); }, + { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, this, reinterpret_cast(p_serialized_kernel_context)); LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data=" diff --git a/onnxruntime/core/providers/js/operators/conv.cc b/onnxruntime/core/providers/js/operators/conv.cc index c7c9f7f7c3f0e..2e07124dcd901 100644 --- a/onnxruntime/core/providers/js/operators/conv.cc +++ b/onnxruntime/core/providers/js/operators/conv.cc @@ -9,33 +9,27 @@ namespace onnxruntime { namespace js { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Conv, \ - kMSInternalNHWCDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Conv, \ - kOnnxDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - Conv, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); - -REGISTER_KERNEL_TYPED(float) +ONNX_OPERATOR_KERNEL_EX( + Conv, + kMSInternalNHWCDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Conv); +ONNX_OPERATOR_KERNEL_EX( + Conv, + kOnnxDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Conv); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Conv, + kOnnxDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Conv); } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index 22f7721276677..fdf3e5b6c6b66 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -9,7 +9,7 @@ namespace onnxruntime { namespace js { -template +template class Conv : public JsKernel { public: Conv(const OpKernelInfo& info) : JsKernel(info), conv_attrs_(info), w_is_const_(false) { diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.cc b/onnxruntime/core/providers/js/operators/conv_transpose.cc index 1a2fc99eada6a..2228343e1e6e3 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.cc +++ b/onnxruntime/core/providers/js/operators/conv_transpose.cc @@ -7,33 +7,28 @@ #include "conv_transpose.h" namespace onnxruntime { namespace js { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kMSInternalNHWCDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kOnnxDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); -REGISTER_KERNEL_TYPED(float) +ONNX_OPERATOR_KERNEL_EX( + ConvTranspose, + kMSInternalNHWCDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + ConvTranspose); +ONNX_OPERATOR_KERNEL_EX( + ConvTranspose, + kOnnxDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + ConvTranspose); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ConvTranspose, + kOnnxDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + ConvTranspose); } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index a5aeae8646373..18ef73268005d 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -9,7 +9,7 @@ #include "core/providers/js/js_kernel.h" namespace onnxruntime { namespace js { -template +template class ConvTranspose : public JsKernel { public: ConvTranspose(const OpKernelInfo& info) : JsKernel(info), conv_transpose_attrs_(info), w_is_const_(false) { @@ -108,6 +108,23 @@ class ConvTranspose : public JsKernel { } } + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* /* prepacked_weights */) override { + is_packed = false; + + if (input_idx == 1) { + // Only handle the common case of conv2D + if (tensor.Shape().NumDimensions() != 4 || tensor.SizeInBytes() == 0) { + return Status::OK(); + } + + w_is_const_ = true; + } + + return Status::OK(); + } + protected: ConvTransposeAttributes conv_transpose_attrs_; bool w_is_const_; diff --git a/onnxruntime/core/providers/js/operators/matmul.cc b/onnxruntime/core/providers/js/operators/matmul.cc index ddfbb454def07..6e6f906f7b42c 100644 --- a/onnxruntime/core/providers/js/operators/matmul.cc +++ b/onnxruntime/core/providers/js/operators/matmul.cc @@ -9,11 +9,11 @@ namespace js { JSEP_KERNEL_IMPL(MatMul, MatMul) ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 1, 12, kJsExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), MatMul); ONNX_OPERATOR_KERNEL_EX(MatMul, kOnnxDomain, 13, kJsExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), MatMul); } // namespace js diff --git a/onnxruntime/core/providers/js/operators/range.cc b/onnxruntime/core/providers/js/operators/range.cc new file mode 100644 index 0000000000000..e15861f7f227a --- /dev/null +++ b/onnxruntime/core/providers/js/operators/range.cc @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +#include "range.h" + +namespace onnxruntime { +namespace js { +ONNX_OPERATOR_KERNEL_EX( + Range, + kOnnxDomain, + 11, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .InputMemoryType(OrtMemTypeCPU, 0) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2), + Range); +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/range.h b/onnxruntime/core/providers/js/operators/range.h new file mode 100644 index 0000000000000..8b32bfc3d984b --- /dev/null +++ b/onnxruntime/core/providers/js/operators/range.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +JSEP_KERNEL_IMPL(Range, Range); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/where.cc b/onnxruntime/core/providers/js/operators/where.cc new file mode 100644 index 0000000000000..2f8f5e275aa98 --- /dev/null +++ b/onnxruntime/core/providers/js/operators/where.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +#define REG_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION, \ + kJsExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", \ + {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}), \ + KERNEL_CLASS); + +#define REG_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, \ + kOnnxDomain, \ + VERSION_FROM, VERSION_TO, \ + kJsExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", \ + {DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType(), \ + DataTypeImpl::GetTensorType()}), \ + KERNEL_CLASS); + +JSEP_KERNEL_IMPL(Where, Where) +REG_ELEMENTWISE_VERSIONED_KERNEL(Where, 9, 15, Where); +REG_ELEMENTWISE_KERNEL(Where, 16, Where); +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 8e31124ce4c85..e2083371acca4 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -30,12 +30,20 @@ typedef Qnn_ErrorHandle_t (*QnnSystemInterfaceGetProvidersFn_t)(const QnnSystemI constexpr const char* QNN_PROVIDER = "ORTQNNEP"; +static Qnn_Version_t GetQnnInterfaceApiVersion(const QnnInterface_t* qnn_interface) { + return qnn_interface->apiVersion.coreApiVersion; +} + +static Qnn_Version_t GetQnnInterfaceApiVersion(const QnnSystemInterface_t* qnn_interface) { + return qnn_interface->systemApiVersion; +} + template -Status QnnBackendManager::GetQnnInterfaceProviders(const char* lib_path, - const char* interface_provider_name, - void** backend_lib_handle, - T*** interface_providers, - uint32_t& num_providers) { +Status QnnBackendManager::GetQnnInterfaceProvider(const char* lib_path, + const char* interface_provider_name, + void** backend_lib_handle, + Qnn_Version_t req_version, + T** interface_provider) { std::string error_msg; *backend_lib_handle = LoadLib(lib_path, static_cast(DlOpenFlag::DL_NOW) | static_cast(DlOpenFlag::DL_LOCAL), @@ -47,10 +55,36 @@ Status QnnBackendManager::GetQnnInterfaceProviders(const char* lib_path, GetInterfaceProviders = ResolveSymbol(*backend_lib_handle, interface_provider_name, *logger_); ORT_RETURN_IF(nullptr == GetInterfaceProviders, "Failed to get QNN providers!"); - auto result = GetInterfaceProviders((const T***)interface_providers, &num_providers); + T** interface_providers{nullptr}; + uint32_t num_providers{0}; + + auto result = GetInterfaceProviders((const T***)&interface_providers, &num_providers); ORT_RETURN_IF((QNN_SUCCESS != result || nullptr == *interface_providers || 0 == num_providers), "Failed to get QNN providers."); + bool found_valid_interface{false}; + for (size_t pIdx = 0; pIdx < num_providers; pIdx++) { + Qnn_Version_t interface_version = GetQnnInterfaceApiVersion(interface_providers[pIdx]); + + LOGS_DEFAULT(VERBOSE) << lib_path << " interface version: " << interface_version.major << "." + << interface_version.minor << "." << interface_version.patch; + + // Check the interface's API version against the required version. + // Major versions must match. The interface's minor version must be greater OR equal with a suitable patch version. + if (interface_version.major == req_version.major) { + bool minor_and_patch_version_ok = (interface_version.minor > req_version.minor) || + (interface_version.minor == req_version.minor && + interface_version.patch >= req_version.patch); + if (minor_and_patch_version_ok) { + found_valid_interface = true; + *interface_provider = interface_providers[pIdx]; + break; + } + } + } + + ORT_RETURN_IF_NOT(found_valid_interface, "Unable to find a valid interface for ", lib_path); + return Status::OK(); } @@ -76,38 +110,89 @@ void QnnBackendManager::SetQnnBackendType(uint32_t backend_id) { } Status QnnBackendManager::LoadBackend() { - QnnInterface_t** interface_providers{nullptr}; - uint32_t num_providers{0}; - auto rt = GetQnnInterfaceProviders(backend_path_.c_str(), - "QnnInterface_getProviders", - &backend_lib_handle_, - &interface_providers, - num_providers); + QnnInterface_t* backend_interface_provider{nullptr}; + auto rt = GetQnnInterfaceProvider(backend_path_.c_str(), + "QnnInterface_getProviders", + &backend_lib_handle_, + {QNN_API_VERSION_MAJOR, + QNN_API_VERSION_MINOR, + QNN_API_VERSION_PATCH}, + &backend_interface_provider); ORT_RETURN_IF_ERROR(rt); + qnn_interface_ = backend_interface_provider->QNN_INTERFACE_VER_NAME; + auto backend_id = backend_interface_provider->backendId; + SetQnnBackendType(backend_id); - bool found_valid_interface{false}; - LOGS_DEFAULT(VERBOSE) << "QNN_API_VERSION_MAJOR: " << QNN_API_VERSION_MAJOR - << " QNN_API_VERSION_MINOR: " << QNN_API_VERSION_MINOR; - for (size_t pIdx = 0; pIdx < num_providers; pIdx++) { - LOGS_DEFAULT(VERBOSE) << "interface_providers major: " << interface_providers[pIdx]->apiVersion.coreApiVersion.major - << " interface_providers minor: " << interface_providers[pIdx]->apiVersion.coreApiVersion.minor; - if (QNN_API_VERSION_MAJOR == interface_providers[pIdx]->apiVersion.coreApiVersion.major && - QNN_API_VERSION_MINOR <= interface_providers[pIdx]->apiVersion.coreApiVersion.minor) { - found_valid_interface = true; - qnn_interface_ = interface_providers[pIdx]->QNN_INTERFACE_VER_NAME; - auto backend_id = interface_providers[pIdx]->backendId; - SetQnnBackendType(backend_id); - - LOGS_DEFAULT(INFO) << "Found valid interface, version: " << QNN_API_VERSION_MAJOR - << "." << QNN_API_VERSION_MINOR - << " backend provider name: " << interface_providers[pIdx]->providerName - << " backend id: " << backend_id; - break; + Qnn_Version_t backend_interface_version = GetQnnInterfaceApiVersion(backend_interface_provider); + LOGS_DEFAULT(INFO) << "Found valid interface, version: " << backend_interface_version.major + << "." << backend_interface_version.minor << "." << backend_interface_version.patch + << " backend provider name: " << backend_interface_provider->providerName + << " backend id: " << backend_id; + + return Status::OK(); +} + +// Loads the intended backend (e.g., HTP, CPU, etc) to get its type, and then +// sets QNN Saver as the active backend. QNN op builders will still see the intended backend (e.g., HTP) +// as the backend type to ensure they emit the expected QNN API calls. +// +// QNN Saver is a "debugging" backend that serializes all QNN API calls (and weights) into local files. +// This information can be used to debug issues by replaying QNN API calls with another backend. +Status QnnBackendManager::LoadQnnSaverBackend() { + void* backend_lib_handle = nullptr; + + // Helper that unloads the intended backend library handle when the `unload_backend_lib` variable + // goes out of scope. Similar to `defer` in other languages. + auto unload_backend_lib = gsl::finally([&] { + if (backend_lib_handle != nullptr) { + auto result = UnloadLib(backend_lib_handle); + if (Status::OK() != result) { + ORT_THROW("Failed to unload backend library."); + } } - } + }); + + // Load the intended backend (e.g., HTP, CPU) to ensure it is valid and to get its type. + QnnInterface_t* backend_interface_provider{nullptr}; + auto rt = GetQnnInterfaceProvider(backend_path_.c_str(), + "QnnInterface_getProviders", + &backend_lib_handle, + {QNN_API_VERSION_MAJOR, + QNN_API_VERSION_MINOR, + QNN_API_VERSION_PATCH}, + &backend_interface_provider); + ORT_RETURN_IF_ERROR(rt); - ORT_RETURN_IF_NOT(found_valid_interface, "Unable to find a valid interface."); + // Set the "intended" backend type so that QNN builders still make the expected QNN API calls. + auto backend_id = backend_interface_provider->backendId; + SetQnnBackendType(backend_id); + + // Load the QNN Saver backend and set it as the activate backend. + QnnInterface_t* saver_interface_provider{nullptr}; + auto saver_rt = GetQnnInterfaceProvider(qnn_saver_path_.c_str(), + "QnnInterface_getProviders", + &backend_lib_handle_, // NOTE: QNN Saver library handle is set + {QNN_API_VERSION_MAJOR, + QNN_API_VERSION_MINOR, + QNN_API_VERSION_PATCH}, + &saver_interface_provider); + ORT_RETURN_IF_ERROR(saver_rt); + qnn_interface_ = saver_interface_provider->QNN_INTERFACE_VER_NAME; // NOTE: QNN Saver will provide the interfaces + + Qnn_Version_t backend_interface_version = GetQnnInterfaceApiVersion(backend_interface_provider); + Qnn_Version_t saver_interface_version = GetQnnInterfaceApiVersion(saver_interface_provider); + + LOGS_DEFAULT(INFO) << "Using QNN Saver version: " << saver_interface_version.major << "." + << saver_interface_version.minor << "." << saver_interface_version.patch + << " provider name : " << saver_interface_provider->providerName; + + LOGS_DEFAULT(INFO) << "Intended backend provider name: " << backend_interface_provider->providerName + << " backend id: " << backend_id + << " interface version: " << backend_interface_version.major + << "." << backend_interface_version.minor << "." << backend_interface_version.patch; return Status::OK(); } @@ -120,34 +205,22 @@ Status QnnBackendManager::LoadQnnSystemLib() { #endif // #ifdef _WIN32 std::filesystem::path lib_file_path(backend_path_.c_str()); std::string sys_file_path(lib_file_path.remove_filename().string() + system_lib_file); - QnnSystemInterface_t** system_interface_providers{nullptr}; - uint32_t num_providers = 0; - auto rt = GetQnnInterfaceProviders(sys_file_path.c_str(), - "QnnSystemInterface_getProviders", - &system_lib_handle_, - &system_interface_providers, - num_providers); + QnnSystemInterface_t* system_interface_provider{nullptr}; + auto rt = GetQnnInterfaceProvider(sys_file_path.c_str(), + "QnnSystemInterface_getProviders", + &system_lib_handle_, + {QNN_SYSTEM_API_VERSION_MAJOR, + QNN_SYSTEM_API_VERSION_MINOR, + QNN_SYSTEM_API_VERSION_PATCH}, + &system_interface_provider); ORT_RETURN_IF_ERROR(rt); + Qnn_Version_t system_interface_version = GetQnnInterfaceApiVersion(system_interface_provider); + qnn_sys_interface_ = system_interface_provider->QNN_SYSTEM_INTERFACE_VER_NAME; - bool found_valid_interface{false}; - for (size_t pIdx = 0; pIdx < num_providers; pIdx++) { - LOGS_DEFAULT(VERBOSE) << "system_interface_providers major: " << system_interface_providers[pIdx]->systemApiVersion.major - << " system_interface_providers minor: " << system_interface_providers[pIdx]->systemApiVersion.minor; - int64_t systems_version_major = static_cast(system_interface_providers[pIdx]->systemApiVersion.major); - int64_t systems_version_minor = static_cast(system_interface_providers[pIdx]->systemApiVersion.minor); - if (systems_version_major == QNN_SYSTEM_API_VERSION_MAJOR && - systems_version_minor >= QNN_SYSTEM_API_VERSION_MINOR) { - found_valid_interface = true; - qnn_sys_interface_ = system_interface_providers[pIdx]->QNN_SYSTEM_INTERFACE_VER_NAME; - LOGS_DEFAULT(INFO) << "Found valid system interface, version: " << QNN_API_VERSION_MAJOR - << "." << QNN_API_VERSION_MINOR - << " backend provider name: " << system_interface_providers[pIdx]->providerName; - break; - } - } - - ORT_RETURN_IF_NOT(found_valid_interface, "Unable to find a valid system interface."); + LOGS_DEFAULT(INFO) << "Found valid system interface, version: " << system_interface_version.major + << "." << system_interface_version.minor + << " backend provider name: " << system_interface_provider->providerName; return Status::OK(); } @@ -643,7 +716,12 @@ Status QnnBackendManager::SetupBackend(const logging::Logger& logger, bool load_ return Status::OK(); } - ORT_RETURN_IF_ERROR(LoadBackend()); + if (qnn_saver_path_.empty()) { + ORT_RETURN_IF_ERROR(LoadBackend()); + } else { + ORT_RETURN_IF_ERROR(LoadQnnSaverBackend()); + } + LOGS(logger, VERBOSE) << "LoadBackend succeed."; if (load_from_cached_context) { diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 4ca63a042c103..402f842c7a4bf 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -25,14 +25,16 @@ class QnnModel; class QnnBackendManager { public: - QnnBackendManager(std::string backend_path, + QnnBackendManager(std::string&& backend_path, ProfilingLevel profiling_level, uint32_t rpc_control_latency, - HtpPerformanceMode htp_performance_mode) + HtpPerformanceMode htp_performance_mode, + std::string&& qnn_saver_path) : backend_path_(backend_path), profiling_level_(profiling_level), rpc_control_latency_(rpc_control_latency), - htp_performance_mode_(htp_performance_mode) { + htp_performance_mode_(htp_performance_mode), + qnn_saver_path_(qnn_saver_path) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnBackendManager); @@ -140,6 +142,8 @@ class QnnBackendManager { Status LoadQnnSystemLib(); + Status LoadQnnSaverBackend(); + Status UnloadLib(void* handle); void* LibFunction(void* handle, const char* symbol, std::string& error_msg); @@ -155,11 +159,11 @@ class QnnBackendManager { } template - Status GetQnnInterfaceProviders(const char* lib_path, - const char* interface_provider_name, - void** backend_lib_handle, - T*** interface_providers, - uint32_t& num_providers); + Status GetQnnInterfaceProvider(const char* lib_path, + const char* interface_provider_name, + void** backend_lib_handle, + Qnn_Version_t req_version, + T** interface_provider); bool IsDevicePropertySupported(); @@ -210,6 +214,7 @@ class QnnBackendManager { #ifdef _WIN32 std::set mod_handles_; #endif + const std::string qnn_saver_path_; }; } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 7bbfe807da0f2..ec5316eb13ce1 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -104,9 +104,10 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio static const std::string BACKEND_PATH = "backend_path"; auto backend_path_pos = runtime_options_.find(BACKEND_PATH); + std::string backend_path; if (backend_path_pos != runtime_options_.end()) { - backend_path_ = backend_path_pos->second; - LOGS_DEFAULT(VERBOSE) << "Backend path: " << backend_path_; + backend_path = backend_path_pos->second; + LOGS_DEFAULT(VERBOSE) << "Backend path: " << backend_path; } else { LOGS_DEFAULT(ERROR) << "No backend path provided."; } @@ -131,10 +132,21 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio ParseHtpPerformanceMode(htp_performance_mode_pos->second); } - qnn_backend_manager_ = std::make_unique(backend_path_, - profiling_level_, - rpc_control_latency_, - htp_performance_mode_); + // Enable use of QNN Saver if the user provides a path the QNN Saver backend library. + static const std::string QNN_SAVER_PATH_KEY = "qnn_saver_path"; + std::string qnn_saver_path; + auto qnn_saver_path_pos = runtime_options_.find(QNN_SAVER_PATH_KEY); + if (qnn_saver_path_pos != runtime_options_.end()) { + qnn_saver_path = qnn_saver_path_pos->second; + LOGS_DEFAULT(VERBOSE) << "User specified QNN Saver path: " << qnn_saver_path; + } + + qnn_backend_manager_ = std::make_unique( + std::move(backend_path), + profiling_level_, + rpc_control_latency_, + htp_performance_mode_, + std::move(qnn_saver_path)); } bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 2fe507b70a6ab..3827e2044e2b1 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -58,7 +58,6 @@ class QNNExecutionProvider : public IExecutionProvider { private: ProviderOptions runtime_options_; - std::string backend_path_; qnn::ProfilingLevel profiling_level_ = qnn::ProfilingLevel::OFF; qnn::HtpPerformanceMode htp_performance_mode_ = qnn::HtpPerformanceMode::kHtpDefault; std::unique_ptr qnn_backend_manager_; diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index 301927d9c658f..01e4a3c60281f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -46,7 +46,7 @@ class BaseOpBuilder : public IOpBuilder { // We still set the mininal supported opset to 1 as we couldn't // get the model opset version at this stage. virtual int GetMinSupportedOpSet(const Node& /* node */) const { return 1; } - virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 19; } + virtual int GetMaxSupportedOpSet(const Node& /* node */) const { return 20; } private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index e7e3cee21c956..6a86ca7aca6e9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -35,25 +35,79 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); const auto input_size = input_shape.size(); - // WebNN Softmax only support 2d input shape, reshape input to 2d. - if (input_size != 2) { - int32_t new_shape_0 = SafeInt(input_shape.data()[0]); - for (size_t i = 1; i < input_size - 1; i++) { - new_shape_0 *= input_shape.data()[i]; + NodeAttrHelper helper(node); + if (node.SinceVersion() < 13) { + int32_t axis = helper.Get("axis", 1); + axis = static_cast(HandleNegativeAxis(axis, input_size)); + // Coerce the input into a 2-dimensional tensor with dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. + if (input_size != 2) { + int32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, + 1, std::multiplies())); + int32_t second_dim = static_cast(std::reduce(input_shape.begin() + axis, input_shape.end(), + 1, std::multiplies())); + emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); + input = model_builder.GetBuilder().call("reshape", input, new_shape); } - emscripten::val new_shape = emscripten::val::array(); - new_shape.call("push", new_shape_0); - new_shape.call("push", static_cast(input_shape.back())); - input = model_builder.GetBuilder().call("reshape", input, new_shape); - } - output = model_builder.GetBuilder().call("softmax", input); - // Reshape output to the same shape of input. - if (input_size != 2) { - emscripten::val new_shape = emscripten::val::array(); - for (size_t i = 0; i < input_size; i++) { - new_shape.call("push", static_cast(input_shape[i])); + + output = model_builder.GetBuilder().call("softmax", input); + + // Reshape output to the same shape of input. + if (input_size != 2) { + emscripten::val new_shape = emscripten::val::array(); + for (size_t i = 0; i < input_size; i++) { + new_shape.call("push", static_cast(input_shape[i])); + } + output = model_builder.GetBuilder().call("reshape", output, new_shape); + } + } else { + int32_t axis = helper.Get("axis", -1); + axis = static_cast(HandleNegativeAxis(axis, input_size)); + // Wraparound for transpose the target axis to the last. + // WebNN compute the softmax values of the 2-D input tensor along axis 1. + // https://www.w3.org/TR/webnn/#api-mlgraphbuilder-softmax-method + if (axis != static_cast(input_shape.size() - 1)) { + emscripten::val options = emscripten::val::object(); + std::vector permutation(input_shape.size()); + std::iota(permutation.begin(), permutation.end(), 0); + permutation.erase(permutation.begin() + axis); + permutation.push_back(axis); + options.set("permutation", emscripten::val::array(permutation)); + input = model_builder.GetBuilder().call("transpose", input, options); + } + // Wraparound for reshape input tensor to 2-D. + if (input_shape.size() != 2) { + uint32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, + 1, std::multiplies())); + first_dim *= static_cast(std::reduce(input_shape.begin() + axis + 1, input_shape.end(), + 1, std::multiplies())); + uint32_t second_dim = static_cast(input_shape[axis]); + emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); + input = model_builder.GetBuilder().call("reshape", input, new_shape); + } + + output = model_builder.GetBuilder().call("softmax", input); + + // Transpose back to the axis. + if (input_shape.size() != 2) { + std::vector new_shape; + std::transform(input_shape.begin(), input_shape.begin() + axis, std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return static_cast(dim); }); + std::transform(input_shape.begin() + axis + 1, input_shape.end(), std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return static_cast(dim); }); + new_shape.push_back(static_cast(input_shape[axis])); + output = model_builder.GetBuilder().call("reshape", + output, emscripten::val::array(new_shape)); + } + // Reshape to the original shape. + if (axis != static_cast(input_shape.size() - 1)) { + emscripten::val options = emscripten::val::object(); + std::vector permutation(input_shape.size()); + std::iota(permutation.begin(), permutation.end(), 0); + permutation.pop_back(); + permutation.insert(permutation.begin() + axis, input_shape.size() - 1); + options.set("permutation", emscripten::val::array(permutation)); + output = model_builder.GetBuilder().call("transpose", output, options); } - output = model_builder.GetBuilder().call("reshape", output, new_shape); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); @@ -75,13 +129,6 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali << input_size << "d shape"; return false; } - NodeAttrHelper helper(node); - const int32_t axis = helper.Get("axis", 1); - // WebNN softmax only support input axis 1 - if (axis != 1 && axis != -1) { - LOGS(logger, VERBOSE) << "SoftMax only support axis 1 or -1, input axis: " << axis; - return false; - } return true; } diff --git a/onnxruntime/core/providers/xnnpack/nn/resize.cc b/onnxruntime/core/providers/xnnpack/nn/resize.cc index 672b2597279db..76c6b6acbfe32 100644 --- a/onnxruntime/core/providers/xnnpack/nn/resize.cc +++ b/onnxruntime/core/providers/xnnpack/nn/resize.cc @@ -331,7 +331,13 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 13, 17, kXnnpackExecution DataTypeImpl::GetTensorType()}), Resize); -ONNX_OPERATOR_KERNEL_EX(Resize, kOnnxDomain, 18, kXnnpackExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 18, 18, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + Resize); + +ONNX_OPERATOR_KERNEL_EX(Resize, kOnnxDomain, 19, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index ba577ac38d48c..494c718cde081 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -46,7 +46,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWC class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 10, 10, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, 17, Resize); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 19, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); @@ -84,7 +85,9 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo< ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, Softmax)>, BuildKernelCreateInfo< - ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, Resize)>, + ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 19, Resize)>, + BuildKernelCreateInfo< + ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, 18, Resize)>, BuildKernelCreateInfo< ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, 17, Resize)>, BuildKernelCreateInfo< diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 4fcc6de561f8c..fb314b161f1ad 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -143,6 +143,14 @@ ORT_API_STATUS_IMPL(OrtApis::SetSessionLogId, _In_ OrtSessionOptions* options, c return nullptr; } +///< logging function and optional logging param to use for session output +ORT_API_STATUS_IMPL(OrtApis::SetUserLoggingFunction, _In_ OrtSessionOptions* options, + _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param) { + options->value.user_logging_function = user_logging_function; + options->value.user_logging_param = user_logging_param; + return nullptr; +} + ///< applies to session load, initialization, etc ORT_API_STATUS_IMPL(OrtApis::SetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, int session_log_verbosity_level) { options->value.session_log_verbosity_level = session_log_verbosity_level; diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 14c284d7bbdec..d8256646f9707 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -27,6 +27,10 @@ static constexpr uint32_t min_ort_version_with_optional_io_support = 8; static constexpr uint32_t min_ort_version_with_variadic_io_support = 14; #endif +#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) +static constexpr uint32_t min_ort_version_with_compute_v2_support = 17; +#endif + #if !defined(DISABLE_FLOAT8_TYPES) #define SUPPORTED_TENSOR_TYPES DataTypeImpl::AllTensorTypesIRv9() #else @@ -424,7 +428,8 @@ struct CustomOpKernel : OpKernel { ORT_THROW("Unsupported version '" + std::to_string(op_.version) + "' in custom op '" + op.GetName(&op)); } - if (op_.version > 15 && op_.KernelCompute == 0) { + if (op_.version >= min_ort_version_with_compute_v2_support && + op_.CreateKernelV2) { op_kernel_ = nullptr; Ort::ThrowOnError( op_.CreateKernelV2( @@ -443,13 +448,14 @@ struct CustomOpKernel : OpKernel { } Status Compute(OpKernelContext* ctx) const override { - if (op_.version > 15 && op_.KernelCompute == 0) { + if (op_.version >= min_ort_version_with_compute_v2_support && + op_.KernelComputeV2) { auto status_ptr = op_.KernelComputeV2(op_kernel_, reinterpret_cast(ctx)); return ToStatus(status_ptr); + } else { + op_.KernelCompute(op_kernel_, reinterpret_cast(ctx)); + return Status::OK(); } - - op_.KernelCompute(op_kernel_, reinterpret_cast(ctx)); - return Status::OK(); } private: diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 5a2a6efb6df4b..b4d47652942b7 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -56,6 +56,7 @@ #include "core/providers/dml/dml_session_options_config_keys.h" #endif #include "core/session/environment.h" +#include "core/session/user_logging_sink.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -298,6 +299,35 @@ static Status FinalizeSessionOptions(const SessionOptions& user_provided_session return Status::OK(); } +logging::Severity GetSeverity(const SessionOptions& session_options) { + logging::Severity severity = logging::Severity::kWARNING; + if (session_options.session_log_severity_level == -1) { + severity = logging::LoggingManager::DefaultLogger().GetSeverity(); + } else { + ORT_ENFORCE(session_options.session_log_severity_level >= 0 && + session_options.session_log_severity_level <= static_cast(logging::Severity::kFATAL), + "Invalid session log severity level. Not a valid onnxruntime::logging::Severity value: ", + session_options.session_log_severity_level); + severity = static_cast(session_options.session_log_severity_level); + } + return severity; +} + +void InferenceSession::SetLoggingManager(const SessionOptions& session_options, + const Environment& session_env) { + logging_manager_ = session_env.GetLoggingManager(); + if (session_options.user_logging_function) { + std::unique_ptr user_sink = std::make_unique(session_options.user_logging_function, + session_options.user_logging_param); + user_logging_manager_ = std::make_unique(std::move(user_sink), + GetSeverity(session_options), + false, + logging::LoggingManager::InstanceType::Temporal, + &session_options.session_logid); + logging_manager_ = user_logging_manager_.get(); + } +} + void InferenceSession::ConstructorCommon(const SessionOptions& session_options, const Environment& session_env) { auto status = FinalizeSessionOptions(session_options, model_proto_, is_model_proto_parsed_, session_options_); @@ -306,6 +336,8 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ", status.ErrorMessage()); + SetLoggingManager(session_options, session_env); + // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked // after the invocation of FinalizeSessionOptions. InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point. @@ -427,7 +459,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const #if !defined(ORT_MINIMAL_BUILD) graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), #endif - logging_manager_(session_env.GetLoggingManager()), environment_(session_env) { // Initialize assets of this session instance ConstructorCommon(session_options, session_env); @@ -441,7 +472,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, #if !defined(ORT_MINIMAL_BUILD) graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), #endif - logging_manager_(session_env.GetLoggingManager()), external_intra_op_thread_pool_(external_intra_op_thread_pool), external_inter_op_thread_pool_(external_inter_op_thread_pool), environment_(session_env) { @@ -454,7 +484,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const const PathString& model_uri) : model_location_(model_uri), graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), - logging_manager_(session_env.GetLoggingManager()), environment_(session_env) { auto status = Model::Load(model_location_, model_proto_); ORT_ENFORCE(status.IsOK(), "Given model could not be parsed while creating inference session. Error message: ", @@ -475,7 +504,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, std::istream& model_istream) : graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), - logging_manager_(session_env.GetLoggingManager()), environment_(session_env) { Status st = Model::Load(model_istream, &model_proto_); ORT_ENFORCE(st.IsOK(), "Could not parse model successfully while constructing the inference session"); @@ -487,7 +515,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, const void* model_data, int model_data_len) : graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), - logging_manager_(session_env.GetLoggingManager()), environment_(session_env) { const bool result = model_proto_.ParseFromArray(model_data, model_data_len); ORT_ENFORCE(result, "Could not parse model successfully while constructing the inference session"); @@ -1005,18 +1032,22 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool layout_transformation::TransformLayoutForEP(graph_to_transform, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn)); - if (modified) { - ORT_RETURN_IF_ERROR_SESSIONID_( - graph_transformer_mgr_.ApplyTransformers(graph_to_transform, TransformerLevel::Level1, *session_logger_)); - - // debug the graph after the L1 transformers have run against any layout transformation changes. - // this is prior to GraphPartitioner::GetCapabilityForEP calling IExecutionProvider::GetCapability the second - // time to validate the EP that requested the layout transformation can take all nodes using the new layout. - // if that fails, this allows debugging the graph used in that GetCapability call. - if (debug_graph_fn) { - debug_graph_fn(graph_to_transform); - } - } + // Previously we ran the L1 transformers to handle constant folding of any initializers that were transposed in + // a QDQ format model. The transpose optimizer can now look past DQ nodes to directly update initializers which + // takes care of most models without needing this. + // + // if (modified) { + // ORT_RETURN_IF_ERROR_SESSIONID_( + // graph_transformer_mgr_.ApplyTransformers(graph_to_transform, TransformerLevel::Level1, *session_logger_)); + // + // debug the graph after the L1 transformers have run against any layout transformation changes. + // this is prior to GraphPartitioner::GetCapabilityForEP calling IExecutionProvider::GetCapability the second + // time to validate the EP that requested the layout transformation can take all nodes using the new layout. + // if that fails, this allows debugging the graph used in that GetCapability call. + // if (debug_graph_fn) { + // debug_graph_fn(graph_to_transform); + //} + //} return Status::OK(); }; @@ -2811,17 +2842,7 @@ const logging::Logger& InferenceSession::CreateLoggerForRun(const RunOptions& ru void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) { // create logger for session, using provided logging manager if possible if (logging_manager != nullptr) { - logging::Severity severity = logging::Severity::kWARNING; - if (session_options_.session_log_severity_level == -1) { - severity = logging::LoggingManager::DefaultLogger().GetSeverity(); - } else { - ORT_ENFORCE(session_options_.session_log_severity_level >= 0 && - session_options_.session_log_severity_level <= static_cast(logging::Severity::kFATAL), - "Invalid session log severity level. Not a valid onnxruntime::logging::Severity value: ", - session_options_.session_log_severity_level); - severity = static_cast(session_options_.session_log_severity_level); - } - + logging::Severity severity = GetSeverity(session_options_); owned_session_logger_ = logging_manager_->CreateLogger(session_options_.session_logid, severity, false, session_options_.session_log_verbosity_level); session_logger_ = owned_session_logger_.get(); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 9259e014b9860..4db436f132d11 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -595,7 +595,8 @@ class InferenceSession { private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession); - + void SetLoggingManager(const SessionOptions& session_options, + const Environment& session_env); void ConstructorCommon(const SessionOptions& session_options, const Environment& session_env); @@ -698,7 +699,10 @@ class InferenceSession { SessionOptions session_options_; /// Logging manager if provided. - logging::LoggingManager* const logging_manager_; + logging::LoggingManager* logging_manager_; + + /// User specified logging mgr; logging_manager_ is simply the ptr in this unique_ptr when available + std::unique_ptr user_logging_manager_; /// Logger for this session. WARNING: Will contain nullptr if logging_manager_ is nullptr. std::unique_ptr owned_session_logger_ = nullptr; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 60b6296f7f539..67149e1e99f22 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2713,6 +2713,8 @@ static constexpr OrtApi ort_api_1_to_17 = { &OrtApis::GetCUDAProviderOptionsByName, &OrtApis::KernelContext_GetResource, // End of Version 16 - DO NOT MODIFY ABOVE (see above text for more information) + + &OrtApis::SetUserLoggingFunction, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. @@ -2742,6 +2744,7 @@ static_assert(offsetof(OrtApi, ReleaseCANNProviderOptions) / sizeof(void*) == 22 static_assert(offsetof(OrtApi, GetSessionConfigEntry) / sizeof(void*) == 238, "Size of version 14 API cannot change"); static_assert(offsetof(OrtApi, GetBuildInfoString) / sizeof(void*) == 254, "Size of version 15 API cannot change"); static_assert(offsetof(OrtApi, KernelContext_GetResource) / sizeof(void*) == 265, "Size of version 16 API cannot change"); +static_assert(offsetof(OrtApi, SetUserLoggingFunction) / sizeof(void*) == 266, "Size of version 17 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.17.0", diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 47da2fa524588..f472932e20b8a 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -491,4 +491,6 @@ ORT_API_STATUS_IMPL(GetTensorRTProviderOptionsByName, _In_ const OrtTensorRTProv ORT_API_STATUS_IMPL(UpdateCUDAProviderOptionsWithValue, _Inout_ OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _In_ void* value); ORT_API_STATUS_IMPL(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr); ORT_API_STATUS_IMPL(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resource_version, _In_ int resource_id, _Outptr_ void** stream); +ORT_API_STATUS_IMPL(SetUserLoggingFunction, _Inout_ OrtSessionOptions* options, + _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param); } // namespace OrtApis diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index eb78d5d799a55..e3957baa990f8 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -9,6 +9,7 @@ #include "core/session/ort_apis.h" #include "core/session/environment.h" #include "core/session/allocator_adapters.h" +#include "core/session/user_logging_sink.h" #include "core/common/logging/logging.h" #include "core/framework/provider_shutdown.h" #include "core/platform/logging/make_platform_default_log_sink.h" @@ -20,17 +21,6 @@ std::unique_ptr OrtEnv::p_instance_; int OrtEnv::ref_count_ = 0; onnxruntime::OrtMutex OrtEnv::m_; -LoggingWrapper::LoggingWrapper(OrtLoggingFunction logging_function, void* logger_param) - : logging_function_(logging_function), logger_param_(logger_param) { -} - -void LoggingWrapper::SendImpl(const onnxruntime::logging::Timestamp& /*timestamp*/, const std::string& logger_id, - const onnxruntime::logging::Capture& message) { - std::string s = message.Location().ToString(); - logging_function_(logger_param_, static_cast(message.Severity()), message.Category(), - logger_id.c_str(), s.c_str(), message.Message().c_str()); -} - OrtEnv::OrtEnv(std::unique_ptr value1) : value_(std::move(value1)) { } @@ -50,8 +40,8 @@ OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_inf std::unique_ptr lmgr; std::string name = lm_info.logid; if (lm_info.logging_function) { - std::unique_ptr logger = std::make_unique(lm_info.logging_function, - lm_info.logger_param); + std::unique_ptr logger = std::make_unique(lm_info.logging_function, + lm_info.logger_param); lmgr = std::make_unique(std::move(logger), static_cast(lm_info.default_warning_level), false, diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index 7d609acb2db5d..444134d0612e9 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -5,27 +5,15 @@ #include #include #include "core/session/onnxruntime_c_api.h" -#include "core/common/logging/isink.h" #include "core/platform/ort_mutex.h" #include "core/common/status.h" +#include "core/common/logging/logging.h" #include "core/framework/allocator.h" namespace onnxruntime { class Environment; } -class LoggingWrapper : public onnxruntime::logging::ISink { - public: - LoggingWrapper(OrtLoggingFunction logging_function, void* logger_param); - - void SendImpl(const onnxruntime::logging::Timestamp& /*timestamp*/, const std::string& logger_id, - const onnxruntime::logging::Capture& message) override; - - private: - OrtLoggingFunction logging_function_; - void* logger_param_; -}; - struct OrtEnv { public: struct LoggingManagerConstructionInfo { diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index 9326c6eaff240..50c9f6681a0c8 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -109,6 +109,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, #endif } else if (strcmp(provider_name, "JS") == 0) { #if defined(USE_JSEP) + std::string preferred_layout; + if (options->value.config_options.TryGetConfigEntry("preferredLayout", preferred_layout)) { + provider_options["preferred_layout"] = preferred_layout; + } options->provider_factories.push_back(JsProviderFactoryCreator::Create(provider_options)); #else status = create_not_supported_status(); diff --git a/onnxruntime/core/session/user_logging_sink.h b/onnxruntime/core/session/user_logging_sink.h new file mode 100644 index 0000000000000..5a9ceb21d6500 --- /dev/null +++ b/onnxruntime/core/session/user_logging_sink.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +#include "core/session/onnxruntime_c_api.h" +#include "core/common/logging/isink.h" + +namespace onnxruntime { +class UserLoggingSink : public onnxruntime::logging::ISink { + public: + UserLoggingSink(OrtLoggingFunction logging_function, void* logger_param) + : logging_function_(logging_function), logger_param_(logger_param) { + } + + void SendImpl(const onnxruntime::logging::Timestamp& /*timestamp*/, const std::string& logger_id, + const onnxruntime::logging::Capture& message) override { + std::string s = message.Location().ToString(); + logging_function_(logger_param_, static_cast(message.Severity()), message.Category(), + logger_id.c_str(), s.c_str(), message.Message().c_str()); + } + + private: + OrtLoggingFunction logging_function_{}; + void* logger_param_{}; +}; +} // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_pybind_module.cc b/onnxruntime/python/onnxruntime_pybind_module.cc index f320707697c9e..6824a5d0bf98f 100644 --- a/onnxruntime/python/onnxruntime_pybind_module.cc +++ b/onnxruntime/python/onnxruntime_pybind_module.cc @@ -10,6 +10,12 @@ namespace onnxruntime { namespace python { namespace py = pybind11; +#if defined(USE_MPI) && defined(ORT_USE_NCCL) +static constexpr bool HAS_COLLECTIVE_OPS = true; +#else +static constexpr bool HAS_COLLECTIVE_OPS = false; +#endif + void CreateInferencePybindStateModule(py::module& m); PYBIND11_MODULE(onnxruntime_pybind11_state, m) { @@ -23,6 +29,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { m.def("get_version_string", []() -> std::string { return ORT_VERSION; }); m.def("get_build_info", []() -> std::string { return ORT_BUILD_INFO; }); + m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; }); } } // namespace python } // namespace onnxruntime diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 2d1e418f9d2b4..c3c3d2a837336 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -112,8 +112,8 @@ def __init__( False if "ActivationSymmetric" not in self.extra_options else self.extra_options["ActivationSymmetric"] ) - self.activation_qType = activation_qType.tensor_type - self.weight_qType = weight_qType.tensor_type + self.activation_qType = getattr(activation_qType, "tensor_type", activation_qType) + self.weight_qType = getattr(weight_qType, "tensor_type", weight_qType) """ Dictionary specifying the min and max values for tensors. It has following format: { diff --git a/onnxruntime/python/tools/quantization/shape_inference.py b/onnxruntime/python/tools/quantization/shape_inference.py index eff3dc0bcdc35..b7d4726610387 100644 --- a/onnxruntime/python/tools/quantization/shape_inference.py +++ b/onnxruntime/python/tools/quantization/shape_inference.py @@ -99,7 +99,10 @@ def quant_pre_process( sess_option = onnxruntime.SessionOptions() sess_option.optimized_model_filepath = opt_model_path sess_option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC - _ = onnxruntime.InferenceSession(input_model_path, sess_option, providers=["CPUExecutionProvider"]) + sess = onnxruntime.InferenceSession(input_model_path, sess_option, providers=["CPUExecutionProvider"]) + # Close the session to avoid the cleanup error on Windows for temp folders + # https://github.com/microsoft/onnxruntime/issues/17627 + del sess except Exception: logger.error( "ONNX Runtime Model Optimization Failed! Consider rerun with option `--skip_optimization'." diff --git a/onnxruntime/python/tools/tensorrt/perf/benchmark.py b/onnxruntime/python/tools/tensorrt/perf/benchmark.py index d440cafb23236..acd836c915910 100644 --- a/onnxruntime/python/tools/tensorrt/perf/benchmark.py +++ b/onnxruntime/python/tools/tensorrt/perf/benchmark.py @@ -2176,7 +2176,7 @@ def parse_arguments(): required=False, default=True, action="store_true", - help="Inlcude Float16 into benchmarking.", + help="Include Float16 into benchmarking.", ) parser.add_argument("--trtexec", required=False, default=None, help="trtexec executable path.") diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 67d3c95922a87..4f898245d01bd 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -542,7 +542,7 @@ def measure_gpu_usage(self): while True: for i in range(device_count): max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i)) - time.sleep(0.005) # 2ms + time.sleep(0.005) # 5ms if not self.keep_measuring: break return [ @@ -555,7 +555,7 @@ def measure_gpu_usage(self): ] -def measure_memory(is_gpu, func, monitor_type="cuda"): +def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None): memory_monitor_type = None if monitor_type == "rocm": memory_monitor_type = RocmMemoryMonitor @@ -565,10 +565,16 @@ def measure_memory(is_gpu, func, monitor_type="cuda"): monitor = memory_monitor_type(False) if is_gpu: - memory_before_test = monitor.measure_gpu_usage() + if start_memory is not None: + memory_before_test = start_memory + else: + memory_before_test = monitor.measure_gpu_usage() if memory_before_test is None: return None + if func is None: + return memory_before_test + with ThreadPoolExecutor() as executor: monitor = memory_monitor_type() mem_thread = executor.submit(monitor.measure_gpu_usage) @@ -595,7 +601,13 @@ def measure_memory(is_gpu, func, monitor_type="cuda"): return None # CPU memory - memory_before_test = monitor.measure_cpu_usage() + if start_memory is not None: + memory_before_test = start_memory + else: + memory_before_test = monitor.measure_cpu_usage() + + if func is None: + return memory_before_test with ThreadPoolExecutor() as executor: monitor = MemoryMonitor() diff --git a/onnxruntime/python/tools/transformers/convert_to_packing_mode.py b/onnxruntime/python/tools/transformers/convert_to_packing_mode.py index 0b8dbdcdd9638..4da97f0de7bed 100644 --- a/onnxruntime/python/tools/transformers/convert_to_packing_mode.py +++ b/onnxruntime/python/tools/transformers/convert_to_packing_mode.py @@ -67,7 +67,7 @@ def _try_getting_last_layernorm(self) -> Union[NodeProto, None]: last_layernorm_node = node return last_layernorm_node - def _are_attentions_supportted(self) -> bool: + def _are_attentions_supported(self) -> bool: raise NotImplementedError() def _insert_removepadding_node(self, inputs: List[str], outputs: List[str]) -> None: @@ -105,7 +105,7 @@ def _get_input_to_remove_padding(self, first_attention_node) -> Union[str, None] def convert(self, use_symbolic_shape_infer: bool = True) -> None: logger.debug("start converting to packing model...") - if not self._are_attentions_supportted(): + if not self._are_attentions_supported(): return attention_mask = self._try_getting_attention_mask() @@ -164,7 +164,7 @@ class PackingAttention(PackingAttentionBase): def __init__(self, model: OnnxModel): super().__init__(model, Operators.ATTENTION) - def _are_attentions_supportted(self) -> bool: + def _are_attentions_supported(self) -> bool: for node in self.attention_nodes: if OnnxModel.get_node_attribute(node, "past_present_share_buffer") is not None: return False @@ -237,7 +237,7 @@ def _check_empty_output(self, node, index: int, name: str): return False return True - def _are_attentions_supportted(self) -> bool: + def _are_attentions_supported(self) -> bool: for node in self.attention_nodes: for attr in node.attribute: if attr.name not in ["num_heads", "mask_filter_value", "scale"]: diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 40f2aee875382..a7460157ba409 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -78,7 +78,15 @@ def process_mask(self, input: str) -> str: # ReduceSum-13: axes is moved from attribute to input axes_name = "ort_const_1_reduce_sum_axes" if self.model.get_initializer(axes_name) is None: - self.add_initializer(name=axes_name, data_type=TensorProto.INT64, dims=[1], vals=[1], raw=False) + self.model.add_initializer( + helper.make_tensor( + name=axes_name, + data_type=TensorProto.INT64, + dims=[1], + vals=[1], + raw=False, + ) + ) mask_index_node = helper.make_node( "ReduceSum", inputs=[input_name, axes_name], diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index 71c1a21d8f768..de17f195c99cc 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -283,6 +283,7 @@ def infer(self, feed_dict: Dict[str, torch.Tensor]): if name in self.input_names: if self.enable_cuda_graph: assert self.input_tensors[name].nelement() == tensor.nelement() + assert self.input_tensors[name].dtype == tensor.dtype assert tensor.device.type == "cuda" # Please install cuda-python package with a version corresponding to CUDA in your machine. from cuda import cudart diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index 7ffefdd05f215..1fbd5092a719a 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -74,15 +74,16 @@ Below is an example to optimize Stable Diffusion 1.5 in Linux. For Windows OS, p ### Setup Environment (CUDA) -It is recommended to create a Conda environment with Python 3.8, 3.9 or 3.10, and run the model with [CUDA 11.7](https://developer.nvidia.com/cuda-11-7-0-download-archive) or 11.8. +It is recommended to create a Conda environment with Python 3.8, 3.9 or 3.10, and run the model with CUDA 11.8. +If you use CUDA 12.*, you will need build onnxruntime-gpu from source. ``` conda create -n py38 python=3.8 conda activate py38 -pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 +pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118 +pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com pip install -r requirements-cuda.txt ``` - -ONNX Runtime requires CUDA and [cuDNN](https://developer.nvidia.com/rdp/cudnn-download) for GPU inference. CUDA 11.7 and cuDNN 8.5 are used in our tests. +ONNX Runtime requires CUDA and [cuDNN](https://developer.nvidia.com/rdp/cudnn-download) for GPU inference. CUDA 11.8 and cuDNN 8.5 or above are recommended. #### Install Nightly (Optional) @@ -233,18 +234,21 @@ Sometime, it complains ptxas not found when there are multiple CUDA versions ins Note that torch.compile is not supported in Windows: we encountered error `Windows not yet supported for torch.compile`. So it is excluded from RTX 3060 results of Windows. -### Run Benchmark with TensorRT and TensorRT execution provider +### Run Benchmark with TensorRT or TensorRT execution provider For TensorRT installation, follow https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html. ``` pip install torch==1.13.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117 -pip install --upgrade polygraphy>=0.47.0 onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com +pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com pip install -r requirements-tensorrt.txt export CUDA_MODULE_LOADING=LAZY python benchmark.py -e tensorrt -b 1 -v 1.5 python benchmark.py -e onnxruntime -r tensorrt -b 1 -v 1.5 python benchmark.py -e onnxruntime -r tensorrt -b 1 -v 1.5 --enable_cuda_graph + +python benchmark.py -e tensorrt --height 1024 --width 1024 -s 30 -b 1 -v xl-1.0 --enable_cuda_graph +python benchmark.py -e onnxruntime -r tensorrt --height 1024 --width 1024 -s 30 -b 1 -v xl-1.0 --enable_cuda_graph ``` ### Example Benchmark output diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py index 13126f648d290..f8fda13a35b93 100755 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/benchmark.py @@ -10,15 +10,18 @@ import sys import time +import __init__ # noqa: F401. Walk-around to run this script directly import coloredlogs # import torch before onnxruntime so that onnxruntime uses the cuDNN in the torch package. import torch +from benchmark_helper import measure_memory SD_MODELS = { "1.5": "runwayml/stable-diffusion-v1-5", "2.0": "stabilityai/stable-diffusion-2", "2.1": "stabilityai/stable-diffusion-2-1", + "xl-1.0": "stabilityai/stable-diffusion-xl-refiner-1.0", } PROVIDERS = { @@ -43,139 +46,13 @@ def example_prompts(): "delicate elvish moonstone necklace on a velvet background, symmetrical intricate motifs, leaves, flowers, 8k", ] - return prompts + negative_prompt = "bad composition, ugly, abnormal, malformed" - -class CudaMemoryMonitor: - def __init__(self, keep_measuring=True): - self.keep_measuring = keep_measuring - - def measure_gpu_usage(self): - from py3nvml.py3nvml import ( - NVMLError, - nvmlDeviceGetCount, - nvmlDeviceGetHandleByIndex, - nvmlDeviceGetMemoryInfo, - nvmlDeviceGetName, - nvmlInit, - nvmlShutdown, - ) - - max_gpu_usage = [] - gpu_name = [] - try: - nvmlInit() - device_count = nvmlDeviceGetCount() - if not isinstance(device_count, int): - print(f"nvmlDeviceGetCount result is not integer: {device_count}") - return None - - max_gpu_usage = [0 for i in range(device_count)] - gpu_name = [nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)) for i in range(device_count)] - while True: - for i in range(device_count): - info = nvmlDeviceGetMemoryInfo(nvmlDeviceGetHandleByIndex(i)) - if isinstance(info, str): - print(f"nvmlDeviceGetMemoryInfo returns str: {info}") - return None - max_gpu_usage[i] = max(max_gpu_usage[i], info.used / 1024**2) - time.sleep(0.002) # 2ms - if not self.keep_measuring: - break - nvmlShutdown() - return [ - { - "device_id": i, - "name": gpu_name[i], - "max_used_MB": max_gpu_usage[i], - } - for i in range(device_count) - ] - except NVMLError as error: - print("Error fetching GPU information using nvml: %s", error) - return None - - -class RocmMemoryMonitor: - def __init__(self, keep_measuring=True): - self.keep_measuring = keep_measuring - rocm_smi_path = "/opt/rocm/libexec/rocm_smi" - if os.path.exists(rocm_smi_path): - if rocm_smi_path not in sys.path: - sys.path.append(rocm_smi_path) - try: - import rocm_smi - - self.rocm_smi = rocm_smi - self.rocm_smi.initializeRsmi() - except ImportError: - self.rocm_smi = None - - def get_used_memory(self, dev): - if self.rocm_smi is None: - return -1 - return self.rocm_smi.getMemInfo(dev, "VRAM")[0] / 1024 / 1024 - - def measure_gpu_usage(self): - device_count = len(self.rocm_smi.listDevices()) if self.rocm_smi is not None else 0 - max_gpu_usage = [0 for i in range(device_count)] - gpu_name = [f"GPU{i}" for i in range(device_count)] - while True: - for i in range(device_count): - max_gpu_usage[i] = max(max_gpu_usage[i], self.get_used_memory(i)) - time.sleep(0.002) # 2ms - if not self.keep_measuring: - break - return [ - { - "device_id": i, - "name": gpu_name[i], - "max_used_MB": max_gpu_usage[i], - } - for i in range(device_count) - ] + return prompts, negative_prompt def measure_gpu_memory(monitor_type, func, start_memory=None): - if monitor_type is None: - return None - - monitor = monitor_type(False) - memory_before_test = monitor.measure_gpu_usage() - - if start_memory is None: - start_memory = memory_before_test - if start_memory is None: - return None - if func is None: - return start_memory - - from concurrent.futures import ThreadPoolExecutor - - with ThreadPoolExecutor() as executor: - monitor = monitor_type() - mem_thread = executor.submit(monitor.measure_gpu_usage) - try: - fn_thread = executor.submit(func) - _ = fn_thread.result() - finally: - monitor.keep_measuring = False - max_usage = mem_thread.result() - - if max_usage is None: - return None - - print(f"GPU memory usage: before={memory_before_test} peak={max_usage}") - if len(start_memory) >= 1 and len(max_usage) >= 1 and len(start_memory) == len(max_usage): - # When there are multiple GPUs, we will check the one with maximum usage. - max_used = 0 - for i, memory_before in enumerate(start_memory): - before = memory_before["max_used_MB"] - after = max_usage[i]["max_used_MB"] - used = after - before - max_used = max(max_used, used) - return max_used - return None + return measure_memory(is_gpu=True, func=func, monitor_type=monitor_type, start_memory=start_memory) def get_ort_pipeline(model_name: str, directory: str, provider, disable_safety_checker: bool): @@ -256,7 +133,7 @@ def run_ort_pipeline( assert isinstance(pipe, OnnxStableDiffusionPipeline) - prompts = example_prompts() + prompts, negative_prompt = example_prompts() def warmup(): pipe("warm up", height, width, num_inference_steps=steps, num_images_per_prompt=batch_size) @@ -275,13 +152,12 @@ def warmup(): for j in range(batch_count): inference_start = time.time() images = pipe( - prompt, + [prompt] * batch_size, height, width, num_inference_steps=steps, - negative_prompt=None, + negative_prompt=[negative_prompt] * batch_size, guidance_scale=7.5, - num_images_per_prompt=batch_size, ).images inference_end = time.time() latency = inference_end - inference_start @@ -320,7 +196,7 @@ def run_torch_pipeline( start_memory, memory_monitor_type, ): - prompts = example_prompts() + prompts, negative_prompt = example_prompts() # total 2 runs of warm up, and measure GPU memory for CUDA EP def warmup(): @@ -342,13 +218,12 @@ def warmup(): for j in range(batch_count): inference_start = time.time() images = pipe( - prompt=prompt, + prompt=[prompt] * batch_size, height=height, width=width, num_inference_steps=steps, guidance_scale=7.5, - negative_prompt=None, - num_images_per_prompt=batch_size, + negative_prompt=[negative_prompt] * batch_size, generator=None, # torch.Generator ).images @@ -427,7 +302,7 @@ def run_ort( def export_and_run_ort( - model_name: str, + version: str, provider: str, batch_size: int, disable_safety_checker: bool, @@ -443,15 +318,19 @@ def export_and_run_ort( assert provider == "CUDAExecutionProvider" from diffusers import DDIMScheduler + from diffusion_models import PipelineInfo from onnxruntime_cuda_txt2img import OnnxruntimeCudaStableDiffusionPipeline - scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") + pipeline_info = PipelineInfo(version) + model_name = pipeline_info.name() + scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") pipe = OnnxruntimeCudaStableDiffusionPipeline.from_pretrained( model_name, scheduler=scheduler, requires_safety_checker=not disable_safety_checker, enable_cuda_graph=enable_cuda_graph, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models @@ -473,7 +352,7 @@ def warmup(): image_filename_prefix = get_image_filename_prefix("ort_cuda", model_name, batch_size, disable_safety_checker) latency_list = [] - prompts = example_prompts() + prompts, negative_prompt = example_prompts() for i, prompt in enumerate(prompts): if i >= num_prompts: break @@ -481,6 +360,7 @@ def warmup(): inference_start = time.time() images = pipe( [prompt] * batch_size, + negative_prompt=[negative_prompt] * batch_size, num_inference_steps=steps, ).images inference_end = time.time() @@ -514,7 +394,7 @@ def warmup(): def run_ort_trt( - model_name: str, + version: str, batch_size: int, disable_safety_checker: bool, height: int, @@ -528,8 +408,12 @@ def run_ort_trt( enable_cuda_graph: bool, ): from diffusers import DDIMScheduler + from diffusion_models import PipelineInfo from onnxruntime_tensorrt_txt2img import OnnxruntimeTensorRTStableDiffusionPipeline + pipeline_info = PipelineInfo(version) + model_name = pipeline_info.name() + assert batch_size <= max_batch_size scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") @@ -544,6 +428,7 @@ def run_ort_trt( max_batch_size=max_batch_size, onnx_opset=17, enable_cuda_graph=enable_cuda_graph, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models and TensorRT Engines @@ -552,7 +437,7 @@ def run_ort_trt( pipe = pipe.to("cuda") def warmup(): - pipe(["warm up"] * batch_size, num_inference_steps=steps) + pipe(["warm up"] * batch_size, negative_prompt=["negative"] * batch_size, num_inference_steps=steps) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -564,7 +449,7 @@ def warmup(): image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, disable_safety_checker) latency_list = [] - prompts = example_prompts() + prompts, negative_prompt = example_prompts() for i, prompt in enumerate(prompts): if i >= num_prompts: break @@ -572,6 +457,7 @@ def warmup(): inference_start = time.time() images = pipe( [prompt] * batch_size, + negative_prompt=[negative_prompt] * batch_size, num_inference_steps=steps, ).images inference_end = time.time() @@ -589,7 +475,7 @@ def warmup(): "model_name": model_name, "engine": "onnxruntime", "version": ort_version, - "provider": f"tensorrt{trt_version})", + "provider": f"tensorrt({trt_version})", "directory": pipe.engine_dir, "height": height, "width": width, @@ -606,7 +492,148 @@ def warmup(): } -def run_tensorrt( +def run_ort_trt_static( + work_dir: str, + version: str, + batch_size: int, + disable_safety_checker: bool, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, + start_memory, + memory_monitor_type, + max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph: bool = True, +): + print("[I] Initializing ORT TensorRT EP accelerated StableDiffusionXL txt2img pipeline (static input shape)") + + # Register TensorRT plugins + from trt_utilities import init_trt_plugins + + init_trt_plugins() + + assert batch_size <= max_batch_size + + from diffusion_models import PipelineInfo + + pipeline_info = PipelineInfo(version) + short_name = pipeline_info.short_name() + + from engine_builder import EngineType, get_engine_paths + from pipeline_txt2img import Txt2ImgPipeline + + engine_type = EngineType.ORT_TRT + onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths(work_dir, pipeline_info, engine_type) + + # Initialize pipeline + pipeline = Txt2ImgPipeline( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=use_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + # Load TensorRT engines and pytorch modules + pipeline.backend.build_engines( + engine_dir, + framework_model_dir, + onnx_dir, + 17, + opt_image_height=height, + opt_image_width=width, + opt_batch_size=batch_size, + force_engine_rebuild=False, + static_batch=True, + static_image_shape=True, + max_workspace_size=0, + device_id=torch.cuda.current_device(), + ) + + # Here we use static batch and image size, so the resource allocation only need done once. + # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency. + pipeline.load_resources(height, width, batch_size) + + def warmup(): + pipeline.run( + ["warm up"] * batch_size, ["negative"] * batch_size, height, width, denoising_steps=steps, warmup=True + ) + + # Run warm up, and measure GPU memory of two runs + # The first run has algo search so it might need more memory + first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + + warmup() + + image_filename_prefix = get_image_filename_prefix("ort_trt", short_name, batch_size, disable_safety_checker) + + latency_list = [] + prompts, negative_prompt = example_prompts() + for i, prompt in enumerate(prompts): + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = pipeline.run( + [prompt] * batch_size, + [negative_prompt] * batch_size, + height, + width, + denoising_steps=steps, + guidance=7.5, + seed=123, + warmup=True, + ) + images = pipeline.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") + + pipeline.teardown() + + from tensorrt import __version__ as trt_version + + from onnxruntime import __version__ as ort_version + + return { + "model_name": pipeline_info.name(), + "engine": "onnxruntime", + "version": ort_version, + "provider": f"tensorrt({trt_version})", + "directory": engine_dir, + "height": height, + "width": width, + "steps": steps, + "batch_size": batch_size, + "batch_count": batch_count, + "num_prompts": num_prompts, + "average_latency": sum(latency_list) / len(latency_list), + "median_latency": statistics.median(latency_list), + "first_run_memory_MB": first_run_memory, + "second_run_memory_MB": second_run_memory, + "disable_safety_checker": disable_safety_checker, + "enable_cuda_graph": use_cuda_graph, + } + + +def run_tensorrt_static( + work_dir: str, + version: str, model_name: str, batch_size: int, disable_safety_checker: bool, @@ -618,32 +645,79 @@ def run_tensorrt( start_memory, memory_monitor_type, max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph: bool = True, ): - from diffusers import DDIMScheduler - from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline + print("[I] Initializing TensorRT accelerated StableDiffusionXL txt2img pipeline (static input shape)") + + from cuda import cudart + + # Register TensorRT plugins + from trt_utilities import init_trt_plugins + + init_trt_plugins() assert batch_size <= max_batch_size - scheduler = DDIMScheduler.from_pretrained(model_name, subfolder="scheduler") - pipe = StableDiffusionPipeline.from_pretrained( - model_name, - custom_pipeline="stable_diffusion_tensorrt_txt2img", - revision="fp16", - torch_dtype=torch.float16, - scheduler=scheduler, - requires_safety_checker=not disable_safety_checker, - image_height=height, - image_width=width, + from diffusion_models import PipelineInfo + + pipeline_info = PipelineInfo(version) + + from engine_builder import EngineType, get_engine_paths + from pipeline_txt2img import Txt2ImgPipeline + + engine_type = EngineType.TRT + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( + work_dir, pipeline_info, engine_type + ) + + # Initialize pipeline + pipeline = Txt2ImgPipeline( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, max_batch_size=max_batch_size, + use_cuda_graph=True, + engine_type=engine_type, ) - # re-use cached folder to save ONNX models and TensorRT Engines - pipe.set_cached_folder(model_name, revision="fp16") + # Load TensorRT engines and pytorch modules + pipeline.backend.load_engines( + engine_dir=engine_dir, + framework_model_dir=framework_model_dir, + onnx_dir=onnx_dir, + onnx_opset=17, + opt_batch_size=batch_size, + opt_image_height=height, + opt_image_width=width, + force_export=False, + force_optimize=False, + force_build=False, + static_batch=True, + static_shape=True, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=timing_cache, + onnx_refit_dir=None, + ) - pipe = pipe.to("cuda") + # activate engines + max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + pipeline.backend.activate_engines(shared_device_memory) + + # Here we use static batch and image size, so the resource allocation only need done once. + # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency. + pipeline.load_resources(height, width, batch_size) def warmup(): - pipe(["warm up"] * batch_size, num_inference_steps=steps) + pipeline.run( + ["warm up"] * batch_size, ["negative"] * batch_size, height, width, denoising_steps=steps, warmup=True + ) # Run warm up, and measure GPU memory of two runs # The first run has algo search so it might need more memory @@ -655,28 +729,225 @@ def warmup(): image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, disable_safety_checker) latency_list = [] - prompts = example_prompts() + prompts, negative_prompt = example_prompts() for i, prompt in enumerate(prompts): if i >= num_prompts: break for j in range(batch_count): inference_start = time.time() - images = pipe( + # Use warmup mode here since non-warmup mode will save image to disk. + images, pipeline_time = pipeline.run( [prompt] * batch_size, - num_inference_steps=steps, - ).images + [negative_prompt] * batch_size, + height, + width, + denoising_steps=steps, + guidance=7.5, + seed=123, + warmup=True, + ) + images = pipeline.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare inference_end = time.time() latency = inference_end - inference_start latency_list.append(latency) - print(f"Inference took {latency:.3f} seconds") + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") for k, image in enumerate(images): image.save(f"{image_filename_prefix}_{i}_{j}_{k}.jpg") - from tensorrt import __version__ as trt_version + pipeline.teardown() + + import tensorrt as trt + + return { + "engine": "tensorrt", + "version": trt.__version__, + "provider": "default", + "height": height, + "width": width, + "steps": steps, + "batch_size": batch_size, + "batch_count": batch_count, + "num_prompts": num_prompts, + "average_latency": sum(latency_list) / len(latency_list), + "median_latency": statistics.median(latency_list), + "first_run_memory_MB": first_run_memory, + "second_run_memory_MB": second_run_memory, + "enable_cuda_graph": use_cuda_graph, + } + + +def run_tensorrt_static_xl( + work_dir: str, + version: str, + batch_size: int, + disable_safety_checker: bool, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, + start_memory, + memory_monitor_type, + max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph=True, +): + print("[I] Initializing TensorRT accelerated StableDiffusionXL txt2img pipeline (static input shape)") + + import tensorrt as trt + from cuda import cudart + from trt_utilities import init_trt_plugins + + # Validate image dimensions + image_height = height + image_width = width + if image_height % 8 != 0 or image_width % 8 != 0: + raise ValueError( + f"Image height and width have to be divisible by 8 but specified as: {image_height} and {image_width}." + ) + + # Register TensorRT plugins + init_trt_plugins() + + assert batch_size <= max_batch_size + + from diffusion_models import PipelineInfo + from engine_builder import EngineType, get_engine_paths + + def init_pipeline(pipeline_class, pipeline_info): + engine_type = EngineType.TRT + + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( + work_dir, pipeline_info, engine_type + ) + + # Initialize pipeline + pipeline = pipeline_class( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=use_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + pipeline.backend.load_engines( + engine_dir=engine_dir, + framework_model_dir=framework_model_dir, + onnx_dir=onnx_dir, + onnx_opset=17, + opt_batch_size=batch_size, + opt_image_height=height, + opt_image_width=width, + force_export=False, + force_optimize=False, + force_build=False, + static_batch=True, + static_shape=True, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=timing_cache, + onnx_refit_dir=None, + ) + return pipeline + + from pipeline_img2img_xl import Img2ImgXLPipeline + from pipeline_txt2img_xl import Txt2ImgXLPipeline + + base_pipeline_info = PipelineInfo(version) + demo_base = init_pipeline(Txt2ImgXLPipeline, base_pipeline_info) + + refiner_pipeline_info = PipelineInfo(version, is_sd_xl_refiner=True) + demo_refiner = init_pipeline(Img2ImgXLPipeline, refiner_pipeline_info) + + max_device_memory = max(demo_base.backend.max_device_memory(), demo_refiner.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + demo_base.backend.activate_engines(shared_device_memory) + demo_refiner.backend.activate_engines(shared_device_memory) + + # Here we use static batch and image size, so the resource allocation only need done once. + # For dynamic batch and image size, some cost (like memory allocation) shall be included in latency. + demo_base.load_resources(image_height, image_width, batch_size) + demo_refiner.load_resources(image_height, image_width, batch_size) + + def run_sd_xl_inference(prompt, negative_prompt, seed=None, warmup=False): + images, time_base = demo_base.run( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + return_type="latents", + ) + + images, time_refiner = demo_refiner.run( + prompt, + negative_prompt, + images, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + ) + return images, time_base + time_refiner + + def warmup(): + run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size, warmup=True) + + # Run warm up, and measure GPU memory of two runs + # The first run has algo search so it might need more memory + first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + + warmup() + + model_name = refiner_pipeline_info.name() + image_filename_prefix = get_image_filename_prefix("trt", model_name, batch_size, disable_safety_checker) + + latency_list = [] + prompts, negative_prompt = example_prompts() + for i, prompt in enumerate(prompts): + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + if nvtx_profile: + cudart.cudaProfilerStart() + images, pipeline_time = run_sd_xl_inference( + [prompt] * batch_size, [negative_prompt] * batch_size, seed=123, warmup=True + ) + if nvtx_profile: + cudart.cudaProfilerStop() + images = demo_refiner.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + for k, image in enumerate(images): + image.save(f"{image_filename_prefix}_{i}_{j}_{k}.png") + + demo_base.teardown() + demo_refiner.teardown() return { + "model_name": model_name, "engine": "tensorrt", - "version": trt_version, + "version": trt.__version__, "provider": "default", "height": height, "width": width, @@ -688,7 +959,178 @@ def warmup(): "median_latency": statistics.median(latency_list), "first_run_memory_MB": first_run_memory, "second_run_memory_MB": second_run_memory, - "enable_cuda_graph": False, + "enable_cuda_graph": use_cuda_graph, + } + + +def run_ort_trt_xl( + work_dir: str, + version: str, + batch_size: int, + disable_safety_checker: bool, + height: int, + width: int, + steps: int, + num_prompts: int, + batch_count: int, + start_memory, + memory_monitor_type, + max_batch_size: int, + nvtx_profile: bool = False, + use_cuda_graph=True, +): + from cuda import cudart + + # Validate image dimensions + image_height = height + image_width = width + if image_height % 8 != 0 or image_width % 8 != 0: + raise ValueError( + f"Image height and width have to be divisible by 8 but specified as: {image_height} and {image_width}." + ) + + assert batch_size <= max_batch_size + + from engine_builder import EngineType, get_engine_paths + + def init_pipeline(pipeline_class, pipeline_info): + engine_type = EngineType.ORT_TRT + + onnx_dir, engine_dir, output_dir, framework_model_dir, _ = get_engine_paths( + work_dir, pipeline_info, engine_type + ) + + # Initialize pipeline + pipeline = pipeline_class( + pipeline_info, + scheduler="DDIM", + output_dir=output_dir, + hf_token=None, + verbose=False, + nvtx_profile=nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=use_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + pipeline.backend.build_engines( + engine_dir, + framework_model_dir, + onnx_dir, + 17, + opt_image_height=height, + opt_image_width=width, + opt_batch_size=batch_size, + force_engine_rebuild=False, + static_batch=True, + static_image_shape=True, + max_workspace_size=0, + device_id=torch.cuda.current_device(), # TODO: might not work with CUDA_VISIBLE_DEVICES + ) + return pipeline + + from diffusion_models import PipelineInfo + from pipeline_img2img_xl import Img2ImgXLPipeline + from pipeline_txt2img_xl import Txt2ImgXLPipeline + + base_pipeline_info = PipelineInfo(version) + demo_base = init_pipeline(Txt2ImgXLPipeline, base_pipeline_info) + + refiner_pipeline_info = PipelineInfo(version, is_sd_xl_refiner=True) + demo_refiner = init_pipeline(Img2ImgXLPipeline, refiner_pipeline_info) + + demo_base.load_resources(image_height, image_width, batch_size) + demo_refiner.load_resources(image_height, image_width, batch_size) + + def run_sd_xl_inference(prompt, negative_prompt, seed=None, warmup=False): + images, time_base = demo_base.run( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + return_type="latents", + ) + images, time_refiner = demo_refiner.run( + prompt, + negative_prompt, + images, + image_height, + image_width, + denoising_steps=steps, + guidance=5.0, + warmup=warmup, + seed=seed, + ) + return images, time_base + time_refiner + + def warmup(): + run_sd_xl_inference(["warm up"] * batch_size, ["negative"] * batch_size, warmup=True) + + # Run warm up, and measure GPU memory of two runs + # The first run has algo search so it might need more memory + first_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + second_run_memory = measure_gpu_memory(memory_monitor_type, warmup, start_memory) + + warmup() + + model_name = refiner_pipeline_info.name() + image_filename_prefix = get_image_filename_prefix("ort_trt", model_name, batch_size, disable_safety_checker) + + latency_list = [] + prompts, negative_prompt = example_prompts() + for i, prompt in enumerate(prompts): + if i >= num_prompts: + break + for j in range(batch_count): + inference_start = time.time() + # Use warmup mode here since non-warmup mode will save image to disk. + if nvtx_profile: + cudart.cudaProfilerStart() + images, pipeline_time = run_sd_xl_inference( + [prompt] * batch_size, [negative_prompt] * batch_size, seed=123, warmup=True + ) + if nvtx_profile: + cudart.cudaProfilerStop() + images = demo_refiner.to_pil_image( + images + ) # include image conversion time to pil image for apple-to-apple compare + inference_end = time.time() + latency = inference_end - inference_start + latency_list.append(latency) + print(f"End2End took {latency:.3f} seconds. Inference latency: {pipeline_time:.1f} ms") + for k, image in enumerate(images): + filename = f"{image_filename_prefix}_{i}_{j}_{k}.png" + image.save(filename) + print("Image saved to", filename) + + demo_base.teardown() + demo_refiner.teardown() + + from tensorrt import __version__ as trt_version + + from onnxruntime import __version__ as ort_version + + return { + "model_name": model_name, + "engine": "onnxruntime", + "version": ort_version, + "provider": f"tensorrt{trt_version})", + "height": height, + "width": width, + "steps": steps, + "batch_size": batch_size, + "batch_count": batch_count, + "num_prompts": num_prompts, + "average_latency": sum(latency_list) / len(latency_list), + "median_latency": statistics.median(latency_list), + "first_run_memory_MB": first_run_memory, + "second_run_memory_MB": second_run_memory, + "enable_cuda_graph": use_cuda_graph, } @@ -808,6 +1250,15 @@ def parse_arguments(): help="Directory of saved onnx pipeline. It could be the output directory of optimize_pipeline.py.", ) + parser.add_argument( + "-w", + "--work_dir", + required=False, + type=str, + default=".", + help="Root directory to save exported onnx models, built engines etc.", + ) + parser.add_argument( "--enable_safety_checker", required=False, @@ -922,28 +1373,31 @@ def main(): args = parse_arguments() print(args) - if args.enable_cuda_graph: - if not (args.engine == "onnxruntime" and args.provider in ["cuda", "tensorrt"] and args.pipeline is None): - raise ValueError("The stable diffusion pipeline does not support CUDA graph.") + if args.engine == "onnxruntime": + if args.version in ["2.1"]: + # Set a flag to avoid overflow in attention, which causes black image output in SD 2.1 model. + # The environment variables shall be set before the first run of Attention or MultiHeadAttention operator. + os.environ["ORT_DISABLE_TRT_FLASH_ATTENTION"] = "1" from packaging import version from onnxruntime import __version__ as ort_version - if version.parse(ort_version) < version.parse("1.16"): - raise ValueError( - "CUDA graph requires ONNX Runtime 1.16. You can install nightly like the following:\n" - " pip uninstall onnxruntime-gpu\n" - " pip install ort-nightly-gpu -i https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/ORT-Nightly/pypi/simple/" - ) + if version.parse(ort_version) == version.parse("1.16.0"): + # ORT 1.16 has a bug that might trigger Attention RuntimeError when latest fusion script is applied on clip model. + # The walkaround is to enable fused causal attention, or disable Attention fusion for clip model. + os.environ["ORT_ENABLE_FUSED_CAUSAL_ATTENTION"] = "1" + + if args.enable_cuda_graph: + if not (args.engine == "onnxruntime" and args.provider in ["cuda", "tensorrt"] and args.pipeline is None): + raise ValueError("The stable diffusion pipeline does not support CUDA graph.") + + if version.parse(ort_version) < version.parse("1.16"): + raise ValueError("CUDA graph requires ONNX Runtime 1.16 or later") coloredlogs.install(fmt="%(funcName)20s: %(message)s") - memory_monitor_type = None - if args.provider in ["cuda", "tensorrt"]: - memory_monitor_type = CudaMemoryMonitor - elif args.provider == "rocm": - memory_monitor_type = RocmMemoryMonitor + memory_monitor_type = "rocm" if args.provider == "rocm" else "cuda" start_memory = measure_gpu_memory(memory_monitor_type, None) print("GPU memory used before loading models:", start_memory) @@ -951,89 +1405,157 @@ def main(): sd_model = SD_MODELS[args.version] provider = PROVIDERS[args.provider] if args.engine == "onnxruntime" and args.provider == "tensorrt": - result = run_ort_trt( - sd_model, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.max_trt_batch_size, - args.enable_cuda_graph, - ) + if "xl" in args.version: + print("Testing Txt2ImgXLPipeline with static input shape. Backend is ORT TensorRT EP.") + result = run_ort_trt_xl( + work_dir=args.work_dir, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=True, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, + ) + elif args.tuning: + print( + "Testing OnnxruntimeTensorRTStableDiffusionPipeline with {}.".format( + "static input shape" if args.enable_cuda_graph else "dynamic batch size" + ) + ) + result = run_ort_trt( + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + enable_cuda_graph=args.enable_cuda_graph, + ) + else: + print("Testing Txt2ImgPipeline with static input shape. Backend is ORT TensorRT EP.") + result = run_ort_trt_static( + work_dir=args.work_dir, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, + ) + elif args.engine == "onnxruntime" and provider == "CUDAExecutionProvider" and args.pipeline is None: - print("Pipeline is not specified. Trying export and optimize onnx models...") + print( + "Testing OnnxruntimeCudaStableDiffusionPipeline with {} input shape. Backend is ORT CUDA EP.".format( + "static" if args.enable_cuda_graph else "dynamic" + ) + ) result = export_and_run_ort( - sd_model, - provider, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.enable_cuda_graph, + version=args.version, + provider=provider, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + enable_cuda_graph=args.enable_cuda_graph, ) elif args.engine == "onnxruntime": assert args.pipeline and os.path.isdir( args.pipeline ), "--pipeline should be specified for the directory of ONNX models" - - if args.version in ["2.1"]: - # Set a flag to avoid overflow in attention, which causes black image output in SD 2.1 model - # This shall be done before the first inference run. - os.environ["ORT_DISABLE_TRT_FLASH_ATTENTION"] = "1" - + print(f"Testing diffusers StableDiffusionPipeline with {provider} provider and tuning={args.tuning}") result = run_ort( - sd_model, - args.pipeline, - provider, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.tuning, + model_name=sd_model, + directory=args.pipeline, + provider=provider, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + tuning=args.tuning, + ) + elif args.engine == "tensorrt" and "xl" in args.version: + print("Testing Txt2ImgXLPipeline with static input shape. Backend is TensorRT.") + result = run_tensorrt_static_xl( + work_dir=args.work_dir, + version=args.version, + batch_size=args.batch_size, + disable_safety_checker=True, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, ) elif args.engine == "tensorrt": - result = run_tensorrt( - sd_model, - args.batch_size, - not args.enable_safety_checker, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, - args.max_trt_batch_size, + print("Testing Txt2ImgPipeline with static input shape. Backend is TensorRT.") + result = run_tensorrt_static( + work_dir=args.work_dir, + version=args.version, + model_name=sd_model, + batch_size=args.batch_size, + disable_safety_checker=True, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, + max_batch_size=args.max_trt_batch_size, + nvtx_profile=False, + use_cuda_graph=args.enable_cuda_graph, ) else: + print( + f"Testing Txt2ImgPipeline with dynamic input shape. Backend is PyTorch: compile={args.enable_torch_compile}, xformers={args.use_xformers}." + ) result = run_torch( - sd_model, - args.batch_size, - not args.enable_safety_checker, - args.enable_torch_compile, - args.use_xformers, - args.height, - args.width, - args.steps, - args.num_prompts, - args.batch_count, - start_memory, - memory_monitor_type, + model_name=sd_model, + batch_size=args.batch_size, + disable_safety_checker=not args.enable_safety_checker, + enable_torch_compile=args.enable_torch_compile, + use_xformers=args.use_xformers, + height=args.height, + width=args.width, + steps=args.steps, + num_prompts=args.num_prompts, + batch_count=args.batch_count, + start_memory=start_memory, + memory_monitor_type=memory_monitor_type, ) print(result) @@ -1068,8 +1590,9 @@ def main(): if __name__ == "__main__": + import traceback + try: main() - except Exception as e: - tb = sys.exc_info() - print(e.with_traceback(tb[2])) + except Exception: + traceback.print_exception(*sys.exc_info()) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py new file mode 100644 index 0000000000000..f6e00063a6391 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img.py @@ -0,0 +1,94 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import coloredlogs +from cuda import cudart +from demo_utils import init_pipeline, parse_arguments, repeat_prompt +from diffusion_models import PipelineInfo +from engine_builder import EngineType, get_engine_type +from pipeline_txt2img import Txt2ImgPipeline + +if __name__ == "__main__": + coloredlogs.install(fmt="%(funcName)20s: %(message)s") + + args = parse_arguments(is_xl=False, description="Options for Stable Diffusion Demo") + prompt, negative_prompt = repeat_prompt(args) + + image_height = args.height + image_width = args.width + + # Register TensorRT plugins + engine_type = get_engine_type(args.engine) + if engine_type == EngineType.TRT: + from trt_utilities import init_trt_plugins + + init_trt_plugins() + + max_batch_size = 16 + if engine_type != EngineType.ORT_CUDA and (args.build_dynamic_shape or image_height > 512 or image_width > 512): + max_batch_size = 4 + + batch_size = len(prompt) + if batch_size > max_batch_size: + raise ValueError( + f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4" + ) + + pipeline_info = PipelineInfo(args.version) + pipeline = init_pipeline(Txt2ImgPipeline, pipeline_info, engine_type, args, max_batch_size, batch_size) + + if engine_type == EngineType.TRT: + max_device_memory = max(pipeline.backend.max_device_memory(), pipeline.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + pipeline.backend.activate_engines(shared_device_memory) + + pipeline.load_resources(image_height, image_width, batch_size) + + def run_inference(warmup=False): + return pipeline.run( + prompt, + negative_prompt, + image_height, + image_width, + warmup=warmup, + denoising_steps=args.denoising_steps, + guidance=args.guidance, + seed=args.seed, + return_type="images", + ) + + if not args.disable_cuda_graph: + # inference once to get cuda graph + _image, _latency = run_inference(warmup=True) + + print("[I] Warming up ..") + for _ in range(args.num_warmup_runs): + _image, _latency = run_inference(warmup=True) + + print("[I] Running StableDiffusion pipeline") + if args.nvtx_profile: + cudart.cudaProfilerStart() + _image, _latency = run_inference(warmup=False) + if args.nvtx_profile: + cudart.cudaProfilerStop() + + pipeline.teardown() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py new file mode 100644 index 0000000000000..c3a2e4e293cc8 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_txt2img_xl.py @@ -0,0 +1,129 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import coloredlogs +from cuda import cudart +from demo_utils import init_pipeline, parse_arguments, repeat_prompt +from diffusion_models import PipelineInfo +from engine_builder import EngineType, get_engine_type +from pipeline_img2img_xl import Img2ImgXLPipeline +from pipeline_txt2img_xl import Txt2ImgXLPipeline + +if __name__ == "__main__": + coloredlogs.install(fmt="%(funcName)20s: %(message)s") + args = parse_arguments(is_xl=True, description="Options for Stable Diffusion XL Demo") + prompt, negative_prompt = repeat_prompt(args) + + image_height = args.height + image_width = args.width + + # Register TensorRT plugins + engine_type = get_engine_type(args.engine) + if engine_type == EngineType.TRT: + from trt_utilities import init_trt_plugins + + init_trt_plugins() + + max_batch_size = 16 + if args.build_dynamic_shape or image_height > 512 or image_width > 512: + max_batch_size = 4 + + batch_size = len(prompt) + if batch_size > max_batch_size: + raise ValueError( + f"Batch size {len(prompt)} is larger than allowed {max_batch_size}. If dynamic shape is used, then maximum batch size is 4" + ) + + base_info = PipelineInfo(args.version, use_vae_in_xl_base=not args.enable_refiner) + base = init_pipeline(Txt2ImgXLPipeline, base_info, engine_type, args, max_batch_size, batch_size) + + if args.enable_refiner: + refiner_info = PipelineInfo(args.version, is_sd_xl_refiner=True) + refiner = init_pipeline(Img2ImgXLPipeline, refiner_info, engine_type, args, max_batch_size, batch_size) + + if engine_type == EngineType.TRT: + max_device_memory = max(base.backend.max_device_memory(), refiner.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + base.backend.activate_engines(shared_device_memory) + refiner.backend.activate_engines(shared_device_memory) + + base.load_resources(image_height, image_width, batch_size) + refiner.load_resources(image_height, image_width, batch_size) + else: + if engine_type == EngineType.TRT: + max_device_memory = max(base.backend.max_device_memory(), base.backend.max_device_memory()) + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + base.backend.activate_engines(shared_device_memory) + + base.load_resources(image_height, image_width, batch_size) + + def run_sd_xl_inference(enable_refiner: bool, warmup=False): + images, time_base = base.run( + prompt, + negative_prompt, + image_height, + image_width, + warmup=warmup, + denoising_steps=args.denoising_steps, + guidance=args.guidance, + seed=args.seed, + return_type="latents" if enable_refiner else "images", + ) + + if enable_refiner: + images, time_refiner = refiner.run( + prompt, + negative_prompt, + images, + image_height, + image_width, + warmup=warmup, + denoising_steps=args.denoising_steps, + guidance=args.guidance, + seed=args.seed, + ) + return images, time_base + time_refiner + else: + return images, time_base + + if not args.disable_cuda_graph: + # inference once to get cuda graph + images, _ = run_sd_xl_inference(args.enable_refiner, warmup=True) + + print("[I] Warming up ..") + for _ in range(args.num_warmup_runs): + images, _ = run_sd_xl_inference(args.enable_refiner, warmup=True) + + print("[I] Running StableDiffusion XL pipeline") + if args.nvtx_profile: + cudart.cudaProfilerStart() + images, pipeline_time = run_sd_xl_inference(args.enable_refiner, warmup=False) + if args.nvtx_profile: + cudart.cudaProfilerStop() + + base.teardown() + + if args.enable_refiner: + print("|------------|--------------|") + print("| {:^10} | {:>9.2f} ms |".format("e2e", pipeline_time)) + print("|------------|--------------|") + refiner.teardown() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py new file mode 100644 index 0000000000000..5fdafc463f4e2 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/demo_utils.py @@ -0,0 +1,255 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import argparse + +import torch +from diffusion_models import PipelineInfo +from engine_builder import EngineType, get_engine_paths + + +class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter): + pass + + +def parse_arguments(is_xl: bool, description: str): + parser = argparse.ArgumentParser(description=description, formatter_class=RawTextArgumentDefaultsHelpFormatter) + + parser.add_argument( + "--engine", + type=str, + default="ORT_TRT", + choices=["ORT_TRT", "TRT"], + help="Backend engine. Default is OnnxRuntime CUDA execution provider.", + ) + + supported_versions = PipelineInfo.supported_versions(is_xl) + parser.add_argument( + "--version", + type=str, + default=supported_versions[-1] if is_xl else "1.5", + choices=supported_versions, + help="Version of Stable Diffusion" + (" XL." if is_xl else "."), + ) + + parser.add_argument( + "--height", + type=int, + default=1024 if is_xl else 512, + help="Height of image to generate (must be multiple of 8).", + ) + parser.add_argument( + "--width", type=int, default=1024 if is_xl else 512, help="Height of image to generate (must be multiple of 8)." + ) + + parser.add_argument( + "--scheduler", + type=str, + default="DDIM", + choices=["DDIM", "EulerA", "UniPC"], + help="Scheduler for diffusion process", + ) + + parser.add_argument( + "--work-dir", + default=".", + help="Root Directory to store torch or ONNX models, built engines and output images etc.", + ) + + parser.add_argument("prompt", nargs="+", help="Text prompt(s) to guide image generation.") + + parser.add_argument( + "--negative-prompt", nargs="*", default=[""], help="Optional negative prompt(s) to guide the image generation." + ) + parser.add_argument( + "--repeat-prompt", + type=int, + default=1, + choices=[1, 2, 4, 8, 16], + help="Number of times to repeat the prompt (batch size multiplier).", + ) + + parser.add_argument( + "--denoising-steps", + type=int, + default=30 if is_xl else 50, + help="Number of denoising steps" + (" in each of base and refiner." if is_xl else "."), + ) + + parser.add_argument( + "--guidance", + type=float, + default=5.0 if is_xl else 7.5, + help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.", + ) + + # ONNX export + parser.add_argument( + "--onnx-opset", + type=int, + default=17, + choices=range(14, 18), + help="Select ONNX opset version to target for exported models.", + ) + parser.add_argument( + "--force-onnx-export", action="store_true", help="Force ONNX export of CLIP, UNET, and VAE models." + ) + parser.add_argument( + "--force-onnx-optimize", action="store_true", help="Force ONNX optimizations for CLIP, UNET, and VAE models." + ) + + # Framework model ckpt + parser.add_argument( + "--framework-model-dir", + default="pytorch_model", + help="Directory for HF saved models. Default is pytorch_model.", + ) + parser.add_argument("--hf-token", type=str, help="HuggingFace API access token for downloading model checkpoints.") + + # Engine build options. + parser.add_argument("--force-engine-build", action="store_true", help="Force rebuilding the TensorRT engine.") + parser.add_argument( + "--build-dynamic-batch", action="store_true", help="Build TensorRT engines to support dynamic batch size." + ) + parser.add_argument( + "--build-dynamic-shape", action="store_true", help="Build TensorRT engines to support dynamic image sizes." + ) + + # Inference related options + parser.add_argument( + "--num-warmup-runs", type=int, default=5, help="Number of warmup runs before benchmarking performance." + ) + parser.add_argument("--nvtx-profile", action="store_true", help="Enable NVTX markers for performance profiling.") + parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.") + parser.add_argument("--disable-cuda-graph", action="store_true", help="Disable cuda graph.") + + # TensorRT only options + group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only") + group.add_argument("--onnx-refit-dir", help="ONNX models to load the weights from.") + group.add_argument( + "--build-enable-refit", action="store_true", help="Enable Refit option in TensorRT engines during build." + ) + group.add_argument( + "--build-preview-features", action="store_true", help="Build TensorRT engines with preview features." + ) + group.add_argument( + "--build-all-tactics", action="store_true", help="Build TensorRT engines using all tactic sources." + ) + + # Pipeline options + if is_xl: + parser.add_argument( + "--enable-refiner", action="store_true", help="Enable refiner and run both base and refiner pipelines." + ) + + args = parser.parse_args() + + # Validate image dimensions + if args.height % 8 != 0 or args.width % 8 != 0: + raise ValueError( + f"Image height and width have to be divisible by 8 but specified as: {args.height} and {args.width}." + ) + + if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph: + print("[I] CUDA Graph is disabled since dynamic input shape is configured.") + args.disable_cuda_graph = True + + print(args) + + return args + + +def repeat_prompt(args): + if not isinstance(args.prompt, list): + raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}") + prompt = args.prompt * args.repeat_prompt + + if not isinstance(args.negative_prompt, list): + raise ValueError( + f"`--negative-prompt` must be of type `str` or `str` list, but is {type(args.negative_prompt)}" + ) + if len(args.negative_prompt) == 1: + negative_prompt = args.negative_prompt * len(prompt) + else: + negative_prompt = args.negative_prompt + + return prompt, negative_prompt + + +def init_pipeline(pipeline_class, pipeline_info, engine_type, args, max_batch_size, batch_size): + onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths( + args.work_dir, pipeline_info, engine_type + ) + + # Initialize demo + pipeline = pipeline_class( + pipeline_info, + scheduler=args.scheduler, + output_dir=output_dir, + hf_token=args.hf_token, + verbose=False, + nvtx_profile=args.nvtx_profile, + max_batch_size=max_batch_size, + use_cuda_graph=not args.disable_cuda_graph, + framework_model_dir=framework_model_dir, + engine_type=engine_type, + ) + + if engine_type == EngineType.ORT_TRT: + # Build TensorRT EP engines and load pytorch modules + pipeline.backend.build_engines( + engine_dir, + framework_model_dir, + onnx_dir, + args.onnx_opset, + opt_image_height=args.height, + opt_image_width=args.height, + opt_batch_size=batch_size, + force_engine_rebuild=args.force_engine_build, + static_batch=not args.build_dynamic_batch, + static_image_shape=not args.build_dynamic_shape, + max_workspace_size=0, + device_id=torch.cuda.current_device(), + ) + elif engine_type == EngineType.TRT: + # Load TensorRT engines and pytorch modules + pipeline.backend.load_engines( + engine_dir, + framework_model_dir, + onnx_dir, + args.onnx_opset, + opt_batch_size=batch_size, + opt_image_height=args.height, + opt_image_width=args.height, + force_export=args.force_onnx_export, + force_optimize=args.force_onnx_optimize, + force_build=args.force_engine_build, + static_batch=not args.build_dynamic_batch, + static_shape=not args.build_dynamic_shape, + enable_refit=args.build_enable_refit, + enable_preview=args.build_preview_features, + enable_all_tactics=args.build_all_tactics, + timing_cache=timing_cache, + onnx_refit_dir=args.onnx_refit_dir, + ) + + return pipeline diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py new file mode 100644 index 0000000000000..951cd66005f4c --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_models.py @@ -0,0 +1,858 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from stable_diffusion_tensorrt_txt2img.py in diffusers and TensorRT demo diffusion, +# which has the following license: +# +# Copyright 2023 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import tempfile +from typing import List, Optional + +import onnx +import onnx_graphsurgeon as gs +import torch +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from onnx import GraphProto, ModelProto, shape_inference +from ort_optimizer import OrtStableDiffusionOptimizer +from polygraphy.backend.onnx.loader import fold_constants +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from onnxruntime.transformers.onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class TrtOptimizer: + def __init__(self, onnx_graph): + self.graph = gs.import_onnx(onnx_graph) + + def cleanup(self): + self.graph.cleanup().toposort() + + def get_optimized_onnx_graph(self): + return gs.export_onnx(self.graph) + + def select_outputs(self, keep, names=None): + self.graph.outputs = [self.graph.outputs[o] for o in keep] + if names: + for i, name in enumerate(names): + self.graph.outputs[i].name = name + + def fold_constants(self): + onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) + self.graph = gs.import_onnx(onnx_graph) + + def infer_shapes(self): + onnx_graph = gs.export_onnx(self.graph) + if onnx_graph.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF: + with tempfile.TemporaryDirectory() as temp_dir: + input_onnx_path = os.path.join(temp_dir, "model.onnx") + onnx.save_model( + onnx_graph, + input_onnx_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + output_onnx_path = os.path.join(temp_dir, "model_with_shape.onnx") + onnx.shape_inference.infer_shapes_path(input_onnx_path, output_onnx_path) + onnx_graph = onnx.load(output_onnx_path) + else: + onnx_graph = shape_inference.infer_shapes(onnx_graph) + + self.graph = gs.import_onnx(onnx_graph) + + +class PipelineInfo: + def __init__( + self, version: str, is_inpaint: bool = False, is_sd_xl_refiner: bool = False, use_vae_in_xl_base=False + ): + self.version = version + self._is_inpaint = is_inpaint + self._is_sd_xl_refiner = is_sd_xl_refiner + self._use_vae_in_xl_base = use_vae_in_xl_base + + if is_sd_xl_refiner: + assert self.is_sd_xl() + + def is_inpaint(self) -> bool: + return self._is_inpaint + + def is_sd_xl(self) -> bool: + return "xl" in self.version + + def is_sd_xl_base(self) -> bool: + return self.is_sd_xl() and not self._is_sd_xl_refiner + + def is_sd_xl_refiner(self) -> bool: + return self.is_sd_xl() and self._is_sd_xl_refiner + + def use_safetensors(self) -> bool: + return self.is_sd_xl() + + def stages(self) -> List[str]: + if self.is_sd_xl_base(): + return ["clip", "clip2", "unetxl"] + (["vae"] if self._use_vae_in_xl_base else []) + + if self.is_sd_xl_refiner(): + return ["clip2", "unetxl", "vae"] + + return ["clip", "unet", "vae"] + + def vae_scaling_factor(self) -> float: + return 0.13025 if self.is_sd_xl() else 0.18215 + + @staticmethod + def supported_versions(is_xl: bool): + return ["xl-1.0"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base"] + + def name(self) -> str: + if self.version == "1.4": + if self.is_inpaint(): + return "runwayml/stable-diffusion-inpainting" + else: + return "CompVis/stable-diffusion-v1-4" + elif self.version == "1.5": + if self.is_inpaint(): + return "runwayml/stable-diffusion-inpainting" + else: + return "runwayml/stable-diffusion-v1-5" + elif self.version == "2.0-base": + if self.is_inpaint(): + return "stabilityai/stable-diffusion-2-inpainting" + else: + return "stabilityai/stable-diffusion-2-base" + elif self.version == "2.0": + if self.is_inpaint(): + return "stabilityai/stable-diffusion-2-inpainting" + else: + return "stabilityai/stable-diffusion-2" + elif self.version == "2.1": + return "stabilityai/stable-diffusion-2-1" + elif self.version == "2.1-base": + return "stabilityai/stable-diffusion-2-1-base" + elif self.version == "xl-1.0": + if self.is_sd_xl_refiner(): + return "stabilityai/stable-diffusion-xl-refiner-1.0" + else: + return "stabilityai/stable-diffusion-xl-base-1.0" + + raise ValueError(f"Incorrect version {self.version}") + + def short_name(self) -> str: + return self.name().split("/")[-1].replace("stable-diffusion", "sd") + + def clip_embedding_dim(self): + # TODO: can we read from config instead + if self.version in ("1.4", "1.5"): + return 768 + elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base"): + return 1024 + elif self.version in ("xl-1.0") and self.is_sd_xl_base(): + return 768 + else: + raise ValueError(f"Invalid version {self.version}") + + def clipwithproj_embedding_dim(self): + if self.version in ("xl-1.0"): + return 1280 + else: + raise ValueError(f"Invalid version {self.version}") + + def unet_embedding_dim(self): + if self.version in ("1.4", "1.5"): + return 768 + elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base"): + return 1024 + elif self.version in ("xl-1.0") and self.is_sd_xl_base(): + return 2048 + elif self.version in ("xl-1.0") and self.is_sd_xl_refiner(): + return 1280 + else: + raise ValueError(f"Invalid version {self.version}") + + +class BaseModel: + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + fp16: bool = False, + max_batch_size: int = 16, + embedding_dim: int = 768, + text_maxlen: int = 77, + ): + self.name = self.__class__.__name__ + + self.pipeline_info = pipeline_info + + self.model = model + self.fp16 = fp16 + self.device = device + + self.min_batch = 1 + self.max_batch = max_batch_size + self.min_image_shape = 256 # min image resolution: 256x256 + self.max_image_shape = 1024 # max image resolution: 1024x1024 + self.min_latent_shape = self.min_image_shape // 8 + self.max_latent_shape = self.max_image_shape // 8 + + self.embedding_dim = embedding_dim + self.text_maxlen = text_maxlen + + def get_ort_optimizer(self): + model_name_to_model_type = { + "CLIP": "clip", + "UNet": "unet", + "VAE": "vae", + "UNetXL": "unet", + "CLIPWithProj": "clip", + } + model_type = model_name_to_model_type[self.name] + return OrtStableDiffusionOptimizer(model_type) + + def get_model(self): + return self.model + + def from_pretrained(self, model_class, framework_model_dir, hf_token, subfolder, **kwargs): + model_dir = os.path.join(framework_model_dir, self.pipeline_info.name(), subfolder) + + if not os.path.exists(model_dir): + model = model_class.from_pretrained( + self.pipeline_info.name(), + subfolder=subfolder, + use_safetensors=self.pipeline_info.use_safetensors(), + use_auth_token=hf_token, + **kwargs, + ).to(self.device) + model.save_pretrained(model_dir) + else: + print(f"Load {self.name} pytorch model from: {model_dir}") + + model = model_class.from_pretrained(model_dir).to(self.device) + return model + + def load_model(self, framework_model_dir: str, hf_token: str, subfolder: str): + pass + + def get_input_names(self): + pass + + def get_output_names(self): + pass + + def get_dynamic_axes(self): + return None + + def get_sample_input(self, batch_size, image_height, image_width): + pass + + def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape): + """For TensorRT EP""" + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + + profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}" + + if self.name != "CLIP": + if static_image_shape: + profile_id += f"_h_{image_height}_w_{image_width}" + else: + profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}" + + return profile_id + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + """For TensorRT""" + return None + + def get_shape_dict(self, batch_size, image_height, image_width): + return None + + def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True): + optimizer = self.get_ort_optimizer() + optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16) + + def optimize_trt(self, input_onnx_path, optimized_onnx_path): + onnx_graph = onnx.load(input_onnx_path) + opt = TrtOptimizer(onnx_graph) + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.cleanup() + onnx_opt_graph = opt.get_optimized_onnx_graph() + + if onnx_opt_graph.ByteSize() > onnx.checker.MAXIMUM_PROTOBUF: + onnx.save_model( + onnx_opt_graph, + optimized_onnx_path, + save_as_external_data=True, + all_tensors_to_one_file=True, + convert_attribute=False, + ) + else: + onnx.save(onnx_opt_graph, optimized_onnx_path) + + def check_dims(self, batch_size, image_height, image_width): + assert batch_size >= self.min_batch and batch_size <= self.max_batch + assert image_height % 8 == 0 or image_width % 8 == 0 + latent_height = image_height // 8 + latent_width = image_width // 8 + assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape + assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape + return (latent_height, latent_width) + + def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape): + min_batch = batch_size if static_batch else self.min_batch + max_batch = batch_size if static_batch else self.max_batch + latent_height = image_height // 8 + latent_width = image_width // 8 + min_image_height = image_height if static_image_shape else self.min_image_shape + max_image_height = image_height if static_image_shape else self.max_image_shape + min_image_width = image_width if static_image_shape else self.min_image_shape + max_image_width = image_width if static_image_shape else self.max_image_shape + min_latent_height = latent_height if static_image_shape else self.min_latent_shape + max_latent_height = latent_height if static_image_shape else self.max_latent_shape + min_latent_width = latent_width if static_image_shape else self.min_latent_shape + max_latent_width = latent_width if static_image_shape else self.max_latent_shape + return ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) + + +class CLIP(BaseModel): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + max_batch_size, + embedding_dim: int = 0, + clip_skip=0, + ): + super().__init__( + pipeline_info, + model=model, + device=device, + max_batch_size=max_batch_size, + embedding_dim=embedding_dim if embedding_dim > 0 else pipeline_info.clip_embedding_dim(), + ) + self.output_hidden_state = pipeline_info.is_sd_xl() + + # see https://github.com/huggingface/diffusers/pull/5057 for more information of clip_skip. + # Clip_skip=1 means that the output of the pre-final layer will be used for computing the prompt embeddings. + self.clip_skip = clip_skip + + def get_input_names(self): + return ["input_ids"] + + def get_output_names(self): + # The exported onnx model has no hidden_state. For SD-XL, We will add hidden_state to optimized onnx model. + return ["text_embeddings"] + + def get_dynamic_axes(self): + return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + self.check_dims(batch_size, image_height, image_width) + min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( + batch_size, image_height, image_width, static_batch, static_image_shape + ) + return { + "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + output = { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), + } + + if self.output_hidden_state: + output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim) + + return output + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return (torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device),) + + def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path): + graph: GraphProto = model.graph + hidden_layers = -1 + for i in range(len(graph.node)): + for j in range(len(graph.node[i].output)): + name = graph.node[i].output[j] + if "layers" in name: + hidden_layers = max(int(name.split(".")[1].split("/")[0]), hidden_layers) + + assert self.clip_skip >= 0 and self.clip_skip < hidden_layers + + node_output_name = "/text_model/encoder/layers.{}/Add_1_output_0".format(hidden_layers - 1 - self.clip_skip) + + # search the name in outputs of all node + found = False + for i in range(len(graph.node)): + for j in range(len(graph.node[i].output)): + if graph.node[i].output[j] == node_output_name: + found = True + break + if found: + break + if not found: + raise RuntimeError("Failed to find hidden_states graph output in clip") + + # Insert a Cast (fp32 -> fp16) node so that hidden_states has same data type as the first graph output. + graph_output_name = "hidden_states" + cast_node = onnx.helper.make_node("Cast", inputs=[node_output_name], outputs=[graph_output_name]) + cast_node.attribute.extend([onnx.helper.make_attribute("to", graph.output[0].type.tensor_type.elem_type)]) + + hidden_state = graph.output.add() + hidden_state.CopyFrom( + onnx.helper.make_tensor_value_info( + graph_output_name, + graph.output[0].type.tensor_type.elem_type, + ["B", self.text_maxlen, self.embedding_dim], + ) + ) + + onnx_model = OnnxModel(model) + onnx_model.add_node(cast_node) + onnx_model.save_model_to_file(optimized_onnx_path) + + def optimize_trt(self, input_onnx_path, optimized_onnx_path): + onnx_graph = onnx.load(input_onnx_path) + opt = TrtOptimizer(onnx_graph) + opt.select_outputs([0]) # delete graph output#1 + opt.cleanup() + opt.fold_constants() + opt.infer_shapes() + opt.select_outputs([0], names=["text_embeddings"]) # rename network output + opt.cleanup() + onnx_opt_graph = opt.get_optimized_onnx_graph() + if self.output_hidden_state: + self.add_hidden_states_graph_output(onnx_opt_graph, optimized_onnx_path) + else: + onnx.save(onnx_opt_graph, optimized_onnx_path) + + def load_model(self, framework_model_dir, hf_token, subfolder="text_encoder"): + return self.from_pretrained(CLIPTextModel, framework_model_dir, hf_token, subfolder) + + +class CLIPWithProj(CLIP): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + max_batch_size=16, + clip_skip=0, + ): + super().__init__( + pipeline_info, + model, + device=device, + max_batch_size=max_batch_size, + embedding_dim=pipeline_info.clipwithproj_embedding_dim(), + clip_skip=clip_skip, + ) + + def load_model(self, framework_model_dir, hf_token, subfolder="text_encoder_2"): + return self.from_pretrained(CLIPTextModelWithProjection, framework_model_dir, hf_token, subfolder) + + def get_shape_dict(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + output = { + "input_ids": (batch_size, self.text_maxlen), + "text_embeddings": (batch_size, self.embedding_dim), + } + + if self.output_hidden_state: + output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim) + + return output + + +class UNet(BaseModel): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + fp16=False, # used by TRT + max_batch_size=16, + text_maxlen=77, + unet_dim=4, + ): + super().__init__( + pipeline_info, + model=model, + device=device, + fp16=fp16, + max_batch_size=max_batch_size, + embedding_dim=pipeline_info.unet_embedding_dim(), + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + + def load_model(self, framework_model_dir, hf_token, subfolder="unet"): + options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} + return self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "sample": [ + (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (2 * batch_size, self.unet_dim, latent_height, latent_width), + (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "encoder_hidden_states": [ + (2 * min_batch, self.text_maxlen, self.embedding_dim), + (2 * batch_size, self.text_maxlen, self.embedding_dim), + (2 * max_batch, self.text_maxlen, self.embedding_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "timestep": [1], + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + ) + + +class UNetXL(BaseModel): + def __init__( + self, + pipeline_info: PipelineInfo, + model, + device, + fp16=False, # used by TRT + max_batch_size=16, + text_maxlen=77, + unet_dim=4, + time_dim=6, + ): + super().__init__( + pipeline_info, + model, + device=device, + fp16=fp16, + max_batch_size=max_batch_size, + embedding_dim=pipeline_info.unet_embedding_dim(), + text_maxlen=text_maxlen, + ) + self.unet_dim = unet_dim + self.time_dim = time_dim + + def load_model(self, framework_model_dir, hf_token, subfolder="unet"): + options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 else {} + return self.from_pretrained(UNet2DConditionModel, framework_model_dir, hf_token, subfolder, **options) + + def get_input_names(self): + return ["sample", "timestep", "encoder_hidden_states", "text_embeds", "time_ids"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return { + "sample": {0: "2B", 2: "H", 3: "W"}, + "encoder_hidden_states": {0: "2B"}, + "latent": {0: "2B", 2: "H", 3: "W"}, + "text_embeds": {0: "2B"}, + "time_ids": {0: "2B"}, + } + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "sample": [ + (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), + (2 * batch_size, self.unet_dim, latent_height, latent_width), + (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), + ], + "encoder_hidden_states": [ + (2 * min_batch, self.text_maxlen, self.embedding_dim), + (2 * batch_size, self.text_maxlen, self.embedding_dim), + (2 * max_batch, self.text_maxlen, self.embedding_dim), + ], + "text_embeds": [(2 * min_batch, 1280), (2 * batch_size, 1280), (2 * max_batch, 1280)], + "time_ids": [ + (2 * min_batch, self.time_dim), + (2 * batch_size, self.time_dim), + (2 * max_batch, self.time_dim), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), + "timestep": (1,), + "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), + "latent": (2 * batch_size, 4, latent_height, latent_width), + "text_embeds": (2 * batch_size, 1280), + "time_ids": (2 * batch_size, self.time_dim), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + dtype = torch.float16 if self.fp16 else torch.float32 + return ( + torch.randn( + 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device + ), + torch.tensor([1.0], dtype=torch.float32, device=self.device), + torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), + { + "added_cond_kwargs": { + "text_embeds": torch.randn(2 * batch_size, 1280, dtype=dtype, device=self.device), + "time_ids": torch.randn(2 * batch_size, self.time_dim, dtype=dtype, device=self.device), + } + }, + ) + + +# VAE Decoder +class VAE(BaseModel): + def __init__(self, pipeline_info: PipelineInfo, model, device, max_batch_size): + super().__init__( + pipeline_info, + model=model, + device=device, + max_batch_size=max_batch_size, + ) + + def load_model(self, framework_model_dir, hf_token: Optional[str] = None, subfolder: str = "vae_decoder"): + model_dir = os.path.join(framework_model_dir, self.pipeline_info.name(), subfolder) + if not os.path.exists(model_dir): + vae = AutoencoderKL.from_pretrained( + self.pipeline_info.name(), + subfolder="vae", + use_safetensors=self.pipeline_info.use_safetensors(), + use_auth_token=hf_token, + ).to(self.device) + vae.save_pretrained(model_dir) + else: + print(f"Load {self.name} pytorch model from: {model_dir}") + vae = AutoencoderKL.from_pretrained(model_dir).to(self.device) + + vae.forward = vae.decode + return vae + + def get_input_names(self): + return ["latent"] + + def get_output_names(self): + return ["images"] + + def get_dynamic_axes(self): + return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + ( + min_batch, + max_batch, + _, + _, + _, + _, + min_latent_height, + max_latent_height, + min_latent_width, + max_latent_width, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + return { + "latent": [ + (min_batch, 4, min_latent_height, min_latent_width), + (batch_size, 4, latent_height, latent_width), + (max_batch, 4, max_latent_height, max_latent_width), + ] + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "latent": (batch_size, 4, latent_height, latent_width), + "images": (batch_size, 3, image_height, image_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return (torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device),) + + +def get_tokenizer(pipeline_info: PipelineInfo, framework_model_dir, hf_token, subfolder="tokenizer"): + tokenizer_dir = os.path.join(framework_model_dir, pipeline_info.name(), subfolder) + + if not os.path.exists(tokenizer_dir): + model = CLIPTokenizer.from_pretrained( + pipeline_info.name(), + subfolder=subfolder, + use_safetensors=pipeline_info.is_sd_xl(), + use_auth_token=hf_token, + ) + model.save_pretrained(tokenizer_dir) + else: + print(f"[I] Load tokenizer pytorch model from: {tokenizer_dir}") + model = CLIPTokenizer.from_pretrained(tokenizer_dir) + return model + + +class TorchVAEEncoder(torch.nn.Module): + def __init__(self, vae_encoder): + super().__init__() + self.vae_encoder = vae_encoder + + def forward(self, x): + return self.vae_encoder.encode(x).latent_dist.sample() + + +class VAEEncoder(BaseModel): + def __init__(self, pipeline_info: PipelineInfo, model, device, max_batch_size): + super().__init__( + pipeline_info, + model=model, + device=device, + max_batch_size=max_batch_size, + ) + + def load_model(self, framework_model_dir, hf_token, subfolder="vae_encoder"): + vae = self.from_pretrained(AutoencoderKL, framework_model_dir, hf_token, subfolder) + return TorchVAEEncoder(vae) + + def get_input_names(self): + return ["images"] + + def get_output_names(self): + return ["latent"] + + def get_dynamic_axes(self): + return {"images": {0: "B", 2: "8H", 3: "8W"}, "latent": {0: "B", 2: "H", 3: "W"}} + + def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): + self.check_dims(batch_size, image_height, image_width) + + ( + min_batch, + max_batch, + min_image_height, + max_image_height, + min_image_width, + max_image_width, + _, + _, + _, + _, + ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) + + return { + "images": [ + (min_batch, 3, min_image_height, min_image_width), + (batch_size, 3, image_height, image_width), + (max_batch, 3, max_image_height, max_image_width), + ], + } + + def get_shape_dict(self, batch_size, image_height, image_width): + latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) + return { + "images": (batch_size, 3, image_height, image_width), + "latent": (batch_size, 4, latent_height, latent_width), + } + + def get_sample_input(self, batch_size, image_height, image_width): + self.check_dims(batch_size, image_height, image_width) + return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py new file mode 100644 index 0000000000000..13c450a517eba --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/diffusion_schedulers.py @@ -0,0 +1,721 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from utilities.py of TensorRT demo diffusion, which has the following license: +# +# Copyright 2022 The HuggingFace Inc. team. +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +from typing import List, Optional + +import numpy as np +import torch + + +class DDIMScheduler: + def __init__( + self, + device="cuda", + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + clip_sample: bool = False, + set_alpha_to_one: bool = False, + steps_offset: int = 1, + prediction_type: str = "epsilon", + ): + # this schedule is very specific to the latent diffusion model. + betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + alphas = 1.0 - betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + self.steps_offset = steps_offset + self.num_train_timesteps = num_train_timesteps + self.clip_sample = clip_sample + self.prediction_type = prediction_type + self.device = device + + def configure(self): + variance = np.zeros(self.num_inference_steps, dtype=np.float32) + for idx, timestep in enumerate(self.timesteps): + prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps + variance[idx] = self._get_variance(timestep, prev_timestep) + self.variance = torch.from_numpy(variance).to(self.device) + + timesteps = self.timesteps.long().cpu() + self.alphas_cumprod = self.alphas_cumprod[timesteps].to(self.device) + self.final_alpha_cumprod = self.final_alpha_cumprod.to(self.device) + + def scale_model_input(self, sample: torch.FloatTensor, idx, *args, **kwargs) -> torch.FloatTensor: + return sample + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + def set_timesteps(self, num_inference_steps: int): + self.num_inference_steps = num_inference_steps + step_ratio = self.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(self.device) + self.timesteps += self.steps_offset + + def step( + self, + model_output, + sample, + idx, + timestep, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: torch.FloatTensor = None, + ): + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> η + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + prev_idx = idx + 1 + alpha_prod_t = self.alphas_cumprod[idx] + alpha_prod_t_prev = ( + self.alphas_cumprod[prev_idx] if prev_idx < self.num_inference_steps else self.final_alpha_cumprod + ) + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + if self.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + elif self.prediction_type == "sample": + pred_original_sample = model_output + elif self.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + # predict V + model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip "predicted x_0" + if self.clip_sample: + pred_original_sample = torch.clamp(pred_original_sample, -1, 1) + + # 5. compute variance: "sigma_t(η)" -> see formula (16) + # o_t = sqrt((1 - a_t-1)/(1 - a_t)) * sqrt(1 - a_t/a_t-1) + variance = self.variance[idx] + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the model_output is always re-derived from the clipped x_0 in Glide + model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072 + device = model_output.device + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = torch.randn( + model_output.shape, generator=generator, device=device, dtype=model_output.dtype + ) + variance = variance ** (0.5) * eta * variance_noise + + prev_sample = prev_sample + variance + + return prev_sample + + def add_noise(self, init_latents, noise, idx, latent_timestep): + sqrt_alpha_prod = self.alphas_cumprod[idx] ** 0.5 + sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[idx]) ** 0.5 + noisy_latents = sqrt_alpha_prod * init_latents + sqrt_one_minus_alpha_prod * noise + + return noisy_latents + + +class EulerAncestralDiscreteScheduler: + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + device="cuda", + steps_offset=0, + prediction_type="epsilon", + ): + # this schedule is very specific to the latent diffusion model. + betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + alphas = 1.0 - betas + self.alphas_cumprod = torch.cumprod(alphas, dim=0) + + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = self.sigmas.max() + + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.is_scale_input_called = False + self.device = device + self.num_train_timesteps = num_train_timesteps + self.steps_offset = steps_offset + self.prediction_type = prediction_type + + def scale_model_input(self, sample: torch.FloatTensor, idx, timestep, *args, **kwargs) -> torch.FloatTensor: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + sample = sample / ((sigma**2 + 1) ** 0.5) + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int): + self.num_inference_steps = num_inference_steps + + timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy() + sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) + sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) + sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) + self.sigmas = torch.from_numpy(sigmas).to(device=self.device) + self.timesteps = torch.from_numpy(timesteps).to(device=self.device) + + def configure(self): + dts = np.zeros(self.num_inference_steps, dtype=np.float32) + sigmas_up = np.zeros(self.num_inference_steps, dtype=np.float32) + for idx, timestep in enumerate(self.timesteps): + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + sigma_from = self.sigmas[step_index] + sigma_to = self.sigmas[step_index + 1] + sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + dt = sigma_down - sigma + dts[idx] = dt + sigmas_up[idx] = sigma_up + + self.dts = torch.from_numpy(dts).to(self.device) + self.sigmas_up = torch.from_numpy(sigmas_up).to(self.device) + + def step( + self, + model_output, + sample, + idx, + timestep, + generator=None, + ): + step_index = (self.timesteps == timestep).nonzero().item() + sigma = self.sigmas[step_index] + + # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise + if self.prediction_type == "epsilon": + pred_original_sample = sample - sigma * model_output + elif self.prediction_type == "v_prediction": + # * c_out + input * c_skip + pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1)) + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, or `v_prediction`" + ) + + sigma_up = self.sigmas_up[idx] + + # 2. Convert to an ODE derivative + derivative = (sample - pred_original_sample) / sigma + + dt = self.dts[idx] + + prev_sample = sample + derivative * dt + + device = model_output.device + noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(device) + + prev_sample = prev_sample + noise * sigma_up + + return prev_sample + + def add_noise(self, original_samples, noise, idx, timestep=None): + step_index = (self.timesteps == timestep).nonzero().item() + noisy_samples = original_samples + noise * self.sigmas[step_index] + return noisy_samples + + +class UniPCMultistepScheduler: + def __init__( + self, + device="cuda", + num_train_timesteps: int = 1000, + beta_start: float = 0.00085, + beta_end: float = 0.012, + solver_order: int = 2, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + sample_max_value: float = 1.0, + predict_x0: bool = True, + solver_type: str = "bh2", + lower_order_final: bool = True, + disable_corrector: Optional[List[int]] = None, + ): + self.device = device + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + # Currently we only support VP-type noise schedule + self.alpha_t = torch.sqrt(self.alphas_cumprod) + self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) + self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + self.predict_x0 = predict_x0 + # setable values + self.num_inference_steps = None + timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() + self.timesteps = torch.from_numpy(timesteps) + self.model_outputs = [None] * solver_order + self.timestep_list = [None] * solver_order + self.lower_order_nums = 0 + self.disable_corrector = disable_corrector if disable_corrector else [] + self.last_sample = None + self.num_train_timesteps = num_train_timesteps + self.solver_order = solver_order + self.prediction_type = prediction_type + self.thresholding = thresholding + self.dynamic_thresholding_ratio = dynamic_thresholding_ratio + self.sample_max_value = sample_max_value + self.solver_type = solver_type + self.lower_order_final = lower_order_final + + def set_timesteps(self, num_inference_steps: int): + timesteps = ( + np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) + .round()[::-1][:-1] + .copy() + .astype(np.int64) + ) + + # when num_inference_steps == num_train_timesteps, we can end up with + # duplicates in timesteps. + _, unique_indices = np.unique(timesteps, return_index=True) + timesteps = timesteps[np.sort(unique_indices)] + + self.timesteps = torch.from_numpy(timesteps).to(self.device) + + self.num_inference_steps = len(timesteps) + + self.model_outputs = [ + None, + ] * self.solver_order + self.lower_order_nums = 0 + self.last_sample = None + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + def convert_model_output( + self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor + ) -> torch.FloatTensor: + if self.predict_x0: + if self.prediction_type == "epsilon": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = (sample - sigma_t * model_output) / alpha_t + elif self.prediction_type == "sample": + x0_pred = model_output + elif self.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + x0_pred = alpha_t * sample - sigma_t * model_output + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + if self.thresholding: + x0_pred = self._threshold_sample(x0_pred) + + return x0_pred + else: + if self.prediction_type == "epsilon": + return model_output + elif self.prediction_type == "sample": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = (sample - alpha_t * model_output) / sigma_t + return epsilon + elif self.prediction_type == "v_prediction": + alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] + epsilon = alpha_t * model_output + sigma_t * sample + return epsilon + else: + raise ValueError( + f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction` for the UniPCMultistepScheduler." + ) + + def multistep_uni_p_bh_update( + self, + model_output: torch.FloatTensor, + prev_timestep: int, + sample: torch.FloatTensor, + order: int, + ) -> torch.FloatTensor: + timestep_list = self.timestep_list + model_output_list = self.model_outputs + + s0, t = self.timestep_list[-1], prev_timestep + m0 = model_output_list[-1] + x = sample + + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + h = lambda_t - lambda_s0 + + rks = [] + d1s = [] + for i in range(1, order): + si = timestep_list[-(i + 1)] + mi = model_output_list[-(i + 1)] + lambda_si = self.lambda_t[si] + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + d1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=self.device) + + r = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.solver_type == "bh1": + b_h = hh + elif self.solver_type == "bh2": + b_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + r.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / b_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + r = torch.stack(r) + b = torch.tensor(b, device=self.device) + + if len(d1s) > 0: + d1s = torch.stack(d1s, dim=1) # (B, K) + # for order 2, we use a simplified version + if order == 2: + rhos_p = torch.tensor([0.5], dtype=x.dtype, device=self.device) + else: + rhos_p = torch.linalg.solve(r[:-1, :-1], b[:-1]) + else: + d1s = None + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if d1s is not None: + pred_res = torch.einsum("k,bkchw->bchw", rhos_p, d1s) + else: + pred_res = 0 + x_t = x_t_ - alpha_t * b_h * pred_res + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if d1s is not None: + pred_res = torch.einsum("k,bkchw->bchw", rhos_p, d1s) + else: + pred_res = 0 + x_t = x_t_ - sigma_t * b_h * pred_res + + x_t = x_t.to(x.dtype) + return x_t + + def multistep_uni_c_bh_update( + self, + this_model_output: torch.FloatTensor, + this_timestep: int, + last_sample: torch.FloatTensor, + # this_sample: torch.FloatTensor, + order: int, + ) -> torch.FloatTensor: + timestep_list = self.timestep_list + model_output_list = self.model_outputs + + s0, t = timestep_list[-1], this_timestep + m0 = model_output_list[-1] + x = last_sample + # x_t = this_sample + model_t = this_model_output + + lambda_t, lambda_s0 = self.lambda_t[t], self.lambda_t[s0] + alpha_t, alpha_s0 = self.alpha_t[t], self.alpha_t[s0] + sigma_t, sigma_s0 = self.sigma_t[t], self.sigma_t[s0] + + h = lambda_t - lambda_s0 + + rks = [] + d1s = [] + for i in range(1, order): + si = timestep_list[-(i + 1)] + mi = model_output_list[-(i + 1)] + lambda_si = self.lambda_t[si] + rk = (lambda_si - lambda_s0) / h + rks.append(rk) + d1s.append((mi - m0) / rk) + + rks.append(1.0) + rks = torch.tensor(rks, device=self.device) + + r = [] + b = [] + + hh = -h if self.predict_x0 else h + h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1 + h_phi_k = h_phi_1 / hh - 1 + + factorial_i = 1 + + if self.solver_type == "bh1": + b_h = hh + elif self.solver_type == "bh2": + b_h = torch.expm1(hh) + else: + raise NotImplementedError() + + for i in range(1, order + 1): + r.append(torch.pow(rks, i - 1)) + b.append(h_phi_k * factorial_i / b_h) + factorial_i *= i + 1 + h_phi_k = h_phi_k / hh - 1 / factorial_i + + r = torch.stack(r) + b = torch.tensor(b, device=self.device) + + if len(d1s) > 0: + d1s = torch.stack(d1s, dim=1) + else: + d1s = None + + # for order 1, we use a simplified version + if order == 1: + rhos_c = torch.tensor([0.5], dtype=x.dtype, device=self.device) + else: + rhos_c = torch.linalg.solve(r, b) + + if self.predict_x0: + x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0 + if d1s is not None: + corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], d1s) + else: + corr_res = 0 + d1_t = model_t - m0 + x_t = x_t_ - alpha_t * b_h * (corr_res + rhos_c[-1] * d1_t) + else: + x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0 + if d1s is not None: + corr_res = torch.einsum("k,bkchw->bchw", rhos_c[:-1], d1s) + else: + corr_res = 0 + d1_t = model_t - m0 + x_t = x_t_ - sigma_t * b_h * (corr_res + rhos_c[-1] * d1_t) + x_t = x_t.to(x.dtype) + return x_t + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + return_dict: bool = True, + ): + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.device) + step_index = (self.timesteps == timestep).nonzero() + if len(step_index) == 0: + step_index = len(self.timesteps) - 1 + else: + step_index = step_index.item() + + use_corrector = step_index > 0 and step_index - 1 not in self.disable_corrector and self.last_sample is not None + + model_output_convert = self.convert_model_output(model_output, timestep, sample) + if use_corrector: + sample = self.multistep_uni_c_bh_update( + this_model_output=model_output_convert, + this_timestep=timestep, + last_sample=self.last_sample, + # this_sample=sample, + order=self.this_order, + ) + + # now prepare to run the predictor + prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] + + for i in range(self.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.timestep_list[i] = self.timestep_list[i + 1] + + self.model_outputs[-1] = model_output_convert + self.timestep_list[-1] = timestep + + if self.lower_order_final: + this_order = min(self.solver_order, len(self.timesteps) - step_index) + else: + this_order = self.solver_order + + self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep + assert self.this_order > 0 + + self.last_sample = sample + prev_sample = self.multistep_uni_p_bh_update( + model_output=model_output, # pass the original non-converted model output, in case solver-p is used + prev_timestep=prev_timestep, + sample=sample, + order=self.this_order, + ) + + if self.lower_order_nums < self.solver_order: + self.lower_order_nums += 1 + + if not return_dict: + return (prev_sample,) + + return prev_sample + + def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: + return sample + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=self.device, dtype=original_samples.dtype) + timesteps = timesteps.to(self.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + def configure(self): + pass + + def __len__(self): + return self.num_train_timesteps diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py new file mode 100644 index 0000000000000..64c3c5bc80ecb --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder.py @@ -0,0 +1,181 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import os +from enum import Enum + +import torch +from diffusion_models import CLIP, VAE, CLIPWithProj, PipelineInfo, UNet, UNetXL + + +class EngineType(Enum): + ORT_CUDA = 0 # ONNX Runtime CUDA Execution Provider + ORT_TRT = 1 # ONNX Runtime TensorRT Execution Provider + TRT = 2 # TensorRT + TORCH = 3 # PyTorch + + +def get_engine_type(name: str) -> EngineType: + name_to_type = { + "ORT_CUDA": EngineType.ORT_CUDA, + "ORT_TRT": EngineType.ORT_TRT, + "TRT": EngineType.TRT, + "TORCH": EngineType.TORCH, + } + return name_to_type[name] + + +class EngineBuilder: + def __init__( + self, + engine_type: EngineType, + pipeline_info: PipelineInfo, + device="cuda", + max_batch_size=16, + hf_token=None, + use_cuda_graph=False, + ): + """ + Initializes the Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + device (str | torch.device): + device to run engine + max_batch_size (int): + Maximum batch size for dynamic batch engine. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + self.engine_type = engine_type + self.pipeline_info = pipeline_info + self.max_batch_size = max_batch_size + self.hf_token = hf_token + self.use_cuda_graph = use_cuda_graph + self.device = torch.device(device) + self.torch_device = torch.device(device, torch.cuda.current_device()) + self.stages = pipeline_info.stages() + self.vae_torch_fallback = self.pipeline_info.is_sd_xl() + + self.models = {} + self.engines = {} + self.torch_models = {} + + def teardown(self): + for engine in self.engines.values(): + del engine + self.engines = {} + + def get_cached_model_name(self, model_name): + if self.pipeline_info.is_inpaint(): + model_name += "_inpaint" + return model_name + + def get_onnx_path(self, model_name, onnx_dir, opt=True): + engine_name = self.engine_type.name.lower() + onnx_model_dir = os.path.join( + onnx_dir, self.get_cached_model_name(model_name) + (f".{engine_name}" if opt else "") + ) + os.makedirs(onnx_model_dir, exist_ok=True) + return os.path.join(onnx_model_dir, "model.onnx") + + def get_engine_path(self, engine_dir, model_name, profile_id): + return os.path.join(engine_dir, self.get_cached_model_name(model_name) + profile_id) + + def load_models(self, framework_model_dir: str): + # Disable torch SDPA since torch 2.0.* cannot export it to ONNX + if hasattr(torch.nn.functional, "scaled_dot_product_attention"): + delattr(torch.nn.functional, "scaled_dot_product_attention") + + # For TRT or ORT_TRT, we will export fp16 torch model for UNet. + # For ORT_CUDA, we export fp32 model first, then optimize to fp16. + export_fp16_unet = self.engine_type in [EngineType.ORT_TRT, EngineType.TRT] + + if "clip" in self.stages: + self.models["clip"] = CLIP( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + max_batch_size=self.max_batch_size, + clip_skip=0, + ) + + if "clip2" in self.stages: + self.models["clip2"] = CLIPWithProj( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + max_batch_size=self.max_batch_size, + clip_skip=0, + ) + + if "unet" in self.stages: + self.models["unet"] = UNet( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + fp16=export_fp16_unet, + max_batch_size=self.max_batch_size, + unet_dim=(9 if self.pipeline_info.is_inpaint() else 4), + ) + + if "unetxl" in self.stages: + self.models["unetxl"] = UNetXL( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + fp16=export_fp16_unet, + max_batch_size=self.max_batch_size, + unet_dim=4, + time_dim=(5 if self.pipeline_info.is_sd_xl_refiner() else 6), + ) + + # VAE Decoder + if "vae" in self.stages: + self.models["vae"] = VAE( + self.pipeline_info, + None, # not loaded yet + device=self.torch_device, + max_batch_size=self.max_batch_size, + ) + + if self.vae_torch_fallback: + self.torch_models["vae"] = self.models["vae"].load_model(framework_model_dir, self.hf_token) + + def load_resources(self, image_height, image_width, batch_size): + # Allocate buffers for I/O bindings + for model_name, obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + self.engines[model_name].allocate_buffers( + shape_dict=obj.get_shape_dict(batch_size, image_height, image_width), device=self.torch_device + ) + + def vae_decode(self, latents): + if self.vae_torch_fallback: + latents = latents.to(dtype=torch.float32) + self.torch_models["vae"] = self.torch_models["vae"].to(dtype=torch.float32) + images = self.torch_models["vae"](latents)["sample"] + else: + images = self.run_engine("vae", {"latent": latents})["images"] + + return images + + +def get_engine_paths(work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType): + root_dir = work_dir or "." + short_name = pipeline_info.short_name() + + # When both ORT_CUDA and ORT_TRT/TRT is used, we shall make sub directory for each engine since + # ORT_CUDA need fp32 torch model, while ORT_TRT/TRT use fp16 torch model. + onnx_dir = os.path.join(root_dir, engine_type.name, short_name, "onnx") + engine_dir = os.path.join(root_dir, engine_type.name, short_name, "engine") + output_dir = os.path.join(root_dir, engine_type.name, short_name, "output") + timing_cache = os.path.join(root_dir, engine_type.name, "timing_cache") + framework_model_dir = os.path.join(root_dir, engine_type.name, "torch_model") + + return onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py new file mode 100644 index 0000000000000..253cdcc45bf2e --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_ort_trt.py @@ -0,0 +1,263 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import gc +import logging +import os +import shutil + +import torch +from cuda import cudart +from diffusion_models import PipelineInfo +from engine_builder import EngineBuilder, EngineType + +import onnxruntime as ort +from onnxruntime.transformers.io_binding_helper import CudaSession + +logger = logging.getLogger(__name__) + + +class OrtTensorrtEngine(CudaSession): + def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, workspace_size, enable_cuda_graph): + self.engine_path = engine_path + self.ort_trt_provider_options = self.get_tensorrt_provider_options( + input_profile, + workspace_size, + fp16, + device_id, + enable_cuda_graph, + ) + + session_options = ort.SessionOptions() + session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + print("creating TRT EP session for ", onnx_path) + ort_session = ort.InferenceSession( + onnx_path, + session_options, + providers=[ + ("TensorrtExecutionProvider", self.ort_trt_provider_options), + ], + ) + print("created TRT EP session for ", onnx_path) + + device = torch.device("cuda", device_id) + super().__init__(ort_session, device, enable_cuda_graph) + + def get_tensorrt_provider_options(self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph): + trt_ep_options = { + "device_id": device_id, + "trt_fp16_enable": fp16, + "trt_engine_cache_enable": True, + "trt_timing_cache_enable": True, + "trt_detailed_build_log": True, + "trt_engine_cache_path": self.engine_path, + } + + if enable_cuda_graph: + trt_ep_options["trt_cuda_graph_enable"] = True + + if workspace_size > 0: + trt_ep_options["trt_max_workspace_size"] = workspace_size + + if input_profile: + min_shapes = [] + max_shapes = [] + opt_shapes = [] + for name, profile in input_profile.items(): + assert isinstance(profile, list) and len(profile) == 3 + min_shape = profile[0] + opt_shape = profile[1] + max_shape = profile[2] + assert len(min_shape) == len(opt_shape) and len(opt_shape) == len(max_shape) + + min_shapes.append(f"{name}:" + "x".join([str(x) for x in min_shape])) + opt_shapes.append(f"{name}:" + "x".join([str(x) for x in opt_shape])) + max_shapes.append(f"{name}:" + "x".join([str(x) for x in max_shape])) + + trt_ep_options["trt_profile_min_shapes"] = ",".join(min_shapes) + trt_ep_options["trt_profile_max_shapes"] = ",".join(max_shapes) + trt_ep_options["trt_profile_opt_shapes"] = ",".join(opt_shapes) + + logger.info("trt_ep_options=%s", trt_ep_options) + + return trt_ep_options + + def allocate_buffers(self, shape_dict, device): + super().allocate_buffers(shape_dict) + + +class OrtTensorrtEngineBuilder(EngineBuilder): + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + hf_token=None, + device="cuda", + use_cuda_graph=False, + ): + """ + Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + device (str): + device to run. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + super().__init__( + EngineType.ORT_TRT, + pipeline_info, + max_batch_size=max_batch_size, + hf_token=hf_token, + device=device, + use_cuda_graph=use_cuda_graph, + ) + + def has_engine_file(self, engine_path): + if os.path.isdir(engine_path): + children = os.scandir(engine_path) + for entry in children: + if entry.is_file() and entry.name.endswith(".engine"): + return True + return False + + def get_work_space_size(self, model_name, max_workspace_size): + gibibyte = 2**30 + workspace_size = 4 * gibibyte if model_name == "clip" else max_workspace_size + if workspace_size == 0: + _, free_mem, _ = cudart.cudaMemGetInfo() + # The following logic are adopted from TensorRT demo diffusion. + if free_mem > 6 * gibibyte: + workspace_size = free_mem - 4 * gibibyte + return workspace_size + + def build_engines( + self, + engine_dir, + framework_model_dir, + onnx_dir, + onnx_opset, + opt_image_height, + opt_image_width, + opt_batch_size=1, + force_engine_rebuild=False, + static_batch=False, + static_image_shape=True, + max_workspace_size=0, + device_id=0, + ): + self.torch_device = torch.device("cuda", device_id) + self.load_models(framework_model_dir) + + if force_engine_rebuild: + if os.path.isdir(onnx_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) + shutil.rmtree(onnx_dir) + if os.path.isdir(engine_dir): + logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) + shutil.rmtree(engine_dir) + + if not os.path.isdir(engine_dir): + os.makedirs(engine_dir) + + if not os.path.isdir(onnx_dir): + os.makedirs(onnx_dir) + + # Export models to ONNX + for model_name, model_obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + + profile_id = model_obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape + ) + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + if not self.has_engine_file(engine_path): + onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + if not os.path.exists(onnx_opt_path): + if not os.path.exists(onnx_path): + logger.info(f"Exporting model: {onnx_path}") + model = model_obj.load_model(framework_model_dir, self.hf_token) + with torch.inference_mode(), torch.autocast("cuda"): + inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=model_obj.get_input_names(), + output_names=model_obj.get_output_names(), + dynamic_axes=model_obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + logger.info("Found cached model: %s", onnx_path) + + # Optimize onnx + if not os.path.exists(onnx_opt_path): + logger.info("Generating optimizing model: %s", onnx_opt_path) + model_obj.optimize_trt(onnx_path, onnx_opt_path) + else: + logger.info("Found cached optimized model: %s", onnx_opt_path) + + built_engines = {} + for model_name, model_obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + + profile_id = model_obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape + ) + + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + + if not self.has_engine_file(engine_path): + logger.info( + "Building TensorRT engine for %s from %s to %s. It can take a while to complete...", + model_name, + onnx_opt_path, + engine_path, + ) + else: + logger.info("Reuse cached TensorRT engine in directory %s", engine_path) + + input_profile = model_obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch=static_batch, + static_image_shape=static_image_shape, + ) + + engine = OrtTensorrtEngine( + engine_path, + device_id, + onnx_opt_path, + fp16=True, + input_profile=input_profile, + workspace_size=self.get_work_space_size(model_name, max_workspace_size), + enable_cuda_graph=self.use_cuda_graph, + ) + + built_engines[model_name] = engine + + self.engines = built_engines + + return built_engines + + def run_engine(self, model_name, feed_dict): + return self.engines[model_name].infer(feed_dict) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py new file mode 100644 index 0000000000000..4a924abfb8600 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/engine_builder_tensorrt.py @@ -0,0 +1,507 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import gc +import os +import pathlib +from collections import OrderedDict + +import numpy as np +import onnx +import onnx_graphsurgeon as gs +import tensorrt as trt +import torch +from cuda import cudart +from diffusion_models import PipelineInfo +from engine_builder import EngineBuilder, EngineType +from polygraphy.backend.common import bytes_from_path +from polygraphy.backend.trt import ( + CreateConfig, + ModifyNetworkOutputs, + Profile, + engine_from_bytes, + engine_from_network, + network_from_onnx_path, + save_engine, +) +from trt_utilities import TRT_LOGGER + +# Map of numpy dtype -> torch dtype +numpy_to_torch_dtype_dict = { + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, +} + + +def _cuda_assert(cuda_ret): + err = cuda_ret[0] + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError( + f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t" + ) + if len(cuda_ret) > 1: + return cuda_ret[1] + return None + + +class TensorrtEngine: + def __init__( + self, + engine_path, + ): + self.engine_path = engine_path + self.engine = None + self.context = None + self.buffers = OrderedDict() + self.tensors = OrderedDict() + self.cuda_graph_instance = None + + def __del__(self): + del self.engine + del self.context + del self.buffers + del self.tensors + + def refit(self, onnx_path, onnx_refit_path): + def convert_int64(arr): + if len(arr.shape) == 0: + return np.int32(arr) + return arr + + def add_to_map(refit_dict, name, values): + if name in refit_dict: + assert refit_dict[name] is None + if values.dtype == np.int64: + values = convert_int64(values) + refit_dict[name] = values + + print(f"Refitting TensorRT engine with {onnx_refit_path} weights") + refit_nodes = gs.import_onnx(onnx.load(onnx_refit_path)).toposort().nodes + + # Construct mapping from weight names in refit model -> original model + name_map = {} + for n, node in enumerate(gs.import_onnx(onnx.load(onnx_path)).toposort().nodes): + refit_node = refit_nodes[n] + assert node.op == refit_node.op + # Constant nodes in ONNX do not have inputs but have a constant output + if node.op == "Constant": + name_map[refit_node.outputs[0].name] = node.outputs[0].name + # Handle scale and bias weights + elif node.op == "Conv": + if node.inputs[1].__class__ == gs.Constant: + name_map[refit_node.name + "_TRTKERNEL"] = node.name + "_TRTKERNEL" + if node.inputs[2].__class__ == gs.Constant: + name_map[refit_node.name + "_TRTBIAS"] = node.name + "_TRTBIAS" + # For all other nodes: find node inputs that are initializers (gs.Constant) + else: + for i, inp in enumerate(node.inputs): + if inp.__class__ == gs.Constant: + name_map[refit_node.inputs[i].name] = inp.name + + def map_name(name): + if name in name_map: + return name_map[name] + return name + + # Construct refit dictionary + refit_dict = {} + refitter = trt.Refitter(self.engine, TRT_LOGGER) + all_weights = refitter.get_all() + for layer_name, role in zip(all_weights[0], all_weights[1]): + # for specialized roles, use a unique name in the map: + if role == trt.WeightsRole.KERNEL: + name = layer_name + "_TRTKERNEL" + elif role == trt.WeightsRole.BIAS: + name = layer_name + "_TRTBIAS" + else: + name = layer_name + + assert name not in refit_dict, "Found duplicate layer: " + name + refit_dict[name] = None + + for n in refit_nodes: + # Constant nodes in ONNX do not have inputs but have a constant output + if n.op == "Constant": + name = map_name(n.outputs[0].name) + print(f"Add Constant {name}\n") + add_to_map(refit_dict, name, n.outputs[0].values) + + # Handle scale and bias weights + elif n.op == "Conv": + if n.inputs[1].__class__ == gs.Constant: + name = map_name(n.name + "_TRTKERNEL") + add_to_map(refit_dict, name, n.inputs[1].values) + + if n.inputs[2].__class__ == gs.Constant: + name = map_name(n.name + "_TRTBIAS") + add_to_map(refit_dict, name, n.inputs[2].values) + + # For all other nodes: find node inputs that are initializers (AKA gs.Constant) + else: + for inp in n.inputs: + name = map_name(inp.name) + if inp.__class__ == gs.Constant: + add_to_map(refit_dict, name, inp.values) + + for layer_name, weights_role in zip(all_weights[0], all_weights[1]): + if weights_role == trt.WeightsRole.KERNEL: + custom_name = layer_name + "_TRTKERNEL" + elif weights_role == trt.WeightsRole.BIAS: + custom_name = layer_name + "_TRTBIAS" + else: + custom_name = layer_name + + # Skip refitting Trilu for now; scalar weights of type int64 value 1 - for clip model + if layer_name.startswith("onnx::Trilu"): + continue + + if refit_dict[custom_name] is not None: + refitter.set_weights(layer_name, weights_role, refit_dict[custom_name]) + else: + print(f"[W] No refit weights for layer: {layer_name}") + + if not refitter.refit_cuda_engine(): + print("Failed to refit!") + exit(0) + + def build( + self, + onnx_path, + fp16, + input_profile=None, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + update_output_names=None, + ): + print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}") + p = Profile() + if input_profile: + for name, dims in input_profile.items(): + assert len(dims) == 3 + p.add(name, min=dims[0], opt=dims[1], max=dims[2]) + + config_kwargs = {} + if not enable_all_tactics: + config_kwargs["tactic_sources"] = [] + + network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM]) + if update_output_names: + print(f"Updating network outputs to {update_output_names}") + network = ModifyNetworkOutputs(network, update_output_names) + engine = engine_from_network( + network, + config=CreateConfig( + fp16=fp16, refittable=enable_refit, profiles=[p], load_timing_cache=timing_cache, **config_kwargs + ), + save_timing_cache=timing_cache, + ) + save_engine(engine, path=self.engine_path) + + def load(self): + print(f"Loading TensorRT engine: {self.engine_path}") + self.engine = engine_from_bytes(bytes_from_path(self.engine_path)) + + def activate(self, reuse_device_memory=None): + if reuse_device_memory: + self.context = self.engine.create_execution_context_without_device_memory() + self.context.device_memory = reuse_device_memory + else: + self.context = self.engine.create_execution_context() + + def allocate_buffers(self, shape_dict=None, device="cuda"): + for idx in range(self.engine.num_io_tensors): + binding = self.engine[idx] + if shape_dict and binding in shape_dict: + shape = shape_dict[binding] + else: + shape = self.engine.get_binding_shape(binding) + dtype = trt.nptype(self.engine.get_binding_dtype(binding)) + if self.engine.binding_is_input(binding): + self.context.set_binding_shape(idx, shape) + tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device) + self.tensors[binding] = tensor + + def infer(self, feed_dict, stream, use_cuda_graph=False): + for name, buf in feed_dict.items(): + self.tensors[name].copy_(buf) + + for name, tensor in self.tensors.items(): + self.context.set_tensor_address(name, tensor.data_ptr()) + + if use_cuda_graph: + if self.cuda_graph_instance is not None: + _cuda_assert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream)) + _cuda_assert(cudart.cudaStreamSynchronize(stream)) + else: + # do inference before CUDA graph capture + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError("ERROR: inference failed.") + # capture cuda graph + _cuda_assert( + cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal) + ) + self.context.execute_async_v3(stream) + self.graph = _cuda_assert(cudart.cudaStreamEndCapture(stream)) + + from cuda import nvrtc + + result, major, minor = nvrtc.nvrtcVersion() + assert result == nvrtc.nvrtcResult(0) + if major < 12: + self.cuda_graph_instance = _cuda_assert( + cudart.cudaGraphInstantiate(self.graph, b"", 0) + ) # cuda < 12 + else: + self.cuda_graph_instance = _cuda_assert(cudart.cudaGraphInstantiate(self.graph, 0)) # cuda >= 12 + else: + noerror = self.context.execute_async_v3(stream) + if not noerror: + raise ValueError("ERROR: inference failed.") + + return self.tensors + + +class TensorrtEngineBuilder(EngineBuilder): + """ + Helper class to hide the detail of TensorRT Engine from pipeline. + """ + + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + hf_token=None, + device="cuda", + use_cuda_graph=False, + ): + """ + Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + device (str): + device to run. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + """ + super().__init__( + EngineType.TRT, + pipeline_info, + max_batch_size=max_batch_size, + hf_token=hf_token, + device=device, + use_cuda_graph=use_cuda_graph, + ) + + self.stream = None + self.shared_device_memory = None + + def load_resources(self, image_height, image_width, batch_size): + super().load_resources(image_height, image_width, batch_size) + + self.stream = _cuda_assert(cudart.cudaStreamCreate()) + + def teardown(self): + super().teardown() + + if self.shared_device_memory: + cudart.cudaFree(self.shared_device_memory) + + cudart.cudaStreamDestroy(self.stream) + del self.stream + + def load_engines( + self, + engine_dir, + framework_model_dir, + onnx_dir, + onnx_opset, + opt_batch_size, + opt_image_height, + opt_image_width, + force_export=False, + force_optimize=False, + force_build=False, + static_batch=False, + static_shape=True, + enable_refit=False, + enable_preview=False, + enable_all_tactics=False, + timing_cache=None, + onnx_refit_dir=None, + ): + """ + Build and load engines for TensorRT accelerated inference. + Export ONNX models first, if applicable. + + Args: + engine_dir (str): + Directory to write the TensorRT engines. + framework_model_dir (str): + Directory to write the framework model ckpt. + onnx_dir (str): + Directory to write the ONNX models. + onnx_opset (int): + ONNX opset version to export the models. + opt_batch_size (int): + Batch size to optimize for during engine building. + opt_image_height (int): + Image height to optimize for during engine building. Must be a multiple of 8. + opt_image_width (int): + Image width to optimize for during engine building. Must be a multiple of 8. + force_export (bool): + Force re-exporting the ONNX models. + force_optimize (bool): + Force re-optimizing the ONNX models. + force_build (bool): + Force re-building the TensorRT engine. + static_batch (bool): + Build engine only for specified opt_batch_size. + static_shape (bool): + Build engine only for specified opt_image_height & opt_image_width. Default = True. + enable_refit (bool): + Build engines with refit option enabled. + enable_preview (bool): + Enable TensorRT preview features. + enable_all_tactics (bool): + Enable all tactic sources during TensorRT engine builds. + timing_cache (str): + Path to the timing cache to accelerate build or None + onnx_refit_dir (str): + Directory containing refit ONNX models. + """ + # Create directory + for directory in [engine_dir, onnx_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + self.load_models(framework_model_dir) + + # Export models to ONNX + for model_name, obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + profile_id = obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape + ) + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + if force_export or force_build or not os.path.exists(engine_path): + onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + if force_export or not os.path.exists(onnx_opt_path): + if force_export or not os.path.exists(onnx_path): + print(f"Exporting model: {onnx_path}") + model = obj.load_model(framework_model_dir, self.hf_token) + with torch.inference_mode(), torch.autocast("cuda"): + inputs = obj.get_sample_input(1, opt_image_height, opt_image_width) + torch.onnx.export( + model, + inputs, + onnx_path, + export_params=True, + opset_version=onnx_opset, + do_constant_folding=True, + input_names=obj.get_input_names(), + output_names=obj.get_output_names(), + dynamic_axes=obj.get_dynamic_axes(), + ) + del model + torch.cuda.empty_cache() + gc.collect() + else: + print(f"Found cached model: {onnx_path}") + + # Optimize onnx + if force_optimize or not os.path.exists(onnx_opt_path): + print(f"Generating optimizing model: {onnx_opt_path}") + obj.optimize_trt(onnx_path, onnx_opt_path) + else: + print(f"Found cached optimized model: {onnx_opt_path} ") + + # Build TensorRT engines + for model_name, obj in self.models.items(): + if model_name == "vae" and self.vae_torch_fallback: + continue + profile_id = obj.get_profile_id( + opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape + ) + engine_path = self.get_engine_path(engine_dir, model_name, profile_id) + engine = TensorrtEngine(engine_path) + onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True) + + if force_build or not os.path.exists(engine.engine_path): + engine.build( + onnx_opt_path, + fp16=True, + input_profile=obj.get_input_profile( + opt_batch_size, + opt_image_height, + opt_image_width, + static_batch, + static_shape, + ), + enable_refit=enable_refit, + enable_preview=enable_preview, + enable_all_tactics=enable_all_tactics, + timing_cache=timing_cache, + update_output_names=None, + ) + self.engines[model_name] = engine + + # Load TensorRT engines + for model_name in self.models: + if model_name == "vae" and self.vae_torch_fallback: + continue + self.engines[model_name].load() + if onnx_refit_dir: + onnx_refit_path = self.get_onnx_path(model_name, onnx_refit_dir, opt=True) + if os.path.exists(onnx_refit_path): + self.engines[model_name].refit(onnx_opt_path, onnx_refit_path) + + def max_device_memory(self): + max_device_memory = 0 + for _model_name, engine in self.engines.items(): + max_device_memory = max(max_device_memory, engine.engine.device_memory_size) + return max_device_memory + + def activate_engines(self, shared_device_memory=None): + if shared_device_memory is None: + max_device_memory = self.max_device_memory() + _, shared_device_memory = cudart.cudaMalloc(max_device_memory) + self.shared_device_memory = shared_device_memory + # Load and activate TensorRT engines + for engine in self.engines.values(): + engine.activate(reuse_device_memory=self.shared_device_memory) + + def run_engine(self, model_name, feed_dict): + return self.engines[model_name].infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/models.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/models.py deleted file mode 100644 index 0f7688a3df9f6..0000000000000 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/models.py +++ /dev/null @@ -1,368 +0,0 @@ -# ------------------------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. -# -------------------------------------------------------------------------- -# -# Copyright 2023 The HuggingFace Inc. team. -# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Models used in Stable diffusion. -""" -import logging - -import onnx -import onnx_graphsurgeon as gs -import torch -from onnx import shape_inference -from ort_optimizer import OrtStableDiffusionOptimizer -from polygraphy.backend.onnx.loader import fold_constants - -logger = logging.getLogger(__name__) - - -class TrtOptimizer: - def __init__(self, onnx_graph): - self.graph = gs.import_onnx(onnx_graph) - - def cleanup(self): - self.graph.cleanup().toposort() - - def get_optimized_onnx_graph(self): - return gs.export_onnx(self.graph) - - def select_outputs(self, keep, names=None): - self.graph.outputs = [self.graph.outputs[o] for o in keep] - if names: - for i, name in enumerate(names): - self.graph.outputs[i].name = name - - def fold_constants(self): - onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True) - self.graph = gs.import_onnx(onnx_graph) - - def infer_shapes(self): - onnx_graph = gs.export_onnx(self.graph) - if onnx_graph.ByteSize() > 2147483648: - raise TypeError("ERROR: model size exceeds supported 2GB limit") - else: - onnx_graph = shape_inference.infer_shapes(onnx_graph) - - self.graph = gs.import_onnx(onnx_graph) - - -class BaseModel: - def __init__(self, model, name, device="cuda", fp16=False, max_batch_size=16, embedding_dim=768, text_maxlen=77): - self.model = model - self.name = name - self.fp16 = fp16 - self.device = device - - self.min_batch = 1 - self.max_batch = max_batch_size - self.min_image_shape = 256 # min image resolution: 256x256 - self.max_image_shape = 1024 # max image resolution: 1024x1024 - self.min_latent_shape = self.min_image_shape // 8 - self.max_latent_shape = self.max_image_shape // 8 - - self.embedding_dim = embedding_dim - self.text_maxlen = text_maxlen - - self.model_type = name.lower() if name in ["CLIP", "UNet"] else "vae" - self.ort_optimizer = OrtStableDiffusionOptimizer(self.model_type) - - def get_model(self): - return self.model - - def get_input_names(self): - pass - - def get_output_names(self): - pass - - def get_dynamic_axes(self): - return None - - def get_sample_input(self, batch_size, image_height, image_width): - pass - - def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape): - """For TensorRT EP""" - ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - _, - _, - _, - _, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - - profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}" - - if self.name != "CLIP": - if static_image_shape: - profile_id += f"_h_{image_height}_w_{image_width}" - else: - profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}" - - return profile_id - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - """For TensorRT""" - return None - - def get_shape_dict(self, batch_size, image_height, image_width): - return None - - def optimize_ort(self, input_onnx_path, optimized_onnx_path, to_fp16=True): - self.ort_optimizer.optimize(input_onnx_path, optimized_onnx_path, to_fp16) - - def optimize_trt(self, input_onnx_path, optimized_onnx_path): - onnx_graph = onnx.load(input_onnx_path) - opt = TrtOptimizer(onnx_graph) - opt.cleanup() - opt.fold_constants() - opt.infer_shapes() - opt.cleanup() - onnx_opt_graph = opt.get_optimized_onnx_graph() - onnx.save(onnx_opt_graph, optimized_onnx_path) - - def check_dims(self, batch_size, image_height, image_width): - assert batch_size >= self.min_batch and batch_size <= self.max_batch - assert image_height % 8 == 0 or image_width % 8 == 0 - latent_height = image_height // 8 - latent_width = image_width // 8 - assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape - assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape - return (latent_height, latent_width) - - def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape): - min_batch = batch_size if static_batch else self.min_batch - max_batch = batch_size if static_batch else self.max_batch - latent_height = image_height // 8 - latent_width = image_width // 8 - min_image_height = image_height if static_image_shape else self.min_image_shape - max_image_height = image_height if static_image_shape else self.max_image_shape - min_image_width = image_width if static_image_shape else self.min_image_shape - max_image_width = image_width if static_image_shape else self.max_image_shape - min_latent_height = latent_height if static_image_shape else self.min_latent_shape - max_latent_height = latent_height if static_image_shape else self.max_latent_shape - min_latent_width = latent_width if static_image_shape else self.min_latent_shape - max_latent_width = latent_width if static_image_shape else self.max_latent_shape - return ( - min_batch, - max_batch, - min_image_height, - max_image_height, - min_image_width, - max_image_width, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) - - -class CLIP(BaseModel): - def __init__(self, model, device, max_batch_size, embedding_dim): - super().__init__( - model=model, - name="CLIP", - device=device, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - ) - - def get_input_names(self): - return ["input_ids"] - - def get_output_names(self): - return ["text_embeddings"] - - def get_dynamic_axes(self): - return {"input_ids": {0: "B"}, "text_embeddings": {0: "B"}} - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - self.check_dims(batch_size, image_height, image_width) - min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims( - batch_size, image_height, image_width, static_batch, static_image_shape - ) - return { - "input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return { - "input_ids": (batch_size, self.text_maxlen), - "text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim), - } - - def get_sample_input(self, batch_size, image_height, image_width): - self.check_dims(batch_size, image_height, image_width) - return torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device) - - def optimize_trt(self, input_onnx_path, optimized_onnx_path): - onnx_graph = onnx.load(input_onnx_path) - opt = TrtOptimizer(onnx_graph) - opt.select_outputs([0]) # delete graph output#1 - opt.cleanup() - opt.fold_constants() - opt.infer_shapes() - opt.select_outputs([0], names=["text_embeddings"]) # rename network output - opt.cleanup() - onnx_opt_graph = opt.get_optimized_onnx_graph() - onnx.save(onnx_opt_graph, optimized_onnx_path) - - -class UNet(BaseModel): - def __init__( - self, - model, - device="cuda", - fp16=False, # used by TRT - max_batch_size=16, - embedding_dim=768, - text_maxlen=77, - unet_dim=4, - ): - super().__init__( - model=model, - name="UNet", - device=device, - fp16=fp16, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - text_maxlen=text_maxlen, - ) - self.unet_dim = unet_dim - - def get_input_names(self): - return ["sample", "timestep", "encoder_hidden_states"] - - def get_output_names(self): - return ["latent"] - - def get_dynamic_axes(self): - return { - "sample": {0: "2B", 2: "H", 3: "W"}, - "encoder_hidden_states": {0: "2B"}, - "latent": {0: "2B", 2: "H", 3: "W"}, - } - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - ( - min_batch, - max_batch, - _, - _, - _, - _, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - return { - "sample": [ - (2 * min_batch, self.unet_dim, min_latent_height, min_latent_width), - (2 * batch_size, self.unet_dim, latent_height, latent_width), - (2 * max_batch, self.unet_dim, max_latent_height, max_latent_width), - ], - "encoder_hidden_states": [ - (2 * min_batch, self.text_maxlen, self.embedding_dim), - (2 * batch_size, self.text_maxlen, self.embedding_dim), - (2 * max_batch, self.text_maxlen, self.embedding_dim), - ], - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return { - "sample": (2 * batch_size, self.unet_dim, latent_height, latent_width), - "timestep": [1], - "encoder_hidden_states": (2 * batch_size, self.text_maxlen, self.embedding_dim), - "latent": (2 * batch_size, 4, latent_height, latent_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - dtype = torch.float16 if self.fp16 else torch.float32 - return ( - torch.randn( - 2 * batch_size, self.unet_dim, latent_height, latent_width, dtype=torch.float32, device=self.device - ), - torch.tensor([1.0], dtype=torch.float32, device=self.device), - torch.randn(2 * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device), - ) - - -class VAE(BaseModel): - def __init__(self, model, device, max_batch_size, embedding_dim): - super().__init__( - model=model, - name="VAE Decoder", - device=device, - max_batch_size=max_batch_size, - embedding_dim=embedding_dim, - ) - - def get_input_names(self): - return ["latent"] - - def get_output_names(self): - return ["images"] - - def get_dynamic_axes(self): - return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}} - - def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - ( - min_batch, - max_batch, - _, - _, - _, - _, - min_latent_height, - max_latent_height, - min_latent_width, - max_latent_width, - ) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape) - return { - "latent": [ - (min_batch, 4, min_latent_height, min_latent_width), - (batch_size, 4, latent_height, latent_width), - (max_batch, 4, max_latent_height, max_latent_width), - ] - } - - def get_shape_dict(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return { - "latent": (batch_size, 4, latent_height, latent_width), - "images": (batch_size, 3, image_height, image_width), - } - - def get_sample_input(self, batch_size, image_height, image_width): - latent_height, latent_width = self.check_dims(batch_size, image_height, image_width) - return torch.randn(batch_size, 4, latent_height, latent_width, dtype=torch.float32, device=self.device) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py index 6134fa7bddcf4..37785869a355b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_cuda_txt2img.py @@ -43,16 +43,14 @@ StableDiffusionSafetyChecker, ) from diffusers.schedulers import DDIMScheduler -from diffusers.utils import DIFFUSERS_CACHE -from huggingface_hub import snapshot_download -from models import CLIP, VAE, UNet -from ort_utils import Engines +from diffusion_models import CLIP, VAE, PipelineInfo, UNet +from ort_utils import Engines, StableDiffusionPipelineMixin from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer logger = logging.getLogger(__name__) -class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipeline): +class OnnxruntimeCudaStableDiffusionPipeline(StableDiffusionPipelineMixin, StableDiffusionPipeline): r""" Pipeline for text-to-image generation using CUDA provider in ONNX Runtime. This pipeline inherits from [`StableDiffusionPipeline`]. Check the documentation in super class for most parameters. @@ -70,11 +68,12 @@ def __init__( requires_safety_checker: bool = True, # ONNX export parameters onnx_opset: int = 14, - onnx_dir: str = "raw_onnx", + onnx_dir: str = "onnx_ort", # Onnxruntime execution provider parameters - engine_dir: str = "onnxruntime_optimized_onnx", + engine_dir: str = "ORT_CUDA", force_engine_rebuild: bool = False, enable_cuda_graph: bool = False, + pipeline_info: PipelineInfo = None, ): super().__init__( vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker @@ -96,51 +95,38 @@ def __init__( self.fp16 = False - def __load_models(self): - self.embedding_dim = self.text_encoder.config.hidden_size + self.pipeline_info = pipeline_info - self.models["clip"] = CLIP( - self.text_encoder, - device=self.torch_device, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - ) + def load_models(self): + assert self.pipeline_info.clip_embedding_dim() == self.text_encoder.config.hidden_size - self.models["unet"] = UNet( - self.unet, - device=self.torch_device, - fp16=self.fp16, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - unet_dim=(9 if self.inpaint else 4), - ) + stages = self.pipeline_info.stages() + if "clip" in stages: + self.models["clip"] = CLIP( + self.pipeline_info, + self.text_encoder, + device=self.torch_device, + max_batch_size=self.max_batch_size, + clip_skip=0, + ) - self.models["vae"] = VAE( - self.vae, device=self.torch_device, max_batch_size=self.max_batch_size, embedding_dim=self.embedding_dim - ) + if "unet" in stages: + self.models["unet"] = UNet( + self.pipeline_info, + self.unet, + device=self.torch_device, + fp16=False, + max_batch_size=self.max_batch_size, + unet_dim=(9 if self.pipeline_info.is_inpaint() else 4), + ) - @classmethod - def set_cached_folder(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - - cls.cached_folder = ( - pretrained_model_name_or_path - if os.path.isdir(pretrained_model_name_or_path) - else snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, + if "vae" in stages: + self.models["vae"] = VAE( + self.pipeline_info, + self.vae, + device=self.torch_device, + max_batch_size=self.max_batch_size, ) - ) def to( self, @@ -156,7 +142,7 @@ def to( # load models self.fp16 = torch_dtype == torch.float16 - self.__load_models() + self.load_models() # build engines self.engines.build( @@ -180,88 +166,6 @@ def to( return self - def __encode_prompt(self, prompt, negative_prompt): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - """ - # Tokenize prompt - text_input_ids = ( - self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt - text_embeddings = ( - self.engines.get_engine("clip").infer({"input_ids": text_input_ids})["text_embeddings"].clone() - ) - - # Tokenize negative prompt - uncond_input_ids = ( - self.tokenizer( - negative_prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - uncond_embeddings = self.engines.get_engine("clip").infer({"input_ids": uncond_input_ids})["text_embeddings"] - - # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) - - return text_embeddings - - def __denoise_latent(self, latents, text_embeddings, timesteps=None, mask=None, masked_image_latents=None): - if not isinstance(timesteps, torch.Tensor): - timesteps = self.scheduler.timesteps - - for _step_index, timestep in enumerate(timesteps): - # Expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) - if isinstance(mask, torch.Tensor): - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - - timestep_float = timestep.to(torch.float16) if self.fp16 else timestep.to(torch.float32) - - # Predict the noise residual - noise_pred = self.engines.get_engine("unet").infer( - {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, - )["latent"] - - # Perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample - - latents = 1.0 / 0.18215 * latents - return latents - - def __decode_latent(self, latents): - images = self.engines.get_engine("vae").infer({"latent": latents})["images"] - images = (images / 2 + 0.5).clamp(0, 1) - return images.cpu().permute(0, 2, 3, 1).float().numpy() - def __allocate_buffers(self, image_height, image_width, batch_size): # Allocate output tensors for I/O bindings for model_name, obj in self.models.items(): @@ -337,7 +241,7 @@ def __call__( with torch.inference_mode(), torch.autocast("cuda"): # CLIP text encoder - text_embeddings = self.__encode_prompt(prompt, negative_prompt) + text_embeddings = self.encode_prompt(self.engines.get_engine("clip"), prompt, negative_prompt) # Pre-initialize latents num_channels_latents = self.unet_in_channels @@ -352,30 +256,37 @@ def __call__( ) # UNet denoiser - latents = self.__denoise_latent(latents, text_embeddings) + latents = self.denoise_latent( + self.engines.get_engine("unet"), latents, text_embeddings, timestep_fp16=self.fp16 + ) # VAE decode latent - images = self.__decode_latent(latents) + images = self.decode_latent(self.engines.get_engine("vae"), latents) images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) images = self.numpy_to_pil(images) return StableDiffusionPipelineOutput(images=images, nsfw_content_detected=has_nsfw_concept) -if __name__ == "__main__": - model_name_or_path = "runwayml/stable-diffusion-v1-5" +def example(): + pipeline_info = PipelineInfo("1.5") + model_name_or_path = pipeline_info.name() scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler") - pipe = OnnxruntimeCudaStableDiffusionPipeline.from_pretrained( model_name_or_path, scheduler=scheduler, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models - pipe.set_cached_folder(model_name_or_path) + pipe.set_cached_folder(model_name_or_path, resume_download=True, local_files_only=True) pipe = pipe.to("cuda", torch_dtype=torch.float16) prompt = "photorealistic new zealand hills" image = pipe(prompt).images[0] - image.save("ort_trt_txt2img_new_zealand_hills.png") + image.save("ort_cuda_txt2img_new_zealand_hills.png") + + +if __name__ == "__main__": + example() diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py index 6f3c215f36318..c663e37c7ea7d 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/onnxruntime_tensorrt_txt2img.py @@ -32,13 +32,11 @@ pip install onnxruntime-gpu """ -import gc +import logging import os -import shutil from typing import List, Optional, Union import torch -from cuda import cudart from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.pipelines.stable_diffusion import ( StableDiffusionPipeline, @@ -46,224 +44,15 @@ StableDiffusionSafetyChecker, ) from diffusers.schedulers import DDIMScheduler -from diffusers.utils import DIFFUSERS_CACHE, logging -from huggingface_hub import snapshot_download -from models import CLIP, VAE, UNet -from ort_utils import OrtCudaSession +from diffusion_models import PipelineInfo +from engine_builder_ort_trt import OrtTensorrtEngineBuilder +from ort_utils import StableDiffusionPipelineMixin from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -import onnxruntime as ort +logger = logging.getLogger(__name__) -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - -class Engine(OrtCudaSession): - def __init__(self, engine_path, device_id, onnx_path, fp16, input_profile, workspace_size, enable_cuda_graph): - self.engine_path = engine_path - self.ort_trt_provider_options = self.get_tensorrt_provider_options( - input_profile, - workspace_size, - fp16, - device_id, - enable_cuda_graph, - ) - - sess_options = ort.SessionOptions() - sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL - ort_session = ort.InferenceSession( - onnx_path, - sess_options, - providers=[ - ("TensorrtExecutionProvider", self.ort_trt_provider_options), - ], - ) - - device = torch.device("cuda", device_id) - super().__init__(ort_session, device, enable_cuda_graph) - - def get_tensorrt_provider_options(self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph): - trt_ep_options = { - "device_id": device_id, - "trt_fp16_enable": fp16, - "trt_engine_cache_enable": True, - "trt_timing_cache_enable": True, - "trt_detailed_build_log": True, - "trt_engine_cache_path": self.engine_path, - } - - if enable_cuda_graph: - trt_ep_options["trt_cuda_graph_enable"] = True - - if workspace_size > 0: - trt_ep_options["trt_max_workspace_size"] = workspace_size - - if input_profile: - min_shapes = [] - max_shapes = [] - opt_shapes = [] - for name, profile in input_profile.items(): - assert isinstance(profile, list) and len(profile) == 3 - min_shape = profile[0] - opt_shape = profile[1] - max_shape = profile[2] - assert len(min_shape) == len(opt_shape) and len(opt_shape) == len(max_shape) - - min_shapes.append(f"{name}:" + "x".join([str(x) for x in min_shape])) - opt_shapes.append(f"{name}:" + "x".join([str(x) for x in opt_shape])) - max_shapes.append(f"{name}:" + "x".join([str(x) for x in max_shape])) - - trt_ep_options["trt_profile_min_shapes"] = ",".join(min_shapes) - trt_ep_options["trt_profile_max_shapes"] = ",".join(max_shapes) - trt_ep_options["trt_profile_opt_shapes"] = ",".join(opt_shapes) - - logger.info("trt_ep_options=%s", trt_ep_options) - - return trt_ep_options - - -def get_onnx_path(model_name, onnx_dir, opt=True): - return os.path.join(onnx_dir, model_name + (".opt" if opt else "") + ".onnx") - - -def get_engine_path(engine_dir, model_name, profile_id): - return os.path.join(engine_dir, model_name + profile_id) - - -def has_engine_file(engine_path): - if os.path.isdir(engine_path): - children = os.scandir(engine_path) - for entry in children: - if entry.is_file() and entry.name.endswith(".engine"): - return True - return False - - -def get_work_space_size(model_name, max_workspace_size): - gibibyte = 2**30 - workspace_size = 4 * gibibyte if model_name == "clip" else max_workspace_size - if workspace_size == 0: - _, free_mem, _ = cudart.cudaMemGetInfo() - # The following logic are adopted from TensorRT demo diffusion. - if free_mem > 6 * gibibyte: - workspace_size = free_mem - 4 * gibibyte - return workspace_size - - -def build_engines( - models, - engine_dir, - onnx_dir, - onnx_opset, - opt_image_height, - opt_image_width, - opt_batch_size=1, - force_engine_rebuild=False, - static_batch=False, - static_image_shape=True, - max_workspace_size=0, - device_id=0, - enable_cuda_graph=False, -): - if force_engine_rebuild: - if os.path.isdir(onnx_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", onnx_dir) - shutil.rmtree(onnx_dir) - if os.path.isdir(engine_dir): - logger.info("Remove existing directory %s since force_engine_rebuild is enabled", engine_dir) - shutil.rmtree(engine_dir) - - if not os.path.isdir(engine_dir): - os.makedirs(engine_dir) - - if not os.path.isdir(onnx_dir): - os.makedirs(onnx_dir) - - # Export models to ONNX - for model_name, model_obj in models.items(): - profile_id = model_obj.get_profile_id( - opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape - ) - engine_path = get_engine_path(engine_dir, model_name, profile_id) - if not has_engine_file(engine_path): - onnx_path = get_onnx_path(model_name, onnx_dir, opt=False) - onnx_opt_path = get_onnx_path(model_name, onnx_dir) - if not os.path.exists(onnx_opt_path): - if not os.path.exists(onnx_path): - logger.info(f"Exporting model: {onnx_path}") - model = model_obj.get_model() - with torch.inference_mode(), torch.autocast("cuda"): - inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width) - torch.onnx.export( - model, - inputs, - onnx_path, - export_params=True, - opset_version=onnx_opset, - do_constant_folding=True, - input_names=model_obj.get_input_names(), - output_names=model_obj.get_output_names(), - dynamic_axes=model_obj.get_dynamic_axes(), - ) - del model - torch.cuda.empty_cache() - gc.collect() - else: - logger.info("Found cached model: %s", onnx_path) - - # Optimize onnx - if not os.path.exists(onnx_opt_path): - logger.info("Generating optimizing model: %s", onnx_opt_path) - model_obj.optimize_trt(onnx_path, onnx_opt_path) - else: - logger.info("Found cached optimized model: %s", onnx_opt_path) - - built_engines = {} - for model_name, model_obj in models.items(): - profile_id = model_obj.get_profile_id( - opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape - ) - - engine_path = get_engine_path(engine_dir, model_name, profile_id) - onnx_opt_path = get_onnx_path(model_name, onnx_dir) - - if not has_engine_file(engine_path): - logger.info( - "Building TensorRT engine for %s from %s to %s. It can take a while to complete...", - model_name, - onnx_opt_path, - engine_path, - ) - else: - logger.info("Reuse cached TensorRT engine in directory %s", engine_path) - - input_profile = model_obj.get_input_profile( - opt_batch_size, - opt_image_height, - opt_image_width, - static_batch=static_batch, - static_image_shape=static_image_shape, - ) - - engine = Engine( - engine_path, - device_id, - onnx_opt_path, - fp16=True, - input_profile=input_profile, - workspace_size=get_work_space_size(model_name, max_workspace_size), - enable_cuda_graph=enable_cuda_graph, - ) - - built_engines[model_name] = engine - - return built_engines - - -def run_engine(engine, feed_dict): - return engine.infer(feed_dict) - - -class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipeline): +class OnnxruntimeTensorRTStableDiffusionPipeline(StableDiffusionPipelineMixin, StableDiffusionPipeline): r""" Pipeline for text-to-image generation using TensorRT execution provider in ONNX Runtime. @@ -285,11 +74,12 @@ def __init__( max_batch_size: int = 16, # ONNX export parameters onnx_opset: int = 17, - onnx_dir: str = "onnx", + onnx_dir: str = "onnx_trt", # TensorRT engine build parameters - engine_dir: str = "onnxruntime_tensorrt_engine", + engine_dir: str = "ORT_TRT", # use short name here to avoid path exceeds 260 chars in Windows. force_engine_rebuild: bool = False, enable_cuda_graph: bool = False, + pipeline_info: Optional[PipelineInfo] = None, ): super().__init__( vae, text_encoder, tokenizer, unet, scheduler, safety_checker, feature_extractor, requires_safety_checker @@ -299,16 +89,14 @@ def __init__( self.image_height = image_height self.image_width = image_width - self.inpaint = False self.onnx_opset = onnx_opset self.onnx_dir = onnx_dir self.engine_dir = engine_dir self.force_engine_rebuild = force_engine_rebuild - self.enable_cuda_graph = enable_cuda_graph - # Although cuda graph requires static input shape, engine built with dyamic batch gets better performance in T4. + # Although cuda graph requires static input shape, engine built with dynamic batch gets better performance in T4. # Use static batch could reduce GPU memory footprint. - self.build_static_batch = False + self.build_static_batch = enable_cuda_graph # TODO: support dynamic image shape. self.build_dynamic_shape = False @@ -318,54 +106,13 @@ def __init__( if self.build_dynamic_shape or self.image_height > 512 or self.image_width > 512: self.max_batch_size = 4 - self.models = {} # loaded in __load_models() self.engines = {} # loaded in build_engines() - - def __load_models(self): - self.embedding_dim = self.text_encoder.config.hidden_size - - self.models["clip"] = CLIP( - self.text_encoder, - device=self.torch_device, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - ) - - self.models["unet"] = UNet( - self.unet, - device=self.torch_device, - fp16=True, - max_batch_size=self.max_batch_size, - embedding_dim=self.embedding_dim, - unet_dim=(9 if self.inpaint else 4), + self.engine_builder = OrtTensorrtEngineBuilder( + pipeline_info, max_batch_size=max_batch_size, use_cuda_graph=enable_cuda_graph ) - self.models["vae"] = VAE( - self.vae, device=self.torch_device, max_batch_size=self.max_batch_size, embedding_dim=self.embedding_dim - ) - - @classmethod - def set_cached_folder(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): - cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - - cls.cached_folder = ( - pretrained_model_name_or_path - if os.path.isdir(pretrained_model_name_or_path) - else snapshot_download( - pretrained_model_name_or_path, - cache_dir=cache_dir, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - use_auth_token=use_auth_token, - revision=revision, - ) - ) + self.pipeline_info = pipeline_info + self.stages = pipeline_info.stages() def to( self, @@ -381,11 +128,9 @@ def to( self.torch_device = self._execution_device logger.info(f"Running inference on device: {self.torch_device}") - self.__load_models() - - self.engines = build_engines( - self.models, + self.engines = self.engine_builder.build_engines( self.engine_dir, + None, self.onnx_dir, self.onnx_opset, opt_image_height=self.image_height, @@ -394,96 +139,10 @@ def to( static_batch=self.build_static_batch, static_image_shape=not self.build_dynamic_shape, device_id=self.torch_device.index, - enable_cuda_graph=self.enable_cuda_graph, ) return self - def __encode_prompt(self, prompt, negative_prompt): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. - Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). - """ - # Tokenize prompt - text_input_ids = ( - self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt - text_embeddings = run_engine(self.engines["clip"], {"input_ids": text_input_ids})["text_embeddings"].clone() - - # Tokenize negative prompt - uncond_input_ids = ( - self.tokenizer( - negative_prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - .input_ids.type(torch.int32) - .to(self.torch_device) - ) - - uncond_embeddings = run_engine(self.engines["clip"], {"input_ids": uncond_input_ids})["text_embeddings"] - - # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) - - return text_embeddings - - def __denoise_latent(self, latents, text_embeddings, timesteps=None, mask=None, masked_image_latents=None): - if not isinstance(timesteps, torch.Tensor): - timesteps = self.scheduler.timesteps - for _step_index, timestep in enumerate(timesteps): - # Expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) - latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) - if isinstance(mask, torch.Tensor): - latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - - # Predict the noise residual - timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep - - noise_pred = run_engine( - self.engines["unet"], - {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, - )["latent"] - - # Perform guidance - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample - - latents = 1.0 / 0.18215 * latents - return latents - - def __decode_latent(self, latents): - images = run_engine(self.engines["vae"], {"latent": latents})["images"] - images = (images / 2 + 0.5).clamp(0, 1) - return images.cpu().permute(0, 2, 3, 1).float().numpy() - - def __allocate_buffers(self, image_height, image_width, batch_size): - # Allocate output tensors for I/O bindings - for model_name, obj in self.models.items(): - self.engines[model_name].allocate_buffers(obj.get_shape_dict(batch_size, image_height, image_width)) - @torch.no_grad() def __call__( self, @@ -547,11 +206,11 @@ def __call__( f"Batch size {len(prompt)} is larger than allowed {self.max_batch_size}. If dynamic shape is used, then maximum batch size is 4" ) - self.__allocate_buffers(self.image_height, self.image_width, batch_size) + self.engine_builder.load_resources(self.image_height, self.image_width, batch_size) with torch.inference_mode(), torch.autocast("cuda"): # CLIP text encoder - text_embeddings = self.__encode_prompt(prompt, negative_prompt) + text_embeddings = self.encode_prompt(self.engines["clip"], prompt, negative_prompt) # Pre-initialize latents num_channels_latents = self.unet.config.in_channels @@ -566,10 +225,10 @@ def __call__( ) # UNet denoiser - latents = self.__denoise_latent(latents, text_embeddings) + latents = self.denoise_latent(self.engines["unet"], latents, text_embeddings) # VAE decode latent - images = self.__decode_latent(latents) + images = self.decode_latent(self.engines["vae"], latents) images, has_nsfw_concept = self.run_safety_checker(images, self.torch_device, text_embeddings.dtype) images = self.numpy_to_pil(images) @@ -577,8 +236,8 @@ def __call__( if __name__ == "__main__": - model_name_or_path = "runwayml/stable-diffusion-v1-5" - + pipeline_info = PipelineInfo("1.5") + model_name_or_path = pipeline_info.name() scheduler = DDIMScheduler.from_pretrained(model_name_or_path, subfolder="scheduler") pipe = OnnxruntimeTensorRTStableDiffusionPipeline.from_pretrained( @@ -589,6 +248,7 @@ def __call__( image_height=512, image_width=512, max_batch_size=4, + pipeline_info=pipeline_info, ) # re-use cached folder to save ONNX models and TensorRT Engines diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py index 4512c971ac27c..ffcfd6d9fd7e0 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/optimize_pipeline.py @@ -13,7 +13,7 @@ # python optimize_pipeline.py -i ./sd-v1-5 -o ./sd-v1-5-fp16 --float16 # # Note that the optimizations are carried out for CUDA Execution Provider at first, other EPs may not have the support -# for the fused opeartors. The users could disable the operator fusion manually to workaround. +# for the fused operators. The users could disable the operator fusion manually to workaround. import argparse import logging @@ -49,7 +49,6 @@ def has_external_data(onnx_model_path): def _optimize_sd_pipeline( source_dir: Path, target_dir: Path, - overwrite: bool, use_external_data_format: Optional[bool], float16: bool, force_fp32_ops: List[str], @@ -61,7 +60,6 @@ def _optimize_sd_pipeline( Args: source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models. target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models. - overwrite (bool): Overwrite files if exists. use_external_data_format (Optional[bool]): use external data format. float16 (bool): use half precision force_fp32_ops(List[str]): operators that are forced to run in float32. @@ -144,6 +142,7 @@ def _optimize_sd_pipeline( opt_level=0, optimization_options=fusion_options, use_gpu=True, + provider=args.provider, ) if float16: @@ -168,6 +167,7 @@ def _optimize_sd_pipeline( optimize_by_onnxruntime( str(tmp_model_path), use_gpu=True, + provider=args.provider, optimized_model_path=str(ort_optimized_model_path), save_as_external_data=use_external_data_format, ) @@ -233,7 +233,7 @@ def optimize_stable_diffusion_pipeline( args, ): if os.path.exists(output_dir): - if args.overwrite: + if overwrite: shutil.rmtree(output_dir, ignore_errors=True) else: raise RuntimeError("output directory existed:{output_dir}. Add --overwrite to empty the directory.") @@ -247,7 +247,6 @@ def optimize_stable_diffusion_pipeline( _optimize_sd_pipeline( source_dir, target_dir, - overwrite, use_external_data_format, float16, args.force_fp32_ops, @@ -319,11 +318,19 @@ def parse_arguments(argv: Optional[List[str]] = None): required=False, action="store_true", help="Onnx model larger than 2GB need to use external data format. " - "If specifed, save each onnx model to two files: one for onnx graph, another for weights. " + "If specified, save each onnx model to two files: one for onnx graph, another for weights. " "If not specified, use same format as original model by default. ", ) parser.set_defaults(use_external_data_format=None) + parser.add_argument( + "--provider", + required=False, + type=str, + default=None, + help="Execution provider to use.", + ) + FusionOptions.add_arguments(parser) args = parser.parse_args(argv) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py index 0824c8f07d6e2..2c4b8e8a1639e 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_optimizer.py @@ -12,6 +12,7 @@ from pathlib import Path import onnx +from optimize_pipeline import has_external_data from onnxruntime.transformers.fusion_options import FusionOptions from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel @@ -32,21 +33,25 @@ def __init__(self, model_type: str): "clip": ClipOnnxModel, } - def optimize_by_ort(self, onnx_model): + def optimize_by_ort(self, onnx_model, use_external_data_format=False): # Use this step to see the final graph that executed by Onnx Runtime. with tempfile.TemporaryDirectory() as tmp_dir: # Save to a temporary file so that we can load it with Onnx Runtime. logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...") tmp_model_path = Path(tmp_dir) / "model.onnx" - onnx_model.save_model_to_file(str(tmp_model_path)) - ort_optimized_model_path = tmp_model_path + onnx_model.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format) + ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx" optimize_by_onnxruntime( - str(tmp_model_path), use_gpu=True, optimized_model_path=str(ort_optimized_model_path) + str(tmp_model_path), + use_gpu=True, + optimized_model_path=str(ort_optimized_model_path), + save_as_external_data=use_external_data_format, + external_data_filename="optimized.onnx_data", ) model = onnx.load(str(ort_optimized_model_path), load_external_data=True) return self.model_type_class_mapping[self.model_type](model) - def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): + def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True, keep_io_types=False, keep_outputs=None): """Optimize onnx model using ONNX Runtime transformers optimizer""" logger.info(f"Optimize {input_fp32_onnx_path}...") fusion_options = FusionOptions(self.model_type) @@ -54,6 +59,8 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): fusion_options.enable_packed_kv = False fusion_options.enable_packed_qkv = False + use_external_data_format = has_external_data(input_fp32_onnx_path) + m = optimize_model( input_fp32_onnx_path, model_type=self.model_type, @@ -64,21 +71,24 @@ def optimize(self, input_fp32_onnx_path, optimized_onnx_path, float16=True): use_gpu=True, ) - if self.model_type == "clip": - m.prune_graph(outputs=["text_embeddings"]) # remove the pooler_output, and only keep the first output. + if keep_outputs is None and self.model_type == "clip": + # remove the pooler_output, and only keep the first output. + keep_outputs = ["text_embeddings"] + + if keep_outputs: + m.prune_graph(outputs=keep_outputs) if float16: logger.info("Convert to float16 ...") m.convert_float_to_float16( - keep_io_types=False, - op_block_list=["RandomNormalLike"], + keep_io_types=keep_io_types, ) - # Note that ORT 1.15 could not save model larger than 2GB. This only works for float16 + # Note that ORT < 1.16 could not save model larger than 2GB. if float16 or (self.model_type != "unet"): - m = self.optimize_by_ort(m) + m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format) m.get_operator_statistics() m.get_fused_operator_statistics() - m.save_model_to_file(optimized_onnx_path, use_external_data_format=(self.model_type == "unet") and not float16) + m.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format) logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py index 7192e4ad5584f..5c2145845e757 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/ort_utils.py @@ -7,122 +7,24 @@ import logging import os import shutil -from collections import OrderedDict -from typing import Any, Dict +from typing import Union import torch import onnxruntime as ort -from onnxruntime.transformers.io_binding_helper import TypeHelper +from onnxruntime.transformers.io_binding_helper import CudaSession logger = logging.getLogger(__name__) -class OrtCudaSession: - """Inference Session with IO Binding for ONNX Runtime CUDA or TensorRT provider""" - - def __init__(self, ort_session: ort.InferenceSession, device: torch.device, enable_cuda_graph=False): - self.ort_session = ort_session - self.input_names = [input.name for input in self.ort_session.get_inputs()] - self.output_names = [output.name for output in self.ort_session.get_outputs()] - self.io_name_to_numpy_type = TypeHelper.get_io_numpy_type_map(self.ort_session) - self.io_binding = self.ort_session.io_binding() - self.enable_cuda_graph = enable_cuda_graph - - self.input_tensors = OrderedDict() - self.output_tensors = OrderedDict() - self.device = device - - def __del__(self): - del self.input_tensors - del self.output_tensors - del self.io_binding - del self.ort_session - - def allocate_buffers(self, shape_dict: Dict[str, tuple]): - """Allocate tensors for I/O Binding""" - if self.enable_cuda_graph: - for name, shape in shape_dict.items(): - if name in self.input_names: - # Reuse allocated buffer when the shape is same - if name in self.input_tensors: - if tuple(self.input_tensors[name].shape) == tuple(shape): - continue - raise RuntimeError("Expect static input shape for cuda graph") - - numpy_dtype = self.io_name_to_numpy_type[name] - tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to( - device=self.device - ) - self.input_tensors[name] = tensor - - self.io_binding.bind_input( - name, - tensor.device.type, - tensor.device.index, - numpy_dtype, - list(tensor.size()), - tensor.data_ptr(), - ) - - for name, shape in shape_dict.items(): - if name in self.output_names: - # Reuse allocated buffer when the shape is same - if name in self.output_tensors and tuple(self.output_tensors[name].shape) == tuple(shape): - continue - - numpy_dtype = self.io_name_to_numpy_type[name] - tensor = torch.empty(tuple(shape), dtype=TypeHelper.numpy_type_to_torch_type(numpy_dtype)).to( - device=self.device - ) - self.output_tensors[name] = tensor - - self.io_binding.bind_output( - name, - tensor.device.type, - tensor.device.index, - numpy_dtype, - list(tensor.size()), - tensor.data_ptr(), - ) - - def infer(self, feed_dict): - """Bind input tensors and run inference""" - for name, tensor in feed_dict.items(): - assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous() - if name in self.input_names: - if self.enable_cuda_graph: - assert self.input_tensors[name].nelement() == tensor.nelement() - assert tensor.device.type == "cuda" - # Update input tensor inplace since cuda graph requires input and output has fixed memory address. - from cuda import cudart - - cudart.cudaMemcpy( - self.input_tensors[name].data_ptr(), - tensor.data_ptr(), - tensor.element_size() * tensor.nelement(), - cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice, - ) - else: - self.io_binding.bind_input( - name, - tensor.device.type, - tensor.device.index, - TypeHelper.torch_type_to_numpy_type(tensor.dtype), - [1] if len(tensor.shape) == 0 else list(tensor.shape), - tensor.data_ptr(), - ) - - self.ort_session.run_with_iobinding(self.io_binding) - - return self.output_tensors - - -class Engine(OrtCudaSession): +# ----------------------------------------------------------------------------------------------------- +# Utilities for CUDA EP +# ----------------------------------------------------------------------------------------------------- +class Engine(CudaSession): def __init__(self, engine_path, provider: str, device_id: int = 0, enable_cuda_graph=False): self.engine_path = engine_path self.provider = provider - self.provider_options = self.get_cuda_provider_options(device_id, enable_cuda_graph) + self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph) device = torch.device("cuda", device_id) ort_session = ort.InferenceSession( @@ -135,13 +37,6 @@ def __init__(self, engine_path, provider: str, device_id: int = 0, enable_cuda_g super().__init__(ort_session, device, enable_cuda_graph) - def get_cuda_provider_options(self, device_id: int, enable_cuda_graph: bool) -> Dict[str, Any]: - return { - "device_id": device_id, - "arena_extend_strategy": "kSameAsRequested", - "enable_cuda_graph": enable_cuda_graph, - } - class Engines: def __init__(self, provider, onnx_opset: int = 14): @@ -197,9 +92,16 @@ def build( model = model_obj.get_model().to(model_obj.device) with torch.inference_mode(): inputs = model_obj.get_sample_input(1, 512, 512) + fp32_inputs = tuple( + [ + (tensor.to(torch.float32) if tensor.dtype == torch.float16 else tensor) + for tensor in inputs + ] + ) + torch.onnx.export( model, - inputs, + fp32_inputs, onnx_path, export_params=True, opset_version=self.onnx_opset, @@ -224,3 +126,125 @@ def build( def get_engine(self, model_name): return self.engines[model_name] + + +def run_engine(engine, feed_dict): + return engine.infer(feed_dict) + + +# ----------------------------------------------------------------------------------------------------- +# Utilities for both CUDA and TensorRT EP +# ----------------------------------------------------------------------------------------------------- + + +class StableDiffusionPipelineMixin: + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def encode_prompt(self, clip_engine, prompt, negative_prompt): + """ + Encodes the prompt into text encoder hidden states. + """ + + # Tokenize prompt + text_input_ids = ( + self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + text_embeddings = run_engine(clip_engine, {"input_ids": text_input_ids})["text_embeddings"].clone() + + # Tokenize negative prompt + uncond_input_ids = ( + self.tokenizer( + negative_prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.torch_device) + ) + + uncond_embeddings = run_engine(clip_engine, {"input_ids": uncond_input_ids})["text_embeddings"] + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + + return text_embeddings + + def denoise_latent( + self, + unet_engine, + latents, + text_embeddings, + timesteps=None, + mask=None, + masked_image_latents=None, + timestep_fp16=False, + ): + if not isinstance(timesteps, torch.Tensor): + timesteps = self.scheduler.timesteps + + for _step_index, timestep in enumerate(timesteps): + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input(latent_model_input, timestep) + if isinstance(mask, torch.Tensor): + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + + # Predict the noise residual + timestep_float = timestep.to(torch.float16) if timestep_fp16 else timestep.to(torch.float32) + + noise_pred = run_engine( + unet_engine, + {"sample": latent_model_input, "timestep": timestep_float, "encoder_hidden_states": text_embeddings}, + )["latent"] + + # Perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + latents = self.scheduler.step(noise_pred, timestep, latents).prev_sample + + latents = 1.0 / 0.18215 * latents + return latents + + def decode_latent(self, vae_engine, latents): + images = run_engine(vae_engine, {"latent": latents})["images"] + images = (images / 2 + 0.5).clamp(0, 1) + return images.cpu().permute(0, 2, 3, 1).float().numpy() + + def set_cached_folder(self, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): + from diffusers.utils import DIFFUSERS_CACHE + from huggingface_hub import snapshot_download + + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + + self.cached_folder = ( + pretrained_model_name_or_path + if os.path.isdir(pretrained_model_name_or_path) + else snapshot_download( + pretrained_model_name_or_path, + cache_dir=cache_dir, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + ) + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py new file mode 100644 index 0000000000000..0e2aeb6174666 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_img2img_xl.py @@ -0,0 +1,232 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import time + +import torch +from diffusion_models import PipelineInfo +from pipeline_stable_diffusion import StableDiffusionPipeline + + +class Img2ImgXLPipeline(StableDiffusionPipeline): + """ + Stable Diffusion Img2Img XL pipeline using NVidia TensorRT. + """ + + def __init__(self, pipeline_info: PipelineInfo, *args, **kwargs): + """ + Initializes the Img2Img XL Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of stable diffusion pipeline. + """ + assert pipeline_info.is_sd_xl_refiner() + + super().__init__(pipeline_info, *args, **kwargs) + + self.requires_aesthetics_score = True + + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + ): + if self.requires_aesthetics_score: + add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) + add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) + else: + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) + add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0).to(device=self.device) + return add_time_ids + + def _infer( + self, + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="image", + ): + assert len(prompt) == len(negative_prompt) + + # TODO(tianleiwu): Need we use image_height and image_width for the target size here? + original_size = (1024, 1024) + crops_coords_top_left = (0, 0) + target_size = (1024, 1024) + strength = 0.3 + aesthetic_score = 6.0 + negative_aesthetic_score = 2.5 + + self.set_denoising_steps(denoising_steps) + self.set_random_seed(seed) + + with torch.inference_mode(), torch.autocast("cuda"): + batch_size = len(prompt) + + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # Initialize timesteps + timesteps, t_start = self.initialize_timesteps(self.denoising_steps, strength) + latent_timestep = timesteps[:1].repeat(batch_size) + + # CLIP text encoder 2 + text_embeddings, pooled_embeddings2 = self.encode_prompt( + prompt, + negative_prompt, + encoder="clip2", + tokenizer=self.tokenizer2, + pooled_outputs=True, + output_hidden_states=True, + ) + + # Time embeddings + add_time_ids = self._get_add_time_ids( + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype=text_embeddings.dtype, + ) + + add_time_ids = add_time_ids.repeat(batch_size, 1) + + add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids} + + # Pre-process input image + init_image = self.preprocess_images(batch_size, (init_image,))[0] + + # VAE encode init image + if init_image.shape[1] == 4: + init_latents = init_image + else: + init_latents = self.encode_image(init_image) + + # Add noise to latents using timesteps + noise = torch.randn(init_latents.shape, device=self.device, dtype=torch.float32, generator=self.generator) + latents = self.scheduler.add_noise(init_latents, noise, t_start, latent_timestep) + + # UNet denoiser + latents = self.denoise_latent( + latents, + text_embeddings, + timesteps=timesteps, + step_offset=t_start, + denoiser="unetxl", + guidance=guidance, + add_kwargs=add_kwargs, + ) + + with torch.inference_mode(): + # VAE decode latent + if return_type == "latents": + images = latents * self.vae_scaling_factor + else: + images = self.decode_latent(latents) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + if not warmup: + print("SD-XL Refiner Pipeline") + self.print_summary(e2e_tic, e2e_toc, batch_size) + self.save_images(images, "img2img-xl", prompt) + + return images, (e2e_toc - e2e_tic) * 1000.0 + + def run( + self, + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="images", + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + negative_prompt (str): + The prompt not to guide the image generation. + init_image (tuple[torch.Tensor]): + Image from base pipeline. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + denoising_steps (int): + Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. + guidance (float): + Higher guidance scale encourages to generate images that are closely linked to the text prompt. + seed (int): + Seed for the random generator + warmup (bool): + Indicate if this is a warmup run. + return_type (str): + It can be "latents" or "images". + """ + + if self.is_backend_tensorrt(): + import tensorrt as trt + from trt_utilities import TRT_LOGGER + + with trt.Runtime(TRT_LOGGER): + return self._infer( + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) + else: + return self._infer( + prompt, + negative_prompt, + init_image, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py new file mode 100644 index 0000000000000..a053c9d5d0835 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_stable_diffusion.py @@ -0,0 +1,429 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import os +import pathlib +import random + +import nvtx +import torch +from cuda import cudart +from diffusion_models import PipelineInfo, get_tokenizer +from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, UniPCMultistepScheduler +from engine_builder import EngineType +from engine_builder_ort_trt import OrtTensorrtEngineBuilder +from engine_builder_tensorrt import TensorrtEngineBuilder + + +class StableDiffusionPipeline: + """ + Stable Diffusion pipeline using TensorRT. + """ + + def __init__( + self, + pipeline_info: PipelineInfo, + max_batch_size=16, + scheduler="DDIM", + device="cuda", + output_dir=".", + hf_token=None, + verbose=False, + nvtx_profile=False, + use_cuda_graph=False, + framework_model_dir="pytorch_model", + engine_type: EngineType = EngineType.ORT_TRT, + ): + """ + Initializes the Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of pipeline. + max_batch_size (int): + Maximum batch size for dynamic batch engine. + scheduler (str): + The scheduler to guide the denoising process. Must be one of [DDIM, EulerA, UniPC]. + device (str): + PyTorch device to run inference. Default: 'cuda' + output_dir (str): + Output directory for log files and image artifacts + hf_token (str): + HuggingFace User Access Token to use for downloading Stable Diffusion model checkpoints. + verbose (bool): + Enable verbose logging. + nvtx_profile (bool): + Insert NVTX profiling markers. + use_cuda_graph (bool): + Use CUDA graph to capture engine execution and then launch inference + framework_model_dir (str): + cache directory for framework checkpoints + engine_type (EngineType) + backend engine type like ORT_TRT or TRT + """ + + self.pipeline_info = pipeline_info + self.version = pipeline_info.version + + self.vae_scaling_factor = pipeline_info.vae_scaling_factor() + + self.max_batch_size = max_batch_size + + self.framework_model_dir = framework_model_dir + self.output_dir = output_dir + for directory in [self.framework_model_dir, self.output_dir]: + if not os.path.exists(directory): + print(f"[I] Create directory: {directory}") + pathlib.Path(directory).mkdir(parents=True) + + self.hf_token = hf_token + self.device = device + self.torch_device = torch.device(device, torch.cuda.current_device()) + self.verbose = verbose + self.nvtx_profile = nvtx_profile + + # Scheduler options + sched_opts = {"num_train_timesteps": 1000, "beta_start": 0.00085, "beta_end": 0.012} + if self.version in ("2.0", "2.1"): + sched_opts["prediction_type"] = "v_prediction" + else: + sched_opts["prediction_type"] = "epsilon" + + if scheduler == "DDIM": + self.scheduler = DDIMScheduler(device=self.device, **sched_opts) + elif scheduler == "EulerA": + self.scheduler = EulerAncestralDiscreteScheduler(device=self.device, **sched_opts) + elif scheduler == "UniPC": + self.scheduler = UniPCMultistepScheduler(device=self.device) + else: + raise ValueError("Scheduler should be either DDIM, EulerA or UniPC") + + self.stages = pipeline_info.stages() + + self.vae_torch_fallback = self.pipeline_info.is_sd_xl() + + self.use_cuda_graph = use_cuda_graph + + self.tokenizer = None + self.tokenizer2 = None + + self.generator = None + self.denoising_steps = None + + # backend engine + self.engine_type = engine_type + if engine_type == EngineType.TRT: + self.backend = TensorrtEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) + elif engine_type == EngineType.ORT_TRT: + self.backend = OrtTensorrtEngineBuilder(pipeline_info, max_batch_size, hf_token, device, use_cuda_graph) + else: + raise RuntimeError(f"Backend engine type {engine_type.name} is not supported") + + # Load text tokenizer + if not self.pipeline_info.is_sd_xl_refiner(): + self.tokenizer = get_tokenizer( + self.pipeline_info, self.framework_model_dir, self.hf_token, subfolder="tokenizer" + ) + + if self.pipeline_info.is_sd_xl(): + self.tokenizer2 = get_tokenizer( + self.pipeline_info, self.framework_model_dir, self.hf_token, subfolder="tokenizer_2" + ) + + # Create CUDA events + self.events = {} + for stage in ["clip", "denoise", "vae", "vae_encoder"]: + for marker in ["start", "stop"]: + self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1] + + def is_backend_tensorrt(self): + return self.engine_type == EngineType.TRT + + def set_denoising_steps(self, denoising_steps: int): + if self.denoising_steps != denoising_steps: + assert self.denoising_steps is None # TODO(tianleiwu): support changing steps in different runs + # Pre-compute latent input scales and linear multistep coefficients + self.scheduler.set_timesteps(denoising_steps) + self.scheduler.configure() + self.denoising_steps = denoising_steps + + def load_resources(self, image_height, image_width, batch_size): + # If engine is built with static input shape, call this only once after engine build. + # Otherwise, it need be called before every inference run. + self.backend.load_resources(image_height, image_width, batch_size) + + def set_random_seed(self, seed): + # Initialize noise generator. Usually, it is done before a batch of inference. + self.generator = torch.Generator(device="cuda").manual_seed(seed) if isinstance(seed, int) else None + + def teardown(self): + for e in self.events.values(): + cudart.cudaEventDestroy(e) + + if self.backend: + self.backend.teardown() + + def run_engine(self, model_name, feed_dict): + return self.backend.run_engine(model_name, feed_dict) + + def initialize_latents(self, batch_size, unet_channels, latent_height, latent_width): + latents_dtype = torch.float32 # text_embeddings.dtype + latents_shape = (batch_size, unet_channels, latent_height, latent_width) + latents = torch.randn(latents_shape, device=self.device, dtype=latents_dtype, generator=self.generator) + # Scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def initialize_timesteps(self, timesteps, strength): + self.scheduler.set_timesteps(timesteps) + offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0 + init_timestep = int(timesteps * strength) + offset + init_timestep = min(init_timestep, timesteps) + t_start = max(timesteps - init_timestep + offset, 0) + timesteps = self.scheduler.timesteps[t_start:].to(self.device) + return timesteps, t_start + + def preprocess_images(self, batch_size, images=()): + if self.nvtx_profile: + nvtx_image_preprocess = nvtx.start_range(message="image_preprocess", color="pink") + init_images = [] + for i in images: + image = i.to(self.device).float() + if image.shape[0] != batch_size: + image = image.repeat(batch_size, 1, 1, 1) + init_images.append(image) + if self.nvtx_profile: + nvtx.end_range(nvtx_image_preprocess) + return tuple(init_images) + + def encode_prompt( + self, prompt, negative_prompt, encoder="clip", tokenizer=None, pooled_outputs=False, output_hidden_states=False + ): + if tokenizer is None: + tokenizer = self.tokenizer + + if self.nvtx_profile: + nvtx_clip = nvtx.start_range(message="clip", color="green") + cudart.cudaEventRecord(self.events["clip-start"], 0) + + # Tokenize prompt + text_input_ids = ( + tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) + ) + + # NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt + outputs = self.run_engine(encoder, {"input_ids": text_input_ids}) + text_embeddings = outputs["text_embeddings"].clone() + if output_hidden_states: + hidden_states = outputs["hidden_states"].clone() + + # Tokenize negative prompt + uncond_input_ids = ( + tokenizer( + negative_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + .input_ids.type(torch.int32) + .to(self.device) + ) + + outputs = self.run_engine(encoder, {"input_ids": uncond_input_ids}) + uncond_embeddings = outputs["text_embeddings"] + if output_hidden_states: + uncond_hidden_states = outputs["hidden_states"] + + # Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance + text_embeddings = torch.cat([uncond_embeddings, text_embeddings]).to(dtype=torch.float16) + + if pooled_outputs: + pooled_output = text_embeddings + + if output_hidden_states: + text_embeddings = torch.cat([uncond_hidden_states, hidden_states]).to(dtype=torch.float16) + + cudart.cudaEventRecord(self.events["clip-stop"], 0) + if self.nvtx_profile: + nvtx.end_range(nvtx_clip) + + if pooled_outputs: + return text_embeddings, pooled_output + return text_embeddings + + def denoise_latent( + self, + latents, + text_embeddings, + denoiser="unet", + timesteps=None, + step_offset=0, + mask=None, + masked_image_latents=None, + guidance=7.5, + image_guidance=1.5, + add_kwargs=None, + ): + assert guidance > 1.0, "Guidance has to be > 1.0" + assert image_guidance > 1.0, "Image guidance has to be > 1.0" + + cudart.cudaEventRecord(self.events["denoise-start"], 0) + if not isinstance(timesteps, torch.Tensor): + timesteps = self.scheduler.timesteps + for step_index, timestep in enumerate(timesteps): + if self.nvtx_profile: + nvtx_latent_scale = nvtx.start_range(message="latent_scale", color="pink") + + # Expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, step_offset + step_index, timestep + ) + + if isinstance(mask, torch.Tensor): + latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) + if self.nvtx_profile: + nvtx.end_range(nvtx_latent_scale) + + # Predict the noise residual + if self.nvtx_profile: + nvtx_unet = nvtx.start_range(message="unet", color="blue") + + timestep_float = timestep.float() if timestep.dtype != torch.float32 else timestep + + sample_inp = latent_model_input + timestep_inp = timestep_float + embeddings_inp = text_embeddings + + params = {"sample": sample_inp, "timestep": timestep_inp, "encoder_hidden_states": embeddings_inp} + if add_kwargs: + params.update(add_kwargs) + + noise_pred = self.run_engine(denoiser, params)["latent"] + + if self.nvtx_profile: + nvtx.end_range(nvtx_unet) + + if self.nvtx_profile: + nvtx_latent_step = nvtx.start_range(message="latent_step", color="pink") + + # Perform guidance + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond) + + if type(self.scheduler) == UniPCMultistepScheduler: + latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0] + else: + latents = self.scheduler.step(noise_pred, latents, step_offset + step_index, timestep) + + if self.nvtx_profile: + nvtx.end_range(nvtx_latent_step) + + latents = 1.0 / self.vae_scaling_factor * latents + cudart.cudaEventRecord(self.events["denoise-stop"], 0) + return latents + + def encode_image(self, init_image): + if self.nvtx_profile: + nvtx_vae = nvtx.start_range(message="vae_encoder", color="red") + cudart.cudaEventRecord(self.events["vae_encoder-start"], 0) + init_latents = self.run_engine("vae_encoder", {"images": init_image})["latent"] + cudart.cudaEventRecord(self.events["vae_encoder-stop"], 0) + if self.nvtx_profile: + nvtx.end_range(nvtx_vae) + + init_latents = self.vae_scaling_factor * init_latents + return init_latents + + def decode_latent(self, latents): + if self.nvtx_profile: + nvtx_vae = nvtx.start_range(message="vae", color="red") + cudart.cudaEventRecord(self.events["vae-start"], 0) + images = self.backend.vae_decode(latents) + cudart.cudaEventRecord(self.events["vae-stop"], 0) + if self.nvtx_profile: + nvtx.end_range(nvtx_vae) + return images + + def print_summary(self, tic, toc, batch_size, vae_enc=False): + print("|------------|--------------|") + print("| {:^10} | {:^12} |".format("Module", "Latency")) + print("|------------|--------------|") + if vae_enc: + print( + "| {:^10} | {:>9.2f} ms |".format( + "VAE-Enc", + cudart.cudaEventElapsedTime(self.events["vae_encoder-start"], self.events["vae_encoder-stop"])[1], + ) + ) + print( + "| {:^10} | {:>9.2f} ms |".format( + "CLIP", cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1] + ) + ) + print( + "| {:^10} | {:>9.2f} ms |".format( + "UNet x " + str(self.denoising_steps), + cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1], + ) + ) + print( + "| {:^10} | {:>9.2f} ms |".format( + "VAE-Dec", cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1] + ) + ) + print("|------------|--------------|") + print("| {:^10} | {:>9.2f} ms |".format("Pipeline", (toc - tic) * 1000.0)) + print("|------------|--------------|") + print(f"Throughput: {batch_size / (toc - tic):.2f} image/s") + + @staticmethod + def to_pil_image(images): + images = ( + ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy() + ) + from PIL import Image + + return [Image.fromarray(images[i]) for i in range(images.shape[0])] + + def save_images(self, images, pipeline, prompt): + image_name_prefix = ( + pipeline + "".join(set(["-" + prompt[i].replace(" ", "_")[:10] for i in range(len(prompt))])) + "-" + ) + + images = self.to_pil_image(images) + random_session_id = str(random.randint(1000, 9999)) + for i, image in enumerate(images): + image_path = os.path.join( + self.output_dir, image_name_prefix + str(i + 1) + "-" + random_session_id + ".png" + ) + print(f"Saving image {i+1} / {len(images)} to: {image_path}") + image.save(image_path) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py new file mode 100644 index 0000000000000..82f73e8b3cc61 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img.py @@ -0,0 +1,155 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import time + +import torch +from diffusion_models import PipelineInfo +from pipeline_stable_diffusion import StableDiffusionPipeline + + +class Txt2ImgPipeline(StableDiffusionPipeline): + """ + Stable Diffusion Txt2Img pipeline using NVidia TensorRT. + """ + + def __init__(self, pipeline_info: PipelineInfo, **kwargs): + """ + Initializes the Txt2Img Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of stable diffusion pipeline. + """ + super().__init__(pipeline_info, **kwargs) + + def _infer( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=50, + guidance=7.5, + seed=None, + warmup=False, + return_type="latents", + ): + assert len(prompt) == len(negative_prompt) + batch_size = len(prompt) + + self.set_denoising_steps(denoising_steps) + self.set_random_seed(seed) + + with torch.inference_mode(), torch.autocast("cuda"): + # Pre-initialize latents + latents = self.initialize_latents( + batch_size=batch_size, + unet_channels=4, + latent_height=(image_height // 8), + latent_width=(image_width // 8), + ) + + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # CLIP text encoder + text_embeddings = self.encode_prompt(prompt, negative_prompt) + + # UNet denoiser + latents = self.denoise_latent(latents, text_embeddings, guidance=guidance) + + # VAE decode latent + images = self.decode_latent(latents) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + if not warmup: + self.print_summary(e2e_tic, e2e_toc, batch_size) + self.save_images(images, "txt2img", prompt) + + return images, (e2e_toc - e2e_tic) * 1000.0 + + def run( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=30, + guidance=7.5, + seed=None, + warmup=False, + return_type="images", + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + negative_prompt (str): + The prompt not to guide the image generation. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + denoising_steps (int): + Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. + guidance (float): + Higher guidance scale encourages to generate images that are closely linked to the text prompt. + seed (int): + Seed for the random generator + warmup (bool): + Indicate if this is a warmup run. + return_type (str): + type of return. The value can be "latents" or "images". + """ + if self.is_backend_tensorrt(): + import tensorrt as trt + from trt_utilities import TRT_LOGGER + + with trt.Runtime(TRT_LOGGER): + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) + else: + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py new file mode 100644 index 0000000000000..d8f00ed619354 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/pipeline_txt2img_xl.py @@ -0,0 +1,198 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +# Modified from TensorRT demo diffusion, which has the following license: +# +# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# -------------------------------------------------------------------------- + +import time + +import torch +from diffusion_models import PipelineInfo +from pipeline_stable_diffusion import StableDiffusionPipeline + + +class Txt2ImgXLPipeline(StableDiffusionPipeline): + """ + Stable Diffusion Txt2Img XL pipeline. + """ + + def __init__(self, pipeline_info: PipelineInfo, *args, **kwargs): + """ + Initializes the Txt2Img XL Diffusion pipeline. + + Args: + pipeline_info (PipelineInfo): + Version and Type of stable diffusion pipeline. + """ + assert pipeline_info.is_sd_xl_base() + + super().__init__(pipeline_info, *args, **kwargs) + + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + def _infer( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="images", + ): + assert len(prompt) == len(negative_prompt) + + # TODO(tianleiwu): Need we use image_height and image_width for the target size here? + original_size = (1024, 1024) + crops_coords_top_left = (0, 0) + target_size = (1024, 1024) + batch_size = len(prompt) + + self.set_denoising_steps(denoising_steps) + self.set_random_seed(seed) + + with torch.inference_mode(), torch.autocast("cuda"): + # Pre-initialize latents + latents = self.initialize_latents( + batch_size=batch_size, + unet_channels=4, + latent_height=(image_height // 8), + latent_width=(image_width // 8), + ) + + torch.cuda.synchronize() + e2e_tic = time.perf_counter() + + # CLIP text encoder + text_embeddings = self.encode_prompt( + prompt, negative_prompt, encoder="clip", tokenizer=self.tokenizer, output_hidden_states=True + ) + # CLIP text encoder 2 + text_embeddings2, pooled_embeddings2 = self.encode_prompt( + prompt, + negative_prompt, + encoder="clip2", + tokenizer=self.tokenizer2, + pooled_outputs=True, + output_hidden_states=True, + ) + + # Merged text embeddings + text_embeddings = torch.cat([text_embeddings, text_embeddings2], dim=-1) + + # Time embeddings + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=text_embeddings.dtype + ) + add_time_ids = add_time_ids.repeat(batch_size, 1) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0).to(self.device) + + add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids} + + # UNet denoiser + latents = self.denoise_latent( + latents, text_embeddings, denoiser="unetxl", guidance=guidance, add_kwargs=add_kwargs + ) + + # VAE decode latent + if return_type == "latents": + images = latents * self.vae_scaling_factor + else: + images = self.decode_latent(latents) + + torch.cuda.synchronize() + e2e_toc = time.perf_counter() + + if not warmup: + print("SD-XL Base Pipeline") + self.print_summary(e2e_tic, e2e_toc, batch_size) + if return_type == "images": + self.save_images(images, "txt2img-xl", prompt) + + return images, (e2e_toc - e2e_tic) * 1000.0 + + def run( + self, + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=30, + guidance=5.0, + seed=None, + warmup=False, + return_type="images", + ): + """ + Run the diffusion pipeline. + + Args: + prompt (str): + The text prompt to guide image generation. + negative_prompt (str): + The prompt not to guide the image generation. + image_height (int): + Height (in pixels) of the image to be generated. Must be a multiple of 8. + image_width (int): + Width (in pixels) of the image to be generated. Must be a multiple of 8. + denoising_steps (int): + Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference. + guidance (float): + Higher guidance scale encourages to generate images that are closely linked to the text prompt. + seed (int): + Seed for the random generator + warmup (bool): + Indicate if this is a warmup run. + return_type (str): + It can be "latents" or "images". + """ + + if self.is_backend_tensorrt(): + import tensorrt as trt + from trt_utilities import TRT_LOGGER + + with trt.Runtime(TRT_LOGGER): + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) + else: + return self._infer( + prompt, + negative_prompt, + image_height, + image_width, + denoising_steps=denoising_steps, + guidance=guidance, + seed=seed, + warmup=warmup, + return_type=return_type, + ) diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt index b942749f8dcd2..2a3caf4c2392b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-cuda.txt @@ -1,8 +1,14 @@ -r requirements.txt -onnxruntime-gpu>=1.14 +onnxruntime-gpu>=1.16 py3nvml>=0.2.7 + # cuda-python is needed for cuda graph. It shall be compatible with CUDA version of torch and onnxruntime-gpu. -cuda-python==11.7.0 -#To export onnx of stable diffusion, please install PyTorch 1.13.1+cu117 -#--extra-index-url https://download.pytorch.org/whl/cu117 -#torch==1.13.1+cu117 +cuda-python==11.8.0 +# For windows, cuda-python need the following +pywin32; platform_system == "Windows" + +nvtx + +# To export onnx, please install PyTorch 2.10 like +# pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118 +# pip3 install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt index 567f39c0119e6..5b59c64ab7470 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/requirements-tensorrt.txt @@ -1,18 +1,2 @@ -diffusers>=0.16.0 -transformers>=4.26.0 -numpy>=1.24.1 -accelerate -onnx>=1.13.0 -coloredlogs -packaging -protobuf -psutil -sympy +-r requirements-cuda.txt tensorrt>=8.6.1 -onnxruntime-gpu>=1.15.1 -py3nvml -# cuda-python version shall be compatible with CUDA version of torch and onnxruntime-gpu -cuda-python==11.7.0 -#To export onnx of stable diffusion, please install PyTorch 1.13.1+cu117 -#--extra-index-url https://download.pytorch.org/whl/cu117 -#torch==1.13.1+cu117 diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/trt_utilities.py b/onnxruntime/python/tools/transformers/models/stable_diffusion/trt_utilities.py new file mode 100644 index 0000000000000..d03a9f9f55372 --- /dev/null +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/trt_utilities.py @@ -0,0 +1,12 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import tensorrt as trt + +TRT_LOGGER = trt.Logger(trt.Logger.ERROR) + + +def init_trt_plugins(): + # Register TensorRT plugins + trt.init_libnvinfer_plugins(TRT_LOGGER, "") diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 60be2d84b2bc8..e9c24ed3eb09b 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -610,7 +610,7 @@ def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): When symbolic shape inference is used (even if it failed), ONNX shape inference will be disabled. - Note that onnx shape inference will fail for model larger than 2GB. For large model, you have to eanble + Note that onnx shape inference will fail for model larger than 2GB. For large model, you have to enable symbolic shape inference. If your model is not optimized, you can also use model path to call convert_float_to_float16 in float16.py (see https://github.com/microsoft/onnxruntime/pull/15067) to avoid the 2GB limit. @@ -832,7 +832,7 @@ def get_first_output(node): # Keep track of nodes to keep. The key is first output of node, and the value is the node. output_to_node = {} - # Start from graph outputs, and find parent nodes recurisvely, and add nodes to the output_to_node dictionary. + # Start from graph outputs, and find parent nodes recursively, and add nodes to the output_to_node dictionary. dq = deque() for output in keep_outputs: if output in output_name_to_node: @@ -1161,7 +1161,7 @@ def has_same_value( signature_cache1 (dict): Optional dictionary to store data signatures of tensor1 in order to speed up comparison. signature_cache2 (dict): Optional dictionary to store data signatures of tensor2 in order to speed up comparison. Returns: - bool: True when two intializers has same value. + bool: True when two initializers has same value. """ sig1 = ( signature_cache1[tensor1.name] diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 3f274eb6c835a..5ded027b36f74 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -69,6 +69,8 @@ def optimize_by_onnxruntime( save_as_external_data: bool = False, external_data_filename: str = "", external_data_file_threshold: int = 1024, + *, + provider: Optional[str] = None, ) -> str: """ Use onnxruntime to optimize model. @@ -82,6 +84,7 @@ def optimize_by_onnxruntime( save_as_external_data (bool): whether to save external data outside of ONNX model external_data_filename (str): name of external data file. If not provided, name is automatically created from ONNX model. external_data_file_threshold (int): threshold to decide whether to save tensor in ONNX model or in external data file + provider (str or None): execution provider to use if use_gpu Returns: optimized_model_path (str): the path of optimized model """ @@ -90,8 +93,12 @@ def optimize_by_onnxruntime( import onnxruntime - if use_gpu and set(onnxruntime.get_available_providers()).isdisjoint( - ["CUDAExecutionProvider", "ROCMExecutionProvider", "MIGraphXExecutionProvider"] + if ( + use_gpu + and provider is None + and set(onnxruntime.get_available_providers()).isdisjoint( + ["CUDAExecutionProvider", "ROCMExecutionProvider", "MIGraphXExecutionProvider"] + ) ): logger.error("There is no gpu for onnxruntime to do optimization.") return onnx_model_path @@ -138,17 +145,32 @@ def optimize_by_onnxruntime( kwargs["disabled_optimizers"] = disabled_optimizers if not use_gpu: - onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=["CPUExecutionProvider"], **kwargs) + providers = ["CPUExecutionProvider"] + elif provider is not None: + if provider == "dml": + providers = ["DmlExecutionProvider"] + elif provider == "rocm": + providers = ["ROCMExecutionProvider"] + elif provider == "migraphx": + providers = ["MIGraphXExecutionProvider", "ROCMExecutionProvider"] + elif provider == "cuda": + providers = ["CUDAExecutionProvider"] + elif provider == "tensorrt": + providers = ["TensorrtExecutionProvider", "CUDAExecutionProvider"] + else: + providers = ["CUDAExecutionProvider"] + + providers.append("CPUExecutionProvider") else: - gpu_ep = [] + providers = [] if torch_version.hip: - gpu_ep.append("MIGraphXExecutionProvider") - gpu_ep.append("ROCMExecutionProvider") + providers.append("MIGraphXExecutionProvider") + providers.append("ROCMExecutionProvider") else: - gpu_ep.append("CUDAExecutionProvider") + providers.append("CUDAExecutionProvider") - onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=gpu_ep, **kwargs) + onnxruntime.InferenceSession(onnx_model_path, sess_options, providers=providers, **kwargs) assert os.path.exists(optimized_model_path) and os.path.isfile(optimized_model_path) logger.debug("Save optimized model by onnxruntime to %s", optimized_model_path) @@ -220,6 +242,8 @@ def optimize_model( use_gpu: bool = False, only_onnxruntime: bool = False, verbose: bool = False, + *, + provider: Optional[str] = None, ): """Optimize Model by OnnxRuntime and/or python fusion logic. @@ -257,6 +281,7 @@ def optimize_model( use_gpu (bool, optional): use gpu or not for onnxruntime. Defaults to False. only_onnxruntime (bool, optional): only use onnxruntime to optimize model, and no python fusion. Defaults to False. + provider (str, optional): execution provider to use if use_gpu. Defaults to None. Returns: object of an optimizer class. @@ -302,6 +327,7 @@ def optimize_model( temp_model_path = optimize_by_onnxruntime( input, use_gpu=use_gpu, + provider=provider, optimized_model_path=optimized_model_path, opt_level=opt_level, disabled_optimizers=disabled_optimizers, @@ -316,6 +342,7 @@ def optimize_model( temp_model_path = optimize_by_onnxruntime( input, use_gpu=use_gpu, + provider=provider, optimized_model_path=optimized_model_path, opt_level=1, disabled_optimizers=disabled_optimizers, @@ -423,6 +450,14 @@ def _parse_arguments(): ) parser.set_defaults(use_gpu=False) + parser.add_argument( + "--provider", + required=False, + type=str, + default=None, + help="Execution provider to use if use_gpu", + ) + parser.add_argument( "--only_onnxruntime", required=False, @@ -501,6 +536,7 @@ def main(): opt_level=args.opt_level, optimization_options=optimization_options, use_gpu=args.use_gpu, + provider=args.provider, only_onnxruntime=args.only_onnxruntime, ) diff --git a/onnxruntime/test/contrib_ops/qordered_attention_test.cc b/onnxruntime/test/contrib_ops/qordered_attention_test.cc index 24e4bff528285..1dd0162ad722f 100644 --- a/onnxruntime/test/contrib_ops/qordered_attention_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_attention_test.cc @@ -240,12 +240,6 @@ static std::vector transpose(const T& src, size_t h, size_t w) { } TEST(QOrderedTest, Attention_WithData_ROW_ORDER) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - // Needs Turing architecture if (NeedSkipIfCudaArchLowerThan(750) || NeedSkipIfCudaArchGreaterEqualThan(800)) { return; diff --git a/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc b/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc index 55209d9422fdd..06fe42ca989d7 100644 --- a/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc @@ -34,12 +34,6 @@ static void run_qordered_longformer_attention_op_test( const int64_t head_size, const int64_t window, int64_t input_hidden_size = 0) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - // Needs Turing architecture if (NeedSkipIfCudaArchLowerThan(750) || NeedSkipIfCudaArchGreaterEqualThan(800)) { return; diff --git a/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc b/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc index e5b3d59ef86e3..e3905db6355d9 100644 --- a/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc @@ -21,12 +21,6 @@ static void RunQOrdered_MatMul_Test( OrderCublasLt weight_order, float scale_A, float scale_B, float scale_C, float scale_Y, bool add_bias = false, bool broadcast_c_batch = false, bool per_channel = false) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - // Needs Turing architecture if (NeedSkipIfCudaArchLowerThan(750) || NeedSkipIfCudaArchGreaterEqualThan(800)) { return; diff --git a/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc b/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc index 15e97751acf2d..0f3f702695b80 100644 --- a/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc @@ -73,12 +73,6 @@ static void RunQOrdered_Quantize_Test( std::vector const& shape, OrderCublasLt order_q, float scale) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - auto qvec = QuantizeTransform(shape, scale, fvec, order_q); std::vector> execution_providers; @@ -153,12 +147,6 @@ static void RunQOrdered_Dequantize_Test( std::vector const& shape, OrderCublasLt order_q, float scale) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - auto fvec = DequantizeTransform(shape, scale, qvec, order_q); std::vector> execution_providers; diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index e126979532644..6e745776ab6b0 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -6,6 +6,7 @@ #include "onnx/defs/parser.h" #include "core/common/span_utils.h" +#include "core/framework/float8.h" #include "core/graph/model.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/inference_session.h" @@ -69,7 +70,9 @@ static void Check(const char* source, float threshold = 0.001f; for (size_t i = 0; i < size; ++i) { - ASSERT_NEAR(data[i], output_values[i], threshold) << "at position i:" << i; + if (!std::isnan(data[i]) && !std::isnan(output_values[i])) { + ASSERT_NEAR(data[i], output_values[i], threshold) << "at position i:" << i; + } } } @@ -389,25 +392,13 @@ TEST(FunctionTest, AttrSaturateNan) { > agraph (float[N] x) => (float[N] y) { - y0 = local.myfun (x) - y1 = local.myfun (x) - y = Add (y0, y1) - } - - < - opset_import: [ "" : 19 ], - domain: "local" - > - myfun (x) => (y) { - x2 = Constant () - x2_ = Cast(x2) - x3 = CastLike(x2, x2_) - x3_ = Cast(x3) - y = Add (x, x3_) + x_E4M3FNUZ = Cast(x) + x_E4M3FNUZ_2 = CastLike(x, x_E4M3FNUZ) # NaN when OOR + y = Cast(x_E4M3FNUZ_2) } )"; - Check(code, "x", {1.0, 2.0, 1e6}, "y", {243.0, 245.0, 2000241}); // std::numeric_limits::quiet_NaN()}); + Check(code, "x", {1.0, 2.0, 1e6}, "y", {1.0, 2.0, std::numeric_limits::quiet_NaN()}); } #endif diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 077c6ff58e2da..486ec37d1eebd 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -890,6 +890,47 @@ TEST(InferenceSessionTests, ConfigureVerbosityLevel) { #endif } +TEST(InferenceSessionTests, UseUserSpecifiedLoggingFunctionInSession) { + SessionOptions so; + /* + typedef void(ORT_API_CALL* OrtLoggingFunction)( + void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, + const char* message); + */ + std::vector log_msgs; + so.user_logging_function = [](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, + const char* message) { + ORT_UNUSED_PARAMETER(severity); + ORT_UNUSED_PARAMETER(category); + ORT_UNUSED_PARAMETER(logid); + ORT_UNUSED_PARAMETER(code_location); + std::vector* v_ptr = reinterpret_cast*>(param); + std::vector& msg_vector = *v_ptr; + msg_vector.push_back(std::string(message)); + }; + so.user_logging_param = &log_msgs; + so.session_log_severity_level = static_cast(Severity::kVERBOSE); + so.session_log_verbosity_level = 1; + so.session_logid = "InferenceSessionTests.UseUserSpecifiedLoggingFunctionInSession"; + + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); + + RunOptions run_options; + run_options.run_tag = "one session/one tag"; + RunModel(session_object, run_options); + +// vlog output is disabled in release builds +#ifndef NDEBUG + bool have_log_entry_with_vlog_session_msg = + (std::find_if(log_msgs.begin(), log_msgs.end(), + [&](std::string msg) { return msg.find("Added input argument with name") != string::npos; }) != + log_msgs.end()); + ASSERT_TRUE(have_log_entry_with_vlog_session_msg); +#endif +} + TEST(InferenceSessionTests, TestWithIstream) { SessionOptions so; @@ -2056,7 +2097,7 @@ TEST(InferenceSessionTests, TestStrictShapeInference) { ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry(kOrtSessionOptionsConfigStrictShapeTypeInference, "1")); tester.Run(session_options, OpTester::ExpectResult::kExpectFailure, - "Mismatch between number of source and target dimensions. Source=1 Target=2", + "Mismatch between number of inferred and declared dimensions. inferred=1 declared=2", excluded_provider_types); } diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 82e5efd92a8f1..e1ce1d4abf81d 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -850,6 +850,130 @@ TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test3) { ASSERT_EQ(session_state_2.GetUsedSharedPrePackedWeightCounter(), static_cast(1)); } +// Pre-packing enabled + shared initializers + +// pre-packed weights container + subgraphs = +// caching enabled in pre-packed weights used in subgraphs +TEST_F(SessionStateTestSharedInitalizersWithPrePacking, test4) { + SessionOptions sess_options; + sess_options.enable_mem_pattern = true; + sess_options.execution_mode = ExecutionMode::ORT_SEQUENTIAL; + sess_options.use_deterministic_compute = false; + sess_options.enable_mem_reuse = true; + // Enable pre-packing + sess_options.config_options.configurations[kOrtSessionOptionsConfigDisablePrepacking] = "0"; + + // Enable shared initializer + OrtMemoryInfo mem_info(CPU, OrtDeviceAllocator); + std::vector float_data(1, 1); + auto value = std::make_unique(); + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape(std::vector{1}), + reinterpret_cast(float_data.data()), mem_info, *value); + + ASSERT_STATUS_OK(sess_options.AddInitializer("if_shared", value.get())); + + // Enable pre-packed weights container + PrepackedWeightsContainer prepacked_weights_container; + + // First session/model + Model model_1("graph_main", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + + CreateGraphWithSubgraph(model_1.MainGraph()); + PlaceAllNodesToCPUEP(model_1.MainGraph()); + SessionState session_state_1(model_1.MainGraph(), + execution_providers, + tp.get(), + nullptr, /*inter_op_thread_pool*/ + dtm, + DefaultLoggingManager().DefaultLogger(), + profiler, + sess_options, + &prepacked_weights_container); + + ASSERT_STATUS_OK(session_state_1.FinalizeSessionState(std::basic_string(), + kernel_registry_manager)); + + // At the main graph level, there should be no pre-packing calls as there are + // no initializers (shared or otherwise) consumed by any nodes in the main graph + ASSERT_EQ(session_state_1.GetNumberOfPrepacksCounter(), static_cast(0)); + + auto if_index_1 = 1; + if (session_state_1.GetKernel(0)->Node().OpType() == "If") { + if_index_1 = 0; + } + + const auto& subgraph_session_states = session_state_1.GetSubgraphSessionStateMap(); + const auto& if_node_session_states = subgraph_session_states.at(if_index_1); + const auto& session_state_1_then_branch_session_state = *if_node_session_states.at("then_branch"); + const auto& session_state_1_else_branch_session_state = *if_node_session_states.at("else_branch"); + + auto if_node_branches_prepack_counter_1 = + session_state_1_then_branch_session_state.GetNumberOfPrepacksCounter() + + session_state_1_else_branch_session_state.GetNumberOfPrepacksCounter(); + + // We should be seeing 2 pre-pack calls in the "If" node (one in each subgraph) + ASSERT_EQ(if_node_branches_prepack_counter_1, static_cast(2)); + + auto if_node_branches_shared_prepack_counter_1 = + session_state_1_then_branch_session_state.GetUsedSharedPrePackedWeightCounter() + + session_state_1_else_branch_session_state.GetUsedSharedPrePackedWeightCounter(); + + // We should only be seeing 1 shared pre-pack weights usage in the "If" node + // Either the "then branch" or "else branch" will be using the shared version + // depending on which branch writes to the shared container + ASSERT_EQ(if_node_branches_shared_prepack_counter_1, static_cast(1)); + + // Second session/model + Model model_2("graph_main", false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + domain_to_version, std::vector(), + DefaultLoggingManager().DefaultLogger()); + + CreateGraphWithSubgraph(model_2.MainGraph()); + PlaceAllNodesToCPUEP(model_2.MainGraph()); + SessionState session_state_2(model_2.MainGraph(), + execution_providers, + tp.get(), + nullptr, /*inter_op_thread_pool*/ + dtm, + DefaultLoggingManager().DefaultLogger(), + profiler, + sess_options, + &prepacked_weights_container); + + ASSERT_STATUS_OK(session_state_2.FinalizeSessionState(std::basic_string(), + kernel_registry_manager)); + + // At the main graph level, there should be no pre-packing calls as there are + // no initializers (shared or otherwise) consumed by any nodes in the main graph + ASSERT_EQ(session_state_2.GetNumberOfPrepacksCounter(), static_cast(0)); + + auto if_index_2 = 1; + if (session_state_2.GetKernel(0)->Node().OpType() == "If") { + if_index_2 = 0; + } + + const auto& subgraph_session_states_2 = session_state_2.GetSubgraphSessionStateMap(); + const auto& if_node_session_states_2 = subgraph_session_states_2.at(if_index_2); + const auto& session_state_2_then_branch_session_state = *if_node_session_states_2.at("then_branch"); + const auto& session_state_2_else_branch_session_state = *if_node_session_states_2.at("else_branch"); + + auto if_node_branches_prepack_counter_2 = + session_state_2_then_branch_session_state.GetNumberOfPrepacksCounter() + + session_state_2_else_branch_session_state.GetNumberOfPrepacksCounter(); + + // We should be seeing 2 pre-pack calls in the "If" node (one in each subgraph) + ASSERT_EQ(if_node_branches_prepack_counter_2, static_cast(2)); + + auto if_node_branches_shared_prepack_counter_2 = + session_state_2_then_branch_session_state.GetUsedSharedPrePackedWeightCounter() + + session_state_2_else_branch_session_state.GetUsedSharedPrePackedWeightCounter(); + + // We should be seeing 2 shared pre-pack weights calls in the "If" node + // Both branches will be using the shared version coming from the first model. + ASSERT_EQ(if_node_branches_shared_prepack_counter_2, static_cast(2)); +} + INSTANTIATE_TEST_SUITE_P(SessionStateTests, SessionStatePrepackingTest, testing::Values(PrepackingTestParam{false, false}, diff --git a/onnxruntime/test/framework/tensor_test.cc b/onnxruntime/test/framework/tensor_test.cc index f24064a403c5d..38e3f184ebc18 100644 --- a/onnxruntime/test/framework/tensor_test.cc +++ b/onnxruntime/test/framework/tensor_test.cc @@ -214,6 +214,7 @@ TEST(TensorTest, Strided) { TensorShape shape({2, 3, 4}); auto alloc = TestCPUExecutionProvider()->CreatePreferredAllocators()[0]; void* data = alloc->Alloc(shape.Size() * sizeof(float)); + Tensor t(DataTypeImpl::GetType(), shape, data, alloc->Info()); EXPECT_TRUE(t.IsContiguous()); const TensorShapeVector strides{12, 4, 1}; @@ -227,6 +228,7 @@ TEST(TensorTest, Strided) { ASSERT_EQ(t.Shape(), new_shape); ASSERT_THAT(t.Strides(), testing::ContainerEq(gsl::make_span(new_strides))); ASSERT_EQ(t.SizeInBytes(), sizeof(float) * 24); + Tensor t2(DataTypeImpl::GetType(), new_shape, data, alloc->Info(), 0L, gsl::make_span(new_strides)); EXPECT_FALSE(t2.IsContiguous()); ASSERT_EQ(t2.Shape(), new_shape); @@ -237,7 +239,9 @@ TEST(TensorTest, Strided) { ASSERT_EQ(t2.Shape(), shape); ASSERT_THAT(t2.Strides(), testing::ContainerEq(gsl::make_span(strides))); ASSERT_EQ(t2.SizeInBytes(), sizeof(float) * 24); + alloc->Free(data); + data = alloc->Alloc(sizeof(int64_t)); const TensorShapeVector single_element_strides{0, 0, 0}; Tensor t3(DataTypeImpl::GetType(), shape, data, alloc->Info(), 0L, gsl::make_span(single_element_strides)); @@ -246,8 +250,10 @@ TEST(TensorTest, Strided) { ASSERT_THAT(t3.Strides(), testing::ContainerEq(gsl::make_span(single_element_strides))); ASSERT_EQ(t3.SizeInBytes(), sizeof(int64_t)); alloc->Free(data); + const TensorShapeVector zero_strides{0, 0, 0}; - Tensor t4(DataTypeImpl::GetType(), shape, alloc, zero_strides); + Tensor t4(DataTypeImpl::GetType(), shape, alloc); + t4.SetShapeAndStrides(shape, zero_strides); EXPECT_FALSE(t4.IsContiguous()); EXPECT_EQ(t4.Shape(), shape); ASSERT_THAT(t4.Strides(), testing::ContainerEq(gsl::make_span(zero_strides))); diff --git a/onnxruntime/test/framework/test_tensor_loader.cc b/onnxruntime/test/framework/test_tensor_loader.cc index f627409bb0e60..e71830be08b5e 100644 --- a/onnxruntime/test/framework/test_tensor_loader.cc +++ b/onnxruntime/test/framework/test_tensor_loader.cc @@ -1,11 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include "gtest/gtest.h" + #include "core/common/common.h" #include "core/framework/callback.h" #include "core/framework/tensorprotoutils.h" -#include "gtest/gtest.h" -#include "file_util.h" +#include "test/util/include/file_util.h" +#include "test/util/include/asserts.h" #ifdef _WIN32 #include @@ -30,13 +32,14 @@ TEST(CApiTensorTest, load_simple_float_tensor_not_enough_space) { std::vector output(1); OrtValue value; OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); - auto st = utils::TensorProtoToMLValue(Env::Default(), nullptr, p, - MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), value); - // check the result - ASSERT_FALSE(st.IsOK()); + + ASSERT_STATUS_NOT_OK( + utils::TensorProtoToOrtValue(Env::Default(), nullptr, p, + MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), + value)); } -TEST(CApiTensorTest, load_simple_float_tensor) { +TEST(CApiTensorTest, load_simple_float_tensor_membuffer) { // construct a tensor proto onnx::TensorProto p; p.mutable_float_data()->Add(1.0f); @@ -51,9 +54,37 @@ TEST(CApiTensorTest, load_simple_float_tensor) { std::vector output(3); OrtValue value; OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); - auto st = utils::TensorProtoToMLValue(Env::Default(), nullptr, p, - MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), value); - ASSERT_TRUE(st.IsOK()) << st.ErrorMessage(); + ASSERT_STATUS_OK( + utils::TensorProtoToOrtValue(Env::Default(), nullptr, p, + MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), + value)); + float* real_output; + auto ort_st = g_ort->GetTensorMutableData(&value, (void**)&real_output); + ASSERT_EQ(ort_st, nullptr) << g_ort->GetErrorMessage(ort_st); + // check the result + ASSERT_EQ(real_output[0], 1.0f); + ASSERT_EQ(real_output[1], 2.2f); + ASSERT_EQ(real_output[2], 3.5f); + g_ort->ReleaseStatus(ort_st); +} + +TEST(CApiTensorTest, load_simple_float_tensor_allocator) { + // construct a tensor proto + onnx::TensorProto p; + p.mutable_float_data()->Add(1.0f); + p.mutable_float_data()->Add(2.2f); + p.mutable_float_data()->Add(3.5f); + p.mutable_dims()->Add(3); + p.set_data_type(onnx::TensorProto_DataType_FLOAT); + std::string s; + // save it to a buffer + ASSERT_TRUE(p.SerializeToString(&s)); + // deserialize it + AllocatorPtr tmp_allocator = std::make_shared(); + OrtValue value; + + ASSERT_STATUS_OK(utils::TensorProtoToOrtValue(Env::Default(), nullptr, p, tmp_allocator, value)); + float* real_output; auto ort_st = g_ort->GetTensorMutableData(&value, (void**)&real_output); ASSERT_EQ(ort_st, nullptr) << g_ort->GetErrorMessage(ort_st); @@ -106,9 +137,9 @@ static void run_external_data_test() { } OrtValue value; OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); - auto st = utils::TensorProtoToMLValue(Env::Default(), nullptr, p, - MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), value); - ASSERT_TRUE(st.IsOK()) << st.ErrorMessage(); + ASSERT_STATUS_OK(utils::TensorProtoToOrtValue( + Env::Default(), nullptr, p, MemBuffer(output.data(), output.size() * sizeof(float), cpu_memory_info), value)); + float* real_output; auto ort_st = g_ort->GetTensorMutableData(&value, (void**)&real_output); ASSERT_EQ(ort_st, nullptr) << g_ort->GetErrorMessage(ort_st); @@ -156,11 +187,10 @@ TEST(CApiTensorTest, load_huge_tensor_with_external_data) { std::vector output(total_ele_count); OrtValue value; OrtMemoryInfo cpu_memory_info(onnxruntime::CPU, OrtDeviceAllocator, OrtDevice(), 0, OrtMemTypeDefault); - auto st = utils::TensorProtoToMLValue(Env::Default(), nullptr, p, - MemBuffer(output.data(), output.size() * sizeof(int), cpu_memory_info), value); + ASSERT_STATUS_OK( + utils::TensorProtoToOrtValue(Env::Default(), nullptr, p, + MemBuffer(output.data(), output.size() * sizeof(int), cpu_memory_info), value)); - // check the result - ASSERT_TRUE(st.IsOK()) << "Error from TensorProtoToMLValue: " << st.ErrorMessage(); int* buffer; auto ort_st = g_ort->GetTensorMutableData(&value, (void**)&buffer); ASSERT_EQ(ort_st, nullptr) << "Error from OrtGetTensorMutableData: " << g_ort->GetErrorMessage(ort_st); diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc index 6aa7c5ee9f8ac..0d9e557ebc813 100644 --- a/onnxruntime/test/framework/tunable_op_test.cc +++ b/onnxruntime/test/framework/tunable_op_test.cc @@ -263,8 +263,7 @@ TEST(TunableOp, OpWrapsMutableFunctor) { class VecAddMoveOnlyFunctor { public: - VecAddMoveOnlyFunctor() { - } + VecAddMoveOnlyFunctor() = default; VecAddMoveOnlyFunctor(VecAddMoveOnlyFunctor&&) = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(VecAddMoveOnlyFunctor); @@ -290,8 +289,7 @@ TEST(TunableOp, OpWrapsMoveOnlyFunctor) { class VecAddWithIsSupportedMethod { public: - VecAddWithIsSupportedMethod() { - } + VecAddWithIsSupportedMethod() = default; VecAddWithIsSupportedMethod(VecAddWithIsSupportedMethod&&) = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(VecAddWithIsSupportedMethod); diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 062ca4ece86bf..287d657a2ce28 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -56,6 +56,7 @@ void usage() { "\t [QNN only] [rpc_control_latency]: QNN rpc control latency. default to 10.\n" "\t [QNN only] [htp_performance_mode]: QNN performance mode, options: 'burst', 'balanced', 'default', 'high_performance', \n" "\t 'high_power_saver', 'low_balanced', 'low_power_saver', 'power_saver', 'sustained_high_performance'. Default to 'default'. \n" + "\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" "\t [Usage]: -e -i '| |' \n\n" "\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" @@ -477,6 +478,8 @@ int real_main(int argc, char* argv[], Ort::Env& env) { std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_performance_mode. select from: " + str); } + } else if (key == "qnn_saver_path") { + // no validation } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'qnn_context_cache_enable', 'qnn_context_cache_path', 'profiling_level', 'rpc_control_latency', 'htp_performance_mode'])"); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index dce1f2d40e8b9..6acf631d53cd9 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -1000,6 +1000,26 @@ TEST_F(GraphTransformationTests, ConstantFoldingAShapeNodeDeepInTheGraph) { ASSERT_TRUE(op_to_count.size() == 0U); } +// Test we don't fail when constant folding hits a string initializer +TEST_F(GraphTransformationTests, ConstantFoldingStringInitializer) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "gh_issue_17392.onnx"; + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(model_uri, model, nullptr, *logger_)); + Graph& graph = model->MainGraph(); + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Identity"], 1); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + std::unique_ptr e = std::make_unique(CPUExecutionProviderInfo()); + ASSERT_STATUS_OK(graph_transformation_mgr.Register( + std::make_unique(*e.get(), false /*skip_dequantize_linear*/), TransformerLevel::Level1)); + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + op_to_count = CountOpsInGraph(graph); + + ASSERT_EQ(op_to_count.size(), 0U) << "Identity node should have been removed"; +} + // Check transformations in the case of a subgraph with constant inputs. TEST_F(GraphTransformationTests, SubgraphWithConstantInputs) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "constant-subgraph.onnx"; @@ -6280,7 +6300,7 @@ TEST_F(GraphTransformationTests, ConstantSharing_ShouldNotShareForGraphOutput) { TEST_F(GraphTransformationTests, GatherToSplitFusion) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInput({{54}}); - auto* shape_arg = builder.MakeInput({{1}}); + auto* shape_arg = builder.MakeInput({{4}}); auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); auto* gather_index_1 = builder.MakeInitializer({}, {static_cast(0)}); auto* gather_index_2 = builder.MakeInitializer({}, {static_cast(1)}); @@ -6393,7 +6413,7 @@ TEST_F(GraphTransformationTests, GatherToSplitFusion) { TEST_F(GraphTransformationTests, GatherToSplitFusion_NoSqueeze) { auto build_test_case = [&](ModelTestBuilder& builder) { auto* data_arg = builder.MakeInput({{54}}); - auto* shape_arg = builder.MakeInput({{1}}); + auto* shape_arg = builder.MakeInput({{4}}); auto* reshape_out = builder.MakeIntermediate({{2, 3, 3, 3}}); auto* gather_index_1 = builder.MakeInitializer({1}, {static_cast(0)}); auto* gather_index_2 = builder.MakeInitializer({1}, {static_cast(1)}); diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index d3616a14d8a5d..1bf1cbacf479e 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -3085,12 +3085,12 @@ TEST(QDQTransformerTests, QDQPropagation_Per_Layer_No_Propagation) { check_graph, TransformerLevel::Default, TransformerLevel::Level1, - 18); // disable TransposeOptimizer for simplicity + 18); TransformerTester(build_test_case, check_graph, TransformerLevel::Default, TransformerLevel::Level1, - 19); // disable TransposeOptimizer for simplicity + 19); }; test_case({1, 13, 13, 23}, {0, 2, 3, 1}, false /*use_contrib_qdq*/); diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index 1f4c499985ad0..4f4157bd7b1cf 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -12,6 +12,9 @@ #include "core/graph/node_attr_utils.h" #include "core/framework/op_node_proto_helper.h" #include "core/framework/utils.h" +#include "core/optimizer/transpose_optimization/onnx_transpose_optimization.h" +#include "core/optimizer/transpose_optimization/optimizer_api.h" +#include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" #include "test/test_environment.h" @@ -19,6 +22,7 @@ #include "test/providers/internal_testing/internal_testing_execution_provider.h" #include "test/util/include/asserts.h" #include "test/util/include/inference_session_wrapper.h" +#include "test/util/include/test_utils.h" namespace onnxruntime { namespace test { @@ -316,20 +320,6 @@ TEST(TransposeOptimizerTests, TestPadNonconst) { /*opset_version*/ {11, 18}); } -// The CUDA Resize kernel assumes that the input is NCHW and -// Resize can't be supported in ORT builds with CUDA enabled. -// TODO: Enable this once the CUDA Resize kernel is implemented -// "generically" (i.e.) aligning with the generic nature of the -// ONNX spec. -// See https://github.com/microsoft/onnxruntime/pull/10824 for -// a similar fix applied to the CPU Resize kernel. -// Per tests included in #10824, the ROCM EP also generates -// incorrect results when this handler is used, so the Resize -// handler is not enabled even for those builds. -// -// The QNN EP requires the input to be NHWC, so the Resize handler is also not enabled -// for QNN builds. -#if !defined(USE_CUDA) && !defined(USE_ROCM) && !defined(USE_QNN) TEST(TransposeOptimizerTests, TestResize) { auto build_test_case_1 = [&](ModelTestBuilder& builder) { auto* input0_arg = MakeInput(builder, {{4, -1, 2, -1}}, {4, 6, 2, 10}, 0.0, 1.0); @@ -358,7 +348,9 @@ TEST(TransposeOptimizerTests, TestResize) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {10, 18}); } @@ -386,7 +378,9 @@ TEST(TransposeOptimizerTests, TestResizeOpset11) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {11, 18}); } @@ -414,7 +408,9 @@ TEST(TransposeOptimizerTests, TestResizeOpset15) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {15, 18}); } @@ -444,7 +440,9 @@ TEST(TransposeOptimizerTests, TestResizeSizeRoi) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {15, 18}); } @@ -478,7 +476,9 @@ TEST(TransposeOptimizerTests, TestResizeRoiScalesZeroRank0) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, {12, 18}); } @@ -507,7 +507,9 @@ TEST(TransposeOptimizerTests, TestResizeNonconst) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {11, 18}); } @@ -536,11 +538,12 @@ TEST(TransposeOptimizerTests, TestResizeNonconstOpset13) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {13, 18}); } -#endif TEST(TransposeOptimizerTests, TestAdd) { auto build_test_case_1 = [&](ModelTestBuilder& builder) { auto* input0_arg = builder.MakeInput({4, 6, 10}, 0.0, 1.0); @@ -4395,9 +4398,9 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue9671) { SessionOptions so; so.session_logid = "TransposeOptimizerTests.RegressionTest_GitHubIssue9671"; - InferenceSession session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.Load(model_uri)); - ASSERT_STATUS_OK(session_object.Initialize()); // optimizers run during initialization + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); // optimizers run during initialization } // regression test for a model where the transpose optimizations incorrectly removed a node providing an implicit @@ -4409,9 +4412,9 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue10305) { SessionOptions so; so.session_logid = "TransposeOptimizerTests.RegressionTest_GitHubIssue10305"; - InferenceSession session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.Load(model_uri)); - ASSERT_STATUS_OK(session_object.Initialize()); // optimizers run during initialization + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); // optimizers run during initialization } // regression test for a model with DQ node with per-axis dequantization followed by a Transpose. @@ -4432,30 +4435,31 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue12151) { { so.graph_optimization_level = TransformerLevel::Default; // off - InferenceSession session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.Load(model_uri)); - ASSERT_STATUS_OK(session_object.Initialize()); - ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches_orig)); + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig)); } { so.graph_optimization_level = TransformerLevel::Level1; // enable transpose optimizer - InferenceSession session_object{so, GetEnvironment()}; - ASSERT_STATUS_OK(session_object.Load(model_uri)); - ASSERT_STATUS_OK(session_object.Initialize()); - ASSERT_STATUS_OK(session_object.Run(feeds, output_names, &fetches)); + InferenceSession session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); } ASSERT_THAT(fetches_orig[0].Get().DataAsSpan(), testing::ContainerEq(fetches[0].Get().DataAsSpan())); } +// These tests uses internal testing EP with static kernels which requires a full build, +// and the NHWC Conv which requires contrib ops +#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) + // Test a Transpose node followed by a Reshape that is logically equivalent to an Transpose can be merged. // The test graph was extracted from a model we were trying to use with the QNN EP. TEST(TransposeOptimizerTests, QnnTransposeReshape) { - // test uses internal testing EP with static kernels which requires a full build, - // and the NHWC Conv with requires contrib ops -#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) Status status; auto model_uri = ORT_TSTR("testdata/layout_transform_reshape.onnx"); @@ -4497,14 +4501,17 @@ TEST(TransposeOptimizerTests, QnnTransposeReshape) { for (const auto& node : graph.Nodes()) { EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; + + if (node.Name() == "Mul_212" || node.Name() == "Add_213") { + // check that the special case in TransposeInputs for a single element input reconnects things back up correctly + const auto& inputs = node.InputDefs(); + EXPECT_EQ(inputs.size(), size_t(2)); + EXPECT_TRUE(inputs[1]->Exists()); + } } -#endif } TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { - // test uses internal testing EP with static kernels which requires a full build, - // and the NHWC Conv with requires contrib ops -#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) Status status; auto model_uri = ORT_TSTR("testdata/layout_transform_reshape.qdq.onnx"); @@ -4541,7 +4548,128 @@ TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; } -#endif +} + +// Validate handling for EP with layout specific Resize that prefers NHWC +TEST(TransposeOptimizerTests, QnnResizeOpset11) { + Status status; + auto model_uri = ORT_TSTR("testdata/nhwc_resize_scales_opset11.onnx"); + + SessionOptions so; + // Uncomment to debug + // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); + + using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + + // set the test EP to support all ops in the model so that the layout transform applies to all nodes + const std::unordered_set empty_set; + auto internal_testing_ep = std::make_unique(empty_set, empty_set, DataLayout::NHWC); + internal_testing_ep->EnableStaticKernels().TakeAllNodes(); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(internal_testing_ep))); + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + + const auto& graph = session.GetGraph(); + // all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout + std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + for (const auto& node : graph.Nodes()) { + EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() + << "' was not assigned to the internal testing EP."; + if (node.OpType() == "Resize") { + EXPECT_EQ(node.Domain(), kMSInternalNHWCDomain) << "Resize was not converted to NHWC layout"; + } + } + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Transpose"], 2) << "Resize should have been wrapped in 2 Transpose nodes to convert to NHWC"; + + // And the post-Resize Transpose should have been pushed all the way to the end + GraphViewer viewer(graph); + EXPECT_EQ(graph.GetNode(viewer.GetNodesInTopologicalOrder().back())->OpType(), "Transpose"); +} +#endif // !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) + +static void CheckSharedInitializerHandling(bool broadcast) { + auto model_uri = broadcast ? ORT_TSTR("testdata/transpose_optimizer_shared_initializers_broadcast.onnx") + : ORT_TSTR("testdata/transpose_optimizer_shared_initializers.onnx"); + + RandomValueGenerator random{123}; + std::vector input_dims{1, 2, 2, 3}; + std::vector input_data = random.Gaussian(input_dims, 0.0f, 1.0f); + + OrtValue input; + CreateMLValue(TestCPUExecutionProvider()->CreatePreferredAllocators()[0], input_dims, input_data, &input); + + NameMLValMap feeds{{"input0", input}}; + + std::vector output_names{"output0"}; + std::vector fetches_orig; + std::vector fetches; + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); + ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kOrtSessionOptionsDisableQuantQDQ, "1")); + + // get results with no modifications to the model + { + so.graph_optimization_level = TransformerLevel::Default; // off + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches_orig)); + } + + { + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.Load(model_uri)); + + // we call the ONNX transpose optimizer directly to simplify the model required to exercise the shared initializer + // handling. this means we don't need to disable optimizers that might alter the graph before the + // transpose optimizer runs (at a minimum ConstantFolding, CommonSubexpressionElimination and ConstantSharing). + Graph& graph = session.GetMutableGraph(); + CPUAllocator allocator; + + using namespace onnx_transpose_optimization; + auto api_graph = MakeApiGraph(graph, TestCPUExecutionProvider()->CreatePreferredAllocators()[0], + /*new_node_ep*/ nullptr); + + // default optimization cost check + OptimizeResult result = Optimize(*api_graph); + + ASSERT_EQ(result.error_msg, std::nullopt); + ASSERT_TRUE(result.graph_modified); + ASSERT_TRUE(graph.GraphResolveNeeded()); + + std::map op_to_count = CountOpsInGraph(graph); + EXPECT_EQ(op_to_count["Transpose"], 0) << "The Transpose nodes should have been pushed through and canceled out."; + + ASSERT_STATUS_OK(graph.Resolve()); + + ASSERT_STATUS_OK(session.Initialize()); + ASSERT_STATUS_OK(session.Run(feeds, output_names, &fetches)); + } + + ASSERT_THAT(fetches_orig[0].Get().DataAsSpan(), + testing::ContainerEq(fetches[0].Get().DataAsSpan())); +} + +// test we re-use a modified shared initializer wherever possible. model has one initializer that is used by 3 DQ nodes +// and one initializer that is used by 2 Add nodes. both cases should be handled with the initializer being +// modified in-place for the first usage, and the Transpose added to the second usage being cancelled out when the +// original Transpose at the start of the model is pushed down. +TEST(TransposeOptimizerTests, SharedInitializerHandling) { + CheckSharedInitializerHandling(/*broadcast*/ false); +} + +// same setup as the above test, however the initializer is broadcast to bring UnsqueezeInput into play. +// the in-place modification of the initializer for the first usage results in +// -> Transpose -> Squeeze -> {DQ | Add} +// the later usages of the initializer should attempt to cancel out the Squeeze in UnsqueezeInput, +// followed by canceling out the Transpose in TransposeInput. +TEST(TransposeOptimizerTests, SharedInitializerHandlingBroadcast) { + CheckSharedInitializerHandling(/*broadcast*/ true); } } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc index 6d8e05b93510a..8008fd129c19b 100644 --- a/onnxruntime/test/providers/cpu/controlflow/scan_test.cc +++ b/onnxruntime/test/providers/cpu/controlflow/scan_test.cc @@ -578,7 +578,7 @@ TEST(Scan9, DISABLED_BadShape) { ShortSequenceOneInBatchOneLoopStateVar( options, "Node:concat Output:concat_out_1 [ShapeInferenceError] Mismatch between number of source and target dimensions. " - "Source=2 Target=1"); + "inferred=2 declared=1"); } TEST(Scan8, ShortSequenceTwoInBatchOneLoopStateVar) { diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 9b41ba8c0d2ba..da906ebf76f79 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -1172,6 +1172,12 @@ ::std::vector<::std::basic_string> GetParameterStrings() { ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("coreml_VGG16_ImageNet"), + ORT_TSTR("VGG 16-fp32"), + ORT_TSTR("VGG 19-caffe2"), + ORT_TSTR("VGG 19-bn"), + ORT_TSTR("VGG 16-bn"), + ORT_TSTR("VGG 19"), + ORT_TSTR("VGG 16"), ORT_TSTR("faster_rcnn"), ORT_TSTR("GPT2"), ORT_TSTR("GPT2_LM_HEAD"), diff --git a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc index c8343483b80a6..cb5fc8095982c 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_fp16_test.cc @@ -109,7 +109,7 @@ TEST(ConvFp16Test, Conv1D_Invalid_Input_Shape) { TestConvFp16Op(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, OpTester::ExpectResult::kExpectFailure, "Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. " - "Both source and target dimension have values but they differ. Source=0 Target=2 Dimension=2", + "Both inferred and declared dimension have values but they differ. Inferred=0 Declared=2 Dimension=2", -1); // use latest opset for shape inferencing errors } @@ -132,7 +132,7 @@ TEST(ConvFp16Test, Conv2D_Invalid_Input_Shape) { TestConvFp16Op(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, OpTester::ExpectResult::kExpectFailure, "Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. " - "Both source and target dimension have values but they differ. Source=1 Target=2 Dimension=0", + "Both inferred and declared dimension have values but they differ. Inferred=1 Declared=2 Dimension=0", -1); // use latest opset for shape inferencing errors } diff --git a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc index e01fd8c78e55f..5103aed50b152 100644 --- a/onnxruntime/test/providers/cpu/nn/conv_op_test.cc +++ b/onnxruntime/test/providers/cpu/nn/conv_op_test.cc @@ -249,7 +249,7 @@ TEST(ConvTest, Conv1D_Invalid_Input_Shape) { TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, OpTester::ExpectResult::kExpectFailure, "Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. " - "Both source and target dimension have values but they differ. Source=0 Target=2 Dimension=2", + "Both inferred and declared dimension have values but they differ. Inferred=0 Declared=2 Dimension=2", -1); // use latest opset for shape inferencing errors } @@ -272,7 +272,7 @@ TEST(ConvTest, Conv2D_Invalid_Input_Shape) { TestConvOp(attrs, {X, dummy_vals}, {X_shape, dummy_shape}, dummy_vals, dummy_shape, false, OpTester::ExpectResult::kExpectFailure, "Node:node1 Output:Y [ShapeInferenceError] Can't merge shape info. " - "Both source and target dimension have values but they differ. Source=1 Target=2 Dimension=0", + "Both inferred and declared dimension have values but they differ. Inferred=1 Declared=2 Dimension=0", -1); // use latest opset for shape inferencing errors } diff --git a/onnxruntime/test/providers/cpu/nn/tfidfvectorizer_test.cc b/onnxruntime/test/providers/cpu/nn/tfidfvectorizer_test.cc index a22253dbb74d4..379b892f39135 100644 --- a/onnxruntime/test/providers/cpu/nn/tfidfvectorizer_test.cc +++ b/onnxruntime/test/providers/cpu/nn/tfidfvectorizer_test.cc @@ -91,7 +91,7 @@ TEST(TfIdfVectorizerTest, Int32_TF_onlyBigrams_Skip0_Empty_Dim1Fail) { test.Run(OpTester::ExpectResult::kExpectFailure, "Can't merge shape info. " - "Both source and target dimension have values but they differ. Source=7 Target=0 Dimension=0"); + "Both inferred and declared dimension have values but they differ. Inferred=7 Declared=0 Dimension=0"); } TEST(TfIdfVectorizerTest, Int32_TF_onlyBigrams_Skip0_Empty_Dim1Success) { @@ -136,7 +136,7 @@ TEST(TfIdfVectorizerTest, Int32_TF_onlyBigrams_Skip0_Empty_Dim2) { test.AddOutput("Y", out_dims, output); test.Run(OpTester::ExpectResult::kExpectFailure, - "Mismatch between number of source and target dimensions. Source=2 Target=1"); + "Mismatch between number of inferred and declared dimensions. inferred=2 declared=1"); } TEST(TfIdfVectorizerTest, Int32_TF_onlyBigrams_Skip01_Empty_Dim2) { @@ -159,7 +159,7 @@ TEST(TfIdfVectorizerTest, Int32_TF_onlyBigrams_Skip01_Empty_Dim2) { test.AddOutput("Y", out_dims, output); test.Run(OpTester::ExpectResult::kExpectFailure, - "Mismatch between number of source and target dimensions. Source=2 Target=1"); + "Mismatch between number of inferred and declared dimensions. inferred=2 declared=1"); } TEST(TfIdfVectorizerTest, Int32_TF_onlyBigrams_Skip0_Empty_Dim2N) { diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 0434b16dc66ce..2ead9ec91f93f 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -780,7 +780,7 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_5DTrilinear_pytorch_half_pixel) { } TEST(ResizeOpTest, ResizeOpLinearScalesNoOpTest) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; diff --git a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc index c334e0c5ddcb6..0e7ac5ed2b2f0 100644 --- a/onnxruntime/test/providers/cpu/tensor/transpose_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/transpose_test.cc @@ -37,7 +37,7 @@ TEST(TransposeOpTest, PermRankDoesNotMatchTensorRank) { // This failure comes from shape inference, because in this case it knows the input dims. // But in the real world, the model can supply different input dims at runtime. test.Run(OpTester::ExpectResult::kExpectFailure, - "Node:node1 Output:Y [ShapeInferenceError] Mismatch between number of source and target dimensions. Source=3 Target=4"); + "Node:node1 Output:Y [ShapeInferenceError] Mismatch between number of inferred and declared dimensions. inferred=3 declared=4"); } // Some of the tests can't run on TensorrtExecutionProvider because of errors. diff --git a/onnxruntime/test/providers/kernel_compute_test_utils.cc b/onnxruntime/test/providers/kernel_compute_test_utils.cc index 0bcf3cbbfd089..977a5bd9ea7b8 100644 --- a/onnxruntime/test/providers/kernel_compute_test_utils.cc +++ b/onnxruntime/test/providers/kernel_compute_test_utils.cc @@ -5,6 +5,8 @@ #include "test/providers/kernel_compute_test_utils.h" +#include + #include "core/framework/execution_providers.h" #include "core/optimizer/optimizer_execution_frame.h" #include "test/util/include/default_providers.h" @@ -55,12 +57,20 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { } #if defined(USE_CUDA) || defined(USE_ROCM) if ((provider_ == kCudaExecutionProvider || provider_ == kRocmExecutionProvider) && !data.is_cpu_data_) { - OrtValue gpu_value; const Tensor& tensor = data.value_.Get(); - Tensor::InitOrtValue(tensor.DataType(), tensor.Shape(), - execution_providers.Get(ep_type)->CreatePreferredAllocators()[0], gpu_value, - tensor.Strides()); - ASSERT_STATUS_OK(dtm.CopyTensor(tensor, *gpu_value.GetMutable())); + + Tensor gpu_tensor(tensor.DataType(), tensor.Shape(), + execution_providers.Get(ep_type)->CreatePreferredAllocators()[0]); + + if (const auto strides = tensor.Strides(); !strides.empty()) { + gpu_tensor.SetShapeAndStrides(tensor.Shape(), strides); + } + + ASSERT_STATUS_OK(dtm.CopyTensor(tensor, gpu_tensor)); + + OrtValue gpu_value; + Tensor::InitOrtValue(std::move(gpu_tensor), gpu_value); + initializer_map[name] = gpu_value; } #endif @@ -161,8 +171,7 @@ void KernelComputeTester::Run(std::unordered_set strided_outputs) { } else { const Tensor& tensor = outputs[i].Get(); Tensor::InitOrtValue(tensor.DataType(), tensor.Shape(), - execution_providers.Get(cpu_ep_type)->CreatePreferredAllocators()[0], cpu_value, - tensor.Strides()); + execution_providers.Get(cpu_ep_type)->CreatePreferredAllocators()[0], cpu_value); ASSERT_STATUS_OK(dtm.CopyTensor(tensor, *cpu_value.GetMutable())); } diff --git a/onnxruntime/test/providers/kernel_compute_test_utils.h b/onnxruntime/test/providers/kernel_compute_test_utils.h index aed5856fea5a2..a93fd24a3ad4f 100644 --- a/onnxruntime/test/providers/kernel_compute_test_utils.h +++ b/onnxruntime/test/providers/kernel_compute_test_utils.h @@ -75,7 +75,12 @@ class KernelComputeTester { OrtValue value; TensorShape shape(dims); auto allocator = AllocatorManager::Instance().GetAllocator(CPU); - Tensor::InitOrtValue(DataTypeImpl::GetType(), shape, std::move(allocator), value, strides); + Tensor::InitOrtValue(DataTypeImpl::GetType(), shape, std::move(allocator), value); + + if (!strides.empty()) { + value.GetMutable()->SetShapeAndStrides(shape, strides); + } + if (values) { Tensor* tensor = value.GetMutable(); auto* p_data = tensor->MutableData(); diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index a441e828c0cc6..5f63813d8d84e 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "core/session/onnxruntime_cxx_api.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -172,7 +173,7 @@ TEST(QnnEP, TestDisableCPUFallback_ConflictingConfig) { // The models passed to this function are subgraphs extracted from a larger model that exhibited // shape inferencing issues on QNN. Thus, the models are expected to have a specific input/output // types and shapes. -static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp) { +static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bool enable_qnn_saver = false) { Ort::SessionOptions so; // Ensure all type/shape inference warnings result in errors! @@ -183,8 +184,14 @@ static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp) { #if defined(_WIN32) options["backend_path"] = use_htp ? "QnnHtp.dll" : "QnnCpu.dll"; + if (enable_qnn_saver) { + options["qnn_saver_path"] = "QnnSaver.dll"; + } #else options["backend_path"] = use_htp ? "libQnnHtp.so" : "libQnnCpu.so"; + if (enable_qnn_saver) { + options["qnn_saver_path"] = "libQnnSaver.so"; + } #endif so.AppendExecutionProvider("QNN", options); @@ -226,7 +233,7 @@ static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp) { auto typeshape = ort_output.GetTensorTypeAndShapeInfo(); std::vector output_shape = typeshape.GetShape(); - ASSERT_THAT(output_shape, ::testing::ElementsAre(1, 6, 7, 10)); + EXPECT_THAT(output_shape, ::testing::ElementsAre(1, 6, 7, 10)); } // Test shape inference of NHWC Resize operator (opset 11) that uses @@ -253,6 +260,23 @@ TEST_F(QnnCPUBackendTests, TestNHWCResizeShapeInference_sizes_opset18) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", false); } +// Test that QNN Saver generates the expected files for a model meant to run on the QNN CPU backend. +TEST_F(QnnCPUBackendTests, QnnSaver_OutputFiles) { + const std::filesystem::path qnn_saver_output_dir = "saver_output"; + + // Remove pre-existing QNN Saver output files. Note that fs::remove_all() can handle non-existing paths. + std::filesystem::remove_all(qnn_saver_output_dir); + ASSERT_FALSE(std::filesystem::exists(qnn_saver_output_dir)); + + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", + false, // use_htp + true); // enable_qnn_saver + + // Check that QNN Saver output files exist. + EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "saver_output.c")); + EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "params.bin")); +} + #if defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) // Test shape inference of QDQ NHWC Resize operator (opset 18) that uses @@ -261,6 +285,23 @@ TEST_F(QnnHTPBackendTests, TestNHWCResizeShapeInference_qdq_sizes_opset18) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", true); } +// Test that QNN Saver generates the expected files for a model meant to run on the QNN HTP backend. +TEST_F(QnnHTPBackendTests, QnnSaver_OutputFiles) { + const std::filesystem::path qnn_saver_output_dir = "saver_output"; + + // Remove pre-existing QNN Saver output files. Note that fs::remove_all() can handle non-existing paths. + std::filesystem::remove_all(qnn_saver_output_dir); + ASSERT_FALSE(std::filesystem::exists(qnn_saver_output_dir)); + + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.onnx", + true, // use_htp + true); // enable_qnn_saver + + // Check that QNN Saver output files exist. + EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "saver_output.c")); + EXPECT_TRUE(std::filesystem::exists(qnn_saver_output_dir / "params.bin")); +} + #endif // defined(__aarch64__) || defined(_M_ARM64) || defined(__linux__) #endif // !defined(ORT_MINIMAL_BUILD) diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.cc b/onnxruntime/test/providers/qnn/qnn_test_utils.cc index 51df93f8853ec..a067c9c53e57a 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.cc +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.cc @@ -9,6 +9,7 @@ #include "test/util/include/default_providers.h" #include "test/util/include/test/test_environment.h" +#include "core/platform/env_var_utils.h" #include "core/common/span_utils.h" #include "core/framework/compute_capability.h" #include "core/graph/graph.h" @@ -41,7 +42,22 @@ std::vector GetFloatDataInRange(float min_val, float max_val, size_t num_ return data; } -void RunQnnModelTest(const GetTestModelFn& build_test_case, const ProviderOptions& provider_options, +void TryEnableQNNSaver(ProviderOptions& qnn_options) { + // Allow dumping QNN API calls to file by setting an environment variable that enables the QNN Saver backend. + constexpr auto kEnableQNNSaverEnvironmentVariableName = "ORT_UNIT_TEST_ENABLE_QNN_SAVER"; + static std::optional enable_qnn_saver = onnxruntime::ParseEnvironmentVariable( + kEnableQNNSaverEnvironmentVariableName); + + if (enable_qnn_saver.has_value() && *enable_qnn_saver != 0) { +#if defined(_WIN32) + qnn_options["qnn_saver_path"] = "QnnSaver.dll"; +#else + qnn_options["qnn_saver_path"] = "libQnnSaver.so"; +#endif // defined(_WIN32) + } +} + +void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err, logging::Severity log_severity) { EPVerificationParams verification_params; @@ -65,6 +81,7 @@ void RunQnnModelTest(const GetTestModelFn& build_test_case, const ProviderOption // Serialize the model to a string. std::string model_data; model.ToProto().SerializeToString(&model_data); + TryEnableQNNSaver(provider_options); RunAndVerifyOutputsWithEP(AsByteSpan(model_data.data(), model_data.size()), "QNN_EP_TestLogID", QnnExecutionProviderWithOptions(provider_options), helper.feeds_, verification_params); diff --git a/onnxruntime/test/providers/qnn/qnn_test_utils.h b/onnxruntime/test/providers/qnn/qnn_test_utils.h index 14c62f98f6a3e..b4c84d893c828 100644 --- a/onnxruntime/test/providers/qnn/qnn_test_utils.h +++ b/onnxruntime/test/providers/qnn/qnn_test_utils.h @@ -220,6 +220,25 @@ void InferenceModel(const std::string& model_data, const char* log_id, ExpectedEPNodeAssignment expected_ep_assignment, const NameMLValMap& feeds, std::vector& output_vals); +/** + * If the ORT_UNIT_TEST_ENABLE_QNN_SAVER environment variable is enabled (set to 1), this function modifies + * the QNN EP provider options to enable the QNN Saver backend, which dumps QNN API calls (and weights) to disk. + * + * - saver_output/saver_output.c: C file containing all QNN API calls. + * - saver_output/params.bin: binary file containing all input/output/parameter tensor data provided during tensor + * creation, op config validation, and graph execution. + * + * Enabling the QNN Saver backend has 2 note-worthy effects: + * 1. All QNN API calls will succeed. + * 2. Inference output returns dummy data. + * + * Because output files from QNN Saver are always overwritten, it is recommended to run individual unit tests via the + * --gtest_filter command-line option. Ex: --gtest_filter=QnnHTPBackendTests.Resize_DownSample_Linear_AlignCorners + * + * \param qnn_options QNN EP provider options that may be modified to enable QNN Saver. + */ +void TryEnableQNNSaver(ProviderOptions& qnn_options); + /** * Tests the accuracy of a QDQ model on QNN EP by runnning 3 inferences: * @@ -240,7 +259,7 @@ void InferenceModel(const std::string& model_data, const char* log_id, */ template inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTestQDQModelFn& qdq_model_fn, - const ProviderOptions& qnn_options, int opset_version, + ProviderOptions qnn_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err = 1e-4f, logging::Severity log_severity = logging::Severity::kERROR) { // Add kMSDomain to cover contrib op like Gelu @@ -300,6 +319,7 @@ inline void TestQDQModelAccuracy(const GetTestModelFn& f32_model_fn, const GetTe qdq_model.ToProto().SerializeToString(&qdq_model_data); // Run QDQ model on QNN EP and collect outputs. + TryEnableQNNSaver(qnn_options); std::vector qnn_qdq_outputs; InferenceModel(qdq_model_data, "qdq_model_logger", QnnExecutionProviderWithOptions(qnn_options), expected_ep_assignment, qdq_helper.feeds_, qnn_qdq_outputs); @@ -538,7 +558,7 @@ inline GetTestQDQModelFn BuildQDQOpTestCase(const std::string& op_typ * \param fp32_abs_err The acceptable error between CPU EP and QNN EP. * \param log_severity The logger's minimum severity level. */ -void RunQnnModelTest(const GetTestModelFn& build_test_case, const ProviderOptions& provider_options, +void RunQnnModelTest(const GetTestModelFn& build_test_case, ProviderOptions provider_options, int opset_version, ExpectedEPNodeAssignment expected_ep_assignment, float fp32_abs_err = 1e-5f, logging::Severity log_severity = logging::Severity::kERROR); diff --git a/onnxruntime/test/python/onnxruntime_test_collective.py b/onnxruntime/test/python/onnxruntime_test_collective.py index db1ebb5384730..4882b403c3c91 100644 --- a/onnxruntime/test/python/onnxruntime_test_collective.py +++ b/onnxruntime/test/python/onnxruntime_test_collective.py @@ -155,6 +155,7 @@ def _create_alltoall_ut_model_for_boolean_tensor( ) return ORTBertPretrainTest._create_model_with_opsets(graph_def) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") @parameterized.expand( [ (np.float32, TensorProto.FLOAT), @@ -193,6 +194,7 @@ def test_all_reduce(self, np_elem_type, elem_type): outputs[0], size * input, err_msg=f"{rank}: AllGather ({np_elem_type}, {elem_type}): results mismatch" ) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") @parameterized.expand( [ (np.float32, TensorProto.FLOAT, TensorProto.FLOAT), @@ -231,6 +233,7 @@ def test_all_gather(self, np_elem_type, elem_type, communication_elem_type): err_msg=f"{rank}: AllGather (axis0) ({np_elem_type}, {elem_type}, {communication_elem_type}): results mismatch", ) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") def test_all_gather_bool(self): model = self._create_allgather_ut_model((4,), 0, TensorProto.INT64, TensorProto.INT64) rank, _ = self._get_rank_size() @@ -250,6 +253,7 @@ def test_all_gather_bool(self): np.testing.assert_allclose(y, y_expected, err_msg=f"{rank}: AllGather (bool): results mismatch") + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") def test_all_gather_axis1(self): model = self._create_allgather_ut_model((128, 128), 1) rank, size = self._get_rank_size() @@ -268,6 +272,7 @@ def test_all_gather_axis1(self): np.testing.assert_allclose(outputs[0], expected_output, err_msg=f"{rank}: AllGather (axis1): results mismatch") + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") @parameterized.expand( [ (np.float32, TensorProto.FLOAT, TensorProto.FLOAT), @@ -349,6 +354,7 @@ def test_all_to_all(self, np_elem_type, elem_type, communication_elem_type): err_msg=f"{rank}: AllToAll ({np_elem_type}, {elem_type}, {communication_elem_type}): results mismatch", ) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") def test_all_to_all_bool(self): rank, _ = self._get_rank_size() diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index e94ac5c961583..f26b6297cdbda 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -279,6 +279,9 @@ def check_model_correctness( ops_set = set(node.op_type for node in model_onnx.graph.node) check_reference_evaluator = not (ops_set & {"EmbedLayerNormalization", "Conv", "Attention", "Transpose"}) + with open(model_path_to_check, "rb") as f: + model_check = onnx.load(f) + if check_reference_evaluator and onnx_recent_enough: ref = ReferenceEvaluator(model_path_origin) ref_origin_results = ref.run(None, inputs) @@ -289,7 +292,7 @@ def check_model_correctness( output, rtol=rtol, atol=atol, - err_msg=f"Model {model_path_to_check!r} failed for providers={providers!r}.", + err_msg=f"Model {model_path_origin!r} failed for providers={providers!r}.", ) # Verifies the shapes in the quantized model. @@ -301,40 +304,52 @@ def check_model_correctness( expected_shapes[init.name] = tuple(init.dims) checked = 0 f8_quantization = False - with open(model_path_to_check, "rb") as f: - model_check = onnx.load(f) - for init in model_check.graph.initializer: - if init.name.endswith("_quantized"): - name = init.name.replace("_quantized", "") - expected = expected_shapes[name] - shape = tuple(init.dims) - if not dynamic and expected != shape: - raise AssertionError( - f"Shape mismatch for initializer {init.name!r} from {init.name!r}, " - f"shape={shape} != {expected} (expected)." - ) - else: - checked += 1 - if "zero_point" in init.name: - dt = init.data_type - f8_quantization = f8_quantization or dt in ( - TensorProto.FLOAT8E4M3FN, - TensorProto.FLOAT8E4M3FNUZ, - TensorProto.FLOAT8E5M2, - TensorProto.FLOAT8E5M2FNUZ, + for init in model_check.graph.initializer: + if init.name.endswith("_quantized"): + name = init.name.replace("_quantized", "") + expected = expected_shapes[name] + shape = tuple(init.dims) + if not dynamic and expected != shape: + raise AssertionError( + f"Shape mismatch for initializer {init.name!r} from {init.name!r}, " + f"shape={shape} != {expected} (expected)." ) - if checked == 0: - raise AssertionError( - f"Unable to check expected shape, expected_shapes={expected_shapes}, " - f"names={[init.name for init in model_check.graph.initializer]}." + else: + checked += 1 + if "zero_point" in init.name: + dt = init.data_type + f8_quantization = f8_quantization or dt in ( + TensorProto.FLOAT8E4M3FN, + TensorProto.FLOAT8E4M3FNUZ, + TensorProto.FLOAT8E5M2, + TensorProto.FLOAT8E5M2FNUZ, ) + if checked == 0: + raise AssertionError( + f"Unable to check expected shape, expected_shapes={expected_shapes}, " + f"names={[init.name for init in model_check.graph.initializer]}." + ) if f8_quantization: check_sign_f8_quantization(model_path_origin, model_path_to_check) # Verifies the expected outputs. if check_reference_evaluator and onnx_recent_enough: + reference_new_ops = [QGemm] + has_missing_reference_ops = any( + node.domain not in ["", "ai.onnx"] + and not any( + node.domain == new_node.op_domain and node.op_type == new_node.__name__ + for new_node in reference_new_ops + ) + for node in model_check.graph.node + ) + if has_missing_reference_ops: + # We need to skip the test if the model contains ops that are not supported. + testcase.skipTest( + f"Model {model_path_to_check!r} contains ops that are not supported by the reference evaluator." + ) # Needs pv.Version(onnx.__version__) >= pv.Version("1.16.0") - ref = ReferenceEvaluator(model_path_to_check, new_ops=[QGemm]) + ref = ReferenceEvaluator(model_check, new_ops=reference_new_ops) target_results = ref.run(None, inputs) testcase.assertEqual(len(origin_results), len(target_results), "result count are different") for idx, ref_output in enumerate(origin_results): diff --git a/onnxruntime/test/python/quantization/test_conv_dynamic.py b/onnxruntime/test/python/quantization/test_conv_dynamic.py index 9578c9fe708aa..18467bcbc1083 100644 --- a/onnxruntime/test/python/quantization/test_conv_dynamic.py +++ b/onnxruntime/test/python/quantization/test_conv_dynamic.py @@ -95,7 +95,7 @@ def test_quant_conv(self): for use_quant_config in [True, False]: self.dynamic_quant_conv_test(QuantType.QUInt8, extra_options={}, use_quant_config=use_quant_config) - # TODO: uncomment following after ConvInteger s8 supportted + # TODO: uncomment following after ConvInteger s8 supported # def test_quant_conv_s8s8(self): # self.dynamic_quant_conv_test(QuantType.QInt8, extra_options={'ActivationSymmetric': True}) diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 8357ce22fb710..b605673d6606f 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -184,6 +184,8 @@ static constexpr PATH_TYPE SEQUENCE_MODEL_URI_2 = TSTR("testdata/optional_sequen #endif static constexpr PATH_TYPE CUSTOM_OP_MODEL_URI = TSTR("testdata/foo_1.onnx"); static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_TEST_MODEL_URI = TSTR("testdata/custom_op_library/custom_op_test.onnx"); +static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_COPY_TENSOR_ARRAY_2 = TSTR("testdata/custom_op_library/copy_2_inputs_2_outputs.onnx"); +static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_COPY_TENSOR_ARRAY_3 = TSTR("testdata/custom_op_library/copy_3_inputs_3_outputs.onnx"); #if !defined(DISABLE_FLOAT8_TYPES) static constexpr PATH_TYPE CUSTOM_OP_LIBRARY_TEST_MODEL_FLOAT8_URI = TSTR("testdata/custom_op_library/custom_op_test_float8.onnx"); #endif @@ -1406,6 +1408,59 @@ TEST(CApiTest, test_custom_op_library) { #endif } +// It has memory leak. The OrtCustomOpDomain created in custom_op_library.cc:RegisterCustomOps function was not freed +#if defined(__ANDROID__) +TEST(CApiTest, test_custom_op_library_copy_variadic) { +// To accomodate a reduced op build pipeline +#elif defined(REDUCED_OPS_BUILD) && defined(USE_CUDA) +TEST(CApiTest, test_custom_op_library_copy_variadic) { +#else +TEST(CApiTest, test_custom_op_library_copy_variadic) { +#endif + std::cout << "Running inference using custom op shared library" << std::endl; + + std::vector inputs(2); + inputs[0].name = "input_0"; + inputs[0].dims = {15}; + inputs[0].values = {1.1f, 2.2f, 3.3f, 4.4f, 5.5f, + 6.6f, 7.7f, 8.8f, 9.9f, 10.0f, + 11.1f, 12.2f, 13.3f, 14.4f, 15.5f}; + inputs[1].name = "input_1"; + inputs[1].dims = {15}; + inputs[1].values = {15.5f, 14.4f, 13.3f, 12.2f, 11.1f, + 10.0f, 9.9f, 8.8f, 7.7f, 6.6f, + 5.5f, 4.4f, 3.3f, 2.2f, 1.1f}; + + // prepare expected inputs and outputs + std::vector expected_dims_y = {15}; + std::vector expected_values_y = inputs[1].values; + + onnxruntime::PathString lib_name; +#if defined(_WIN32) + lib_name = ORT_TSTR("custom_op_library.dll"); +#elif defined(__APPLE__) + lib_name = ORT_TSTR("libcustom_op_library.dylib"); +#else + lib_name = ORT_TSTR("./libcustom_op_library.so"); +#endif + + TestInference(*ort_env, CUSTOM_OP_LIBRARY_COPY_TENSOR_ARRAY_2, + inputs, "output_1", expected_dims_y, + expected_values_y, 0, nullptr, lib_name.c_str()); + + inputs.push_back({}); + inputs[2].name = "input_2"; + inputs[2].dims = {15}; + inputs[2].values = {6.6f, 7.7f, 8.8f, 9.9f, 10.0f, + 1.1f, 2.2f, 3.3f, 4.4f, 5.5f, + 11.1f, 12.2f, 13.3f, 14.4f, 15.5f}; + + expected_values_y = inputs[2].values; + TestInference(*ort_env, CUSTOM_OP_LIBRARY_COPY_TENSOR_ARRAY_3, + inputs, "output_2", expected_dims_y, + expected_values_y, 0, nullptr, lib_name.c_str()); +} + #if !defined(DISABLE_FLOAT8_TYPES) struct InputF8 { @@ -3126,9 +3181,12 @@ struct Merge { ORT_ENFORCE(ort_api->KernelInfoGetAttribute_int64(info, "reverse", &reverse) == nullptr); reverse_ = reverse != 0; } - void Compute(const Ort::Custom::Tensor& strings_in, - std::string_view string_in, - Ort::Custom::Tensor* strings_out) { + Ort::Status Compute(const Ort::Custom::Tensor& strings_in, + std::string_view string_in, + Ort::Custom::Tensor* strings_out) { + if (strings_in.NumberOfElement() == 0) { + return Ort::Status("the 1st input must have more than one string!", OrtErrorCode::ORT_INVALID_ARGUMENT); + } std::vector string_pool; for (const auto& s : strings_in.Data()) { string_pool.emplace_back(s.data(), s.size()); @@ -3141,6 +3199,7 @@ struct Merge { std::reverse(string_pool.begin(), string_pool.end()); } strings_out->SetStringOutput(string_pool, {static_cast(string_pool.size())}); + return Ort::Status(nullptr); } bool reverse_ = false; }; diff --git a/onnxruntime/test/testdata/custom_op_library/copy_2_inputs_2_outputs.onnx b/onnxruntime/test/testdata/custom_op_library/copy_2_inputs_2_outputs.onnx new file mode 100644 index 0000000000000..1756aae6e8d5a --- /dev/null +++ b/onnxruntime/test/testdata/custom_op_library/copy_2_inputs_2_outputs.onnx @@ -0,0 +1,16 @@ + :ï +d +input_0 +input_1output_0output_1copy_tensor_array"CopyTensorArrayAllVariadic: test.customopgraphZ +input_0 + + ÿÿÿÿÿÿÿÿÿZ +input_1 + + ÿÿÿÿÿÿÿÿÿb +output_0 + + ÿÿÿÿÿÿÿÿÿb +output_1 + + ÿÿÿÿÿÿÿÿÿB \ No newline at end of file diff --git a/onnxruntime/test/testdata/custom_op_library/copy_3_inputs_3_outputs.onnx b/onnxruntime/test/testdata/custom_op_library/copy_3_inputs_3_outputs.onnx new file mode 100644 index 0000000000000..86c9ec1c1fc37 --- /dev/null +++ b/onnxruntime/test/testdata/custom_op_library/copy_3_inputs_3_outputs.onnx @@ -0,0 +1,23 @@ + :À +t +input_0 +input_1 +input_2output_0output_1output_2copy_tensor_array"CopyTensorArrayCombined: test.customopgraphZ +input_0 + + ÿÿÿÿÿÿÿÿÿZ +input_1 + + ÿÿÿÿÿÿÿÿÿZ +input_2 + + ÿÿÿÿÿÿÿÿÿb +output_0 + + ÿÿÿÿÿÿÿÿÿb +output_1 + + ÿÿÿÿÿÿÿÿÿb +output_2 + + ÿÿÿÿÿÿÿÿÿB \ No newline at end of file diff --git a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc index f9e537fb61047..f7a623901f263 100644 --- a/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc @@ -11,16 +11,20 @@ using namespace Ort::Custom; namespace Cpu { -void KernelOne(const Ort::Custom::Tensor& X, - const Ort::Custom::Tensor& Y, - Ort::Custom::Tensor& Z) { - auto input_shape = X.Shape(); +Ort::Status KernelOne(const Ort::Custom::Tensor& X, + const Ort::Custom::Tensor& Y, + Ort::Custom::Tensor& Z) { + if (X.NumberOfElement() != Y.NumberOfElement()) { + return Ort::Status("x and y has different number of elements", OrtErrorCode::ORT_INVALID_ARGUMENT); + } + auto x_shape = X.Shape(); auto x_raw = X.Data(); auto y_raw = Y.Data(); - auto z_raw = Z.Allocate(input_shape); + auto z_raw = Z.Allocate(x_shape); for (int64_t i = 0; i < Z.NumberOfElement(); ++i) { z_raw[i] = x_raw[i] + y_raw[i]; } + return Ort::Status{nullptr}; } // lite custom op as a function @@ -162,6 +166,46 @@ void FilterFloat8(const Ort::Custom::Tensor& floats_in, } #endif +// a sample custom op accepting variadic inputs, and generate variadic outputs by simply 1:1 copying. +template +Ort::Status CopyTensorArrayAllVariadic(const Ort::Custom::TensorArray& inputs, Ort::Custom::TensorArray& outputs) { + for (size_t ith_input = 0; ith_input < inputs.Size(); ++ith_input) { + const auto& input = inputs[ith_input]; + const auto& input_shape = input->Shape(); + const T* raw_input = reinterpret_cast(input->DataRaw()); + auto num_elements = input->NumberOfElement(); + T* raw_output = outputs.AllocateOutput(ith_input, input_shape); + if (!raw_output) { + return Ort::Status("Failed to allocate output!", OrtErrorCode::ORT_FAIL); + } + for (int64_t jth_elem = 0; jth_elem < num_elements; ++jth_elem) { + raw_output[jth_elem] = raw_input[jth_elem]; + } + } + return Ort::Status{nullptr}; +} + +template +Ort::Status CopyTensorArrayCombined(const Ort::Custom::Tensor& first_input, + const Ort::Custom::TensorArray& other_inputs, + Ort::Custom::Tensor& first_output, + Ort::Custom::TensorArray& other_outputs) { + const auto first_input_shape = first_input.Shape(); + const T* raw_input = reinterpret_cast(first_input.DataRaw()); + + T* raw_output = first_output.Allocate(first_input_shape); + if (!raw_output) { + return Ort::Status("Failed to allocate output!", OrtErrorCode::ORT_FAIL); + } + + auto num_elements = first_input.NumberOfElement(); + for (int64_t ith_elem = 0; ith_elem < num_elements; ++ith_elem) { + raw_output[ith_elem] = raw_input[ith_elem]; + } + + return CopyTensorArrayAllVariadic(other_inputs, other_outputs); +} + void RegisterOps(Ort::CustomOpDomain& domain) { static const std::unique_ptr c_CustomOpOne{Ort::Custom::CreateLiteCustomOp("CustomOpOne", "CPUExecutionProvider", KernelOne)}; static const std::unique_ptr c_CustomOpTwo{Ort::Custom::CreateLiteCustomOp("CustomOpTwo", "CPUExecutionProvider", KernelTwo)}; @@ -171,6 +215,8 @@ void RegisterOps(Ort::CustomOpDomain& domain) { static const std::unique_ptr c_Select{Ort::Custom::CreateLiteCustomOp("Select", "CPUExecutionProvider", Select)}; static const std::unique_ptr c_Fill{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)}; static const std::unique_ptr c_Box{Ort::Custom::CreateLiteCustomOp("Box", "CPUExecutionProvider", Box)}; + static const std::unique_ptr c_CopyTensorArrayAllVariadic{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayAllVariadic", "CPUExecutionProvider", CopyTensorArrayAllVariadic)}; + static const std::unique_ptr c_CopyTensorArrayCombined{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayCombined", "CPUExecutionProvider", CopyTensorArrayCombined)}; #if !defined(DISABLE_FLOAT8_TYPES) static const CustomOpOneFloat8 c_CustomOpOneFloat8; @@ -185,6 +231,9 @@ void RegisterOps(Ort::CustomOpDomain& domain) { domain.Add(c_Select.get()); domain.Add(c_Fill.get()); domain.Add(c_Box.get()); + domain.Add(c_CopyTensorArrayAllVariadic.get()); + domain.Add(c_CopyTensorArrayCombined.get()); + #if !defined(DISABLE_FLOAT8_TYPES) domain.Add(&c_CustomOpOneFloat8); domain.Add(c_FilterFloat8.get()); diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index 71a10f646a7c6..c142106ed506c 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -233,7 +233,64 @@ "^test_resize_upsample_sizes_nearest_cuda", "^test_resize_upsample_sizes_nearest_floor_align_corners_cuda", "^test_resize_upsample_sizes_nearest_not_larger_cuda", - "^test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_cuda" + "^test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_cuda", + // onnx 1.15 (opset 20) new and updated op tests + "^test_ai_onnx_ml_label_encoder_string_int", + "^test_ai_onnx_ml_label_encoder_string_int_no_default", + "^test_ai_onnx_ml_label_encoder_tensor_mapping", + "^test_ai_onnx_ml_label_encoder_tensor_value_only_mapping", + "^test_gridsample_aligncorners_true", + "^test_gridsample_bicubic_align_corners_0_additional_1", + "^test_gridsample_bicubic_align_corners_1_additional_1", + "^test_gridsample_bicubic", + "^test_gridsample_bilinear_align_corners_0_additional_1", + "^test_gridsample_bilinear_align_corners_1_additional_1", + "^test_gridsample_bilinear", + "^test_gridsample_border_padding", + "^test_gridsample", + "^test_gridsample_nearest_align_corners_0_additional_1", + "^test_gridsample_nearest_align_corners_1_additional_1", + "^test_gridsample_nearest", + "^test_gridsample_reflection_padding", + "^test_gridsample_volumetric_bilinear_align_corners_0", + "^test_gridsample_volumetric_bilinear_align_corners_1", + "^test_gridsample_volumetric_nearest_align_corners_0", + "^test_gridsample_volumetric_nearest_align_corners_1", + "^test_gridsample_zeros_padding", + "^test_image_decoder_decode_bmp_rgb", + "^test_image_decoder_decode_jpeg2k_rgb", + "^test_image_decoder_decode_jpeg_bgr", + "^test_image_decoder_decode_jpeg_grayscale", + "^test_image_decoder_decode_jpeg_rgb", + "^test_image_decoder_decode_png_rgb", + "^test_image_decoder_decode_pnm_rgb", + "^test_image_decoder_decode_tiff_rgb", + "^test_image_decoder_decode_webp_rgb", + "^test_regex_full_match_basic", + "^test_regex_full_match_email_domain", + "^test_regex_full_match_empty", + "^test_string_concat_broadcasting", + "^test_string_concat", + "^test_string_concat_empty_string", + "^test_string_concat_utf8", + "^test_string_concat_zero_dimensional", + "^test_string_split_basic", + "^test_string_split_consecutive_delimiters", + "^test_string_split_empty_string_delimiter", + "^test_string_split_empty_tensor", + "^test_string_split_maxsplit", + "^test_string_split_no_delimiter", + "^test_dft_axis", + "^test_dft", + "^test_dft_inverse", + "^test_isinf", + "^test_isinf_float16", + "^test_isinf_negative", + "^test_isinf_positive", + "^test_isnan", + "^test_isnan_float16", + "^test_reduce_max_bool_inputs", + "^test_reduce_min_bool_inputs" ], "current_failing_tests_x86": [ "^test_vgg19", @@ -316,7 +373,24 @@ "^test_layer_normalization_4d_axis_negative_1_expanded_ver18_cpu", "^test_layer_normalization_4d_axis_negative_2_expanded_ver18_cpu", "^test_layer_normalization_4d_axis_negative_3_expanded_ver18_cpu", - "^test_layer_normalization_default_axis_expanded_ver18_cpu" + "^test_layer_normalization_default_axis_expanded_ver18_cpu", + // onnx 1.15 (opset 20) new and updated op tests (test_affine_grid_???_expanded utilizes ConstantOfShape so it needs to be skipped as well) + // https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1139541&view=logs&j=249e9d58-0012-5814-27cf-6a201adbd9cf&t=bb33e81f-0527-50e0-0fd2-e94f509f0a82 + // only supported with cpu provider + "^test_affine_grid_2d", + "^test_affine_grid_2d_align_corners", + "^test_affine_grid_2d_align_corners_expanded", + "^test_affine_grid_2d_expanded", + "^test_affine_grid_3d", + "^test_affine_grid_3d_align_corners", + "^test_affine_grid_3d_align_corners_expanded", + "^test_affine_grid_3d_expanded", + "^test_constantofshape_float_ones", + "^test_constantofshape_int_shape_zero", + "^test_constantofshape_int_zeros", + // https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1141563&view=logs&j=a018b46d-e41a-509d-6581-c95fdaa42fcd&t=d61c1d37-f101-5d28-982f-e5931b720302 + "^test_gelu_tanh_2_cpu", + "^test_gelu_tanh_2_expanded_cpu" ], "current_failing_tests_NNAPI": [ "^test_maxpool_2d_uint8", @@ -390,6 +464,14 @@ "^test_squeeze_negative_axes" ], "current_failing_tests_OPENVINO_CPU_FP32": [ + "^test_affine_grid_2d_align_corners", + "^test_affine_grid_2d_align_corners_expanded", + "^test_affine_grid_2d", + "^test_affine_grid_2d_expanded", + "^test_affine_grid_3d_align_corners", + "^test_affine_grid_3d_align_corners_expanded", + "^test_affine_grid_3d", + "^test_affine_grid_3d_expanded", "^test_operator_permute2", "^test_operator_repeat", "^test_operator_repeat_dim_overflow", @@ -569,7 +651,22 @@ "^test_sequence_map_identity_1_sequence_cpu", "^test_sequence_map_identity_1_sequence_expanded_cpu", "^test_sequence_map_identity_2_sequences_cpu", - "^test_sequence_map_identity_2_sequences_expanded_cpu" + "^test_sequence_map_identity_2_sequences_expanded_cpu", + // onnx 1.15 (opset 20) new and updated op tests (test_affine_grid_???_expanded utilizes ConstantOfShape so it needs to be skipped as well) + // https://dev.azure.com/onnxruntime/onnxruntime/_build/results?buildId=1139542&view=logs&j=3032dfba-5baf-5872-0871-2e69cb7f4b6a&t=f0d05deb-fc26-5aaf-e43e-7db2764c07da + // only supported with cpu provider + "^test_affine_grid_2d", + "^test_affine_grid_2d_align_corners", + "^test_affine_grid_2d_align_corners_expanded", + "^test_affine_grid_2d_expanded", + "^test_affine_grid_3d", + "^test_affine_grid_3d_align_corners", + "^test_affine_grid_3d_align_corners_expanded", + "^test_affine_grid_3d_expanded", + "^test_constantofshape_float_ones", + "^test_constantofshape_int_shape_zero", + "^test_constantofshape_int_zeros" + ], // ORT first supported opset 7, so models with nodes that require versions prior to opset 7 are not supported "tests_with_pre_opset7_dependencies": [ diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_overrides.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_overrides.jsonc index caeea0a758ad9..07385ac9ade05 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_overrides.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_overrides.jsonc @@ -7,6 +7,9 @@ "test_dft": 1e-3, "test_dft_axis": 1e-3, "test_dft_inverse": 1e-3, + "test_dft_opset19": 1e-3, + "test_dft_axis_opset19": 1e-3, + "test_dft_inverse_opset19": 1e-3, "test_stft": 1e-4, "test_stft_with_window": 1e-4 }, diff --git a/onnxruntime/test/testdata/transform/gh_issue_17392.onnx b/onnxruntime/test/testdata/transform/gh_issue_17392.onnx new file mode 100644 index 0000000000000..ca9b78a4179bd --- /dev/null +++ b/onnxruntime/test/testdata/transform/gh_issue_17392.onnx @@ -0,0 +1,9 @@ + :Z ++C"Constant* + value_stringsJabcJdef  + +CY"IdentityConstantb +Y + + +B \ No newline at end of file diff --git a/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.onnx b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.onnx new file mode 100644 index 0000000000000..9d82d68a40098 Binary files /dev/null and b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.onnx differ diff --git a/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.py b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.py new file mode 100644 index 0000000000000..d4e3f7e8cbab6 --- /dev/null +++ b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers.py @@ -0,0 +1,60 @@ +import numpy as np +import onnx +from onnx import TensorProto, helper + + +# Create a model with shared initializers that can be updated in-place by the transpose optimizer, +# including ones behind a DQ node. The transpose optimizer updates the first usage and inserts +# Transpose/Unsqueeze ops on the others (see UnsqueezeInput and TransposeInput). +# When we push the Transpose past other usages we should be able to cancel out those Transpose/Unsqueeze ops. +# We need 3 DQ nodes to ensure the Transpose or Unsqueeze added by the transpose optimizer is not +# removed prematurely. +def create_model(broadcast_weights: bool): + if broadcast_weights: + bias_shape = [2, 2] + bias_values = np.random.randn(2, 2) + else: + bias_shape = [1, 3, 2, 2] + bias_values = np.random.randn(1, 3, 2, 2) + + graph = helper.make_graph( + name="graph", + inputs=[ + helper.make_tensor_value_info("input0", TensorProto.FLOAT, [1, 2, 2, 3]), + ], + initializer=[ + helper.make_tensor("bias_quant", TensorProto.UINT8, bias_shape, bias_values.astype(np.uint8)), + helper.make_tensor("bias_fp32", TensorProto.FLOAT, bias_shape, bias_values.astype(np.float32)), + helper.make_tensor("dq_scale0", TensorProto.FLOAT, [], [1.5]), + helper.make_tensor("dq_zp0", TensorProto.UINT8, [], [5]), + helper.make_tensor("dq_scale1", TensorProto.FLOAT, [], [0.5]), + ], + nodes=[ + # Transpose input from channels last to channels first + helper.make_node("Transpose", ["input0"], ["input_T"], perm=[0, 3, 1, 2]), + helper.make_node("DequantizeLinear", ["bias_quant", "dq_scale0", "dq_zp0"], ["DQ0"], "DQ0"), + helper.make_node("Add", ["input_T", "DQ0"], ["A0"], "A0"), + helper.make_node("DequantizeLinear", ["bias_quant", "dq_scale1"], ["DQ1"], "DQ1"), + helper.make_node("Add", ["A0", "DQ1"], ["A1"], "A1"), + helper.make_node("DequantizeLinear", ["bias_quant", "dq_scale0"], ["DQ2"], "DQ2"), + helper.make_node("Add", ["A1", "DQ2"], ["A2"], "A2"), + helper.make_node("Add", ["A2", "bias_fp32"], ["A3"], "A3"), + helper.make_node("Add", ["A3", "bias_fp32"], ["A4"], "A4"), + # NCHW to NHWC + helper.make_node("Transpose", ["A4"], ["output0"], perm=[0, 2, 3, 1]), + ], + outputs=[ + helper.make_tensor_value_info("output0", TensorProto.FLOAT, [1, 2, 2, 3]), + ], + ) + + model = helper.make_model(graph) + onnx.checker.check_model(model, full_check=True) + return model + + +if __name__ == "__main__": + model = create_model(broadcast_weights=False) + onnx.save(model, "transpose_optimizer_shared_initializers.onnx") + model = create_model(broadcast_weights=True) + onnx.save(model, "transpose_optimizer_shared_initializers_broadcast.onnx") diff --git a/onnxruntime/test/testdata/transpose_optimizer_shared_initializers_broadcast.onnx b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers_broadcast.onnx new file mode 100644 index 0000000000000..8bb2c6fd4a8b5 Binary files /dev/null and b/onnxruntime/test/testdata/transpose_optimizer_shared_initializers_broadcast.onnx differ diff --git a/onnxruntime/test/xctest/xcgtest.mm b/onnxruntime/test/xctest/xcgtest.mm index 5367f3e89c07c..c02f18d906cbe 100644 --- a/onnxruntime/test/xctest/xcgtest.mm +++ b/onnxruntime/test/xctest/xcgtest.mm @@ -201,7 +201,7 @@ + (void)registerTestClasses { delete listeners.Release(listeners.default_result_printer()); free(argv); - BOOL runDisabledTests = testing::GTEST_FLAG(also_run_disabled_tests); + BOOL runDisabledTests = GTEST_FLAG_GET(also_run_disabled_tests); NSMutableDictionary* testFilterMap = [NSMutableDictionary dictionary]; NSCharacterSet* decimalDigitCharacterSet = [NSCharacterSet decimalDigitCharacterSet]; diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 174edabbc91fe..968eece361724 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -9,6 +9,7 @@ #include "api.h" #include +#include #include namespace { @@ -17,6 +18,14 @@ OrtErrorCode g_last_error_code; std::string g_last_error_message; } // namespace +enum DataLocation { + DATA_LOCATION_NONE = 0, + DATA_LOCATION_CPU = 1, + DATA_LOCATION_CPU_PINNED = 2, + DATA_LOCATION_TEXTURE = 3, + DATA_LOCATION_GPU_BUFFER = 4 +}; + static_assert(sizeof(const char*) == sizeof(size_t), "size of a pointer and a size_t value should be the same."); static_assert(sizeof(size_t) == 4, "size of size_t should be 4 in this build (wasm32)."); @@ -223,13 +232,23 @@ void OrtFree(void* ptr) { } } -OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length) { +OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location) { + if (data_location != DATA_LOCATION_CPU && + data_location != DATA_LOCATION_CPU_PINNED && + data_location != DATA_LOCATION_GPU_BUFFER) { + std::ostringstream ostr; + ostr << "Invalid data location: " << data_location; + CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str())); + return nullptr; + } + std::vector shapes(dims_length); for (size_t i = 0; i < dims_length; i++) { shapes[i] = dims[i]; } if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { + // data_location is ignored for string tensor. It is always CPU. OrtAllocator* allocator = nullptr; RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); @@ -244,12 +263,16 @@ OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* return UNREGISTER_AUTO_RELEASE(value); } else { - OrtMemoryInfo* memoryInfo = nullptr; - RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memoryInfo); - REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memoryInfo); + OrtMemoryInfo* memory_info = nullptr; + if (data_location != DATA_LOCATION_GPU_BUFFER) { + RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); + } else { + RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + } + REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memory_info); OrtValue* value = nullptr; - int error_code = CHECK_STATUS(CreateTensorWithDataAsOrtValue, memoryInfo, data, data_length, + int error_code = CHECK_STATUS(CreateTensorWithDataAsOrtValue, memory_info, data, data_length, dims_length > 0 ? shapes.data() : nullptr, dims_length, static_cast(data_type), &value); @@ -373,15 +396,85 @@ void OrtReleaseRunOptions(OrtRunOptions* run_options) { Ort::GetApi().ReleaseRunOptions(run_options); } +OrtIoBinding* OrtCreateBinding(OrtSession* session) { + OrtIoBinding* binding = nullptr; + int error_code = CHECK_STATUS(CreateIoBinding, session, &binding); + return (error_code == ORT_OK) ? binding : nullptr; +} + +int EMSCRIPTEN_KEEPALIVE OrtBindInput(OrtIoBinding* io_binding, + const char* name, + OrtValue* input) { + return CHECK_STATUS(BindInput, io_binding, name, input); +} + +int EMSCRIPTEN_KEEPALIVE OrtBindOutput(OrtIoBinding* io_binding, + const char* name, + OrtValue* output, + int output_location) { + if (output) { + return CHECK_STATUS(BindOutput, io_binding, name, output); + } else { + if (output_location != DATA_LOCATION_NONE && + output_location != DATA_LOCATION_CPU && + output_location != DATA_LOCATION_CPU_PINNED && + output_location != DATA_LOCATION_GPU_BUFFER) { + std::ostringstream ostr; + ostr << "Invalid data location (" << output_location << ") for output: \"" << name << "\"."; + return CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str())); + } + + OrtMemoryInfo* memory_info = nullptr; + if (output_location != DATA_LOCATION_GPU_BUFFER) { + RETURN_ERROR_CODE_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); + } else { + RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + } + REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memory_info); + return CHECK_STATUS(BindOutputToDevice, io_binding, name, memory_info); + } +} + +void OrtClearBoundOutputs(OrtIoBinding* io_binding) { + Ort::GetApi().ClearBoundOutputs(io_binding); +} + +void OrtReleaseBinding(OrtIoBinding* io_binding) { + Ort::GetApi().ReleaseIoBinding(io_binding); +} + +int OrtRunWithBinding(OrtSession* session, + OrtIoBinding* io_binding, + size_t output_count, + OrtValue** outputs, + OrtRunOptions* run_options) { + RETURN_ERROR_CODE_IF_ERROR(RunWithBinding, session, run_options, io_binding); + + OrtAllocator* allocator = nullptr; + RETURN_ERROR_CODE_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); + + size_t binding_output_count = 0; + OrtValue** binding_outputs = nullptr; + RETURN_ERROR_CODE_IF_ERROR(GetBoundOutputValues, io_binding, allocator, &binding_outputs, &binding_output_count); + REGISTER_AUTO_RELEASE_BUFFER(OrtValue*, binding_outputs, allocator); + + if (binding_output_count != output_count) { + return CheckStatus( + Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Output count is inconsistent with IO Binding output data.")); + } + + for (size_t i = 0; i < output_count; i++) { + outputs[i] = binding_outputs[i]; + } + + return ORT_OK; +} + int OrtRun(OrtSession* session, const char** input_names, const ort_tensor_handle_t* inputs, size_t input_count, const char** output_names, size_t output_count, ort_tensor_handle_t* outputs, OrtRunOptions* run_options) { - auto status_code = CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs); -#if defined(USE_JSEP) - EM_ASM({ Module.jsepRunPromiseResolve ?.($0); }, status_code); -#endif - return status_code; + return CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs); } char* OrtEndProfiling(ort_session_handle_t session) { diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 398c901e0e5ed..9a0664697f0ff 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -15,6 +15,9 @@ struct OrtSession; using ort_session_handle_t = OrtSession*; +struct OrtIoBinding; +using ort_io_binding_handle_t = OrtIoBinding*; + struct OrtSessionOptions; using ort_session_options_handle_t = OrtSessionOptions*; @@ -164,9 +167,10 @@ void EMSCRIPTEN_KEEPALIVE OrtFree(void* ptr); * @param data_length size of the buffer 'data' in bytes. * @param dims a pointer to an array of dims. the array should contain (dims_length) element(s). * @param dims_length the length of the tensor's dimension + * @param data_location specify the memory location of the tensor data. 0 for CPU, 1 for GPU buffer. * @returns a tensor handle. Caller must release it after use by calling OrtReleaseTensor(). */ -ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length); +ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location); /** * get type, shape info and data of the specified tensor. @@ -216,6 +220,58 @@ int EMSCRIPTEN_KEEPALIVE OrtAddRunConfigEntry(ort_run_options_handle_t run_optio */ void EMSCRIPTEN_KEEPALIVE OrtReleaseRunOptions(ort_run_options_handle_t run_options); +/** + * create an instance of ORT IO binding. + */ +ort_io_binding_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateBinding(ort_session_handle_t session); + +/** + * bind an input tensor to the IO binding instance. A cross device copy will be performed if necessary. + * @param io_binding handle of the IO binding + * @param name name of the input + * @param input handle of the input tensor + * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtBindInput(ort_io_binding_handle_t io_binding, + const char* name, + ort_tensor_handle_t input); + +/** + * bind an output tensor or location to the IO binding instance. + * @param io_binding handle of the IO binding + * @param name name of the output + * @param output handle of the output tensor. nullptr for output location binding. + * @param output_location specify the memory location of the output tensor data. + * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtBindOutput(ort_io_binding_handle_t io_binding, + const char* name, + ort_tensor_handle_t output, + int output_location); + +/** + * clear all bound outputs. + */ +void EMSCRIPTEN_KEEPALIVE OrtClearBoundOutputs(ort_io_binding_handle_t io_binding); + +/** + * release the specified ORT IO binding. + */ +void EMSCRIPTEN_KEEPALIVE OrtReleaseBinding(ort_io_binding_handle_t io_binding); + +/** + * inference the model. + * @param session handle of the specified session + * @param io_binding handle of the IO binding + * @param run_options handle of the run options + * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtRunWithBinding(ort_session_handle_t session, + ort_io_binding_handle_t io_binding, + size_t output_count, + ort_tensor_handle_t* outputs, + ort_run_options_handle_t run_options); + /** * inference the model. * @param session handle of the specified session diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 15d393f4ce62d..427ad6f6d14f3 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -14,40 +14,156 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module.jsepReleaseKernel = releaseKernel; Module.jsepRunKernel = runKernel; - Module['jsepOnRunStart'] = sessionId => { - Module['jsepRunPromise'] = new Promise(r => { - Module.jsepRunPromiseResolve = r; - }); - - if (Module.jsepSessionState) { - throw new Error('Session already started'); - } - - Module.jsepSessionState = { - sessionId, - errors: [] + // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) + // It removes some overhead in cwarp() and ccall() that we don't need. + // + // Currently in JSEP build, we only use this for the following functions: + // - OrtRun() + // - OrtRunWithBinding() + // - OrtBindInput() + // + // Note: about parameters "getFunc" and "setFunc": + // - Emscripten has different behaviors for Debug and Release builds for generating exported function wrapper. + // + // - In Debug build, it will generate a wrapper function for each exported function. For example, it generates a + // wrapper for OrtRun() like this (minified): + // ``` + // var _OrtRun = Module["_OrtRun"] = createExportWrapper("OrtRun"); + // ``` + // + // - In Release build, it will generate a lazy loading wrapper for each exported function. For example, it generates + // a wrapper for OrtRun() like this (minified): + // ``` + // d._OrtRun = (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); + // ``` + // + // The behavior of these two wrappers are different. The debug build will assign `Module["_OrtRun"]` only once + // because `createExportWrapper()` does not reset `Module["_OrtRun"]` inside. The release build, however, will + // reset d._OrtRun to J.ka when the first time it is called. + // + // The difference is important because we need to design the async wrapper in a way that it can handle both cases. + // + // Now, let's look at how the async wrapper is designed to work for both cases: + // + // - Debug build: + // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to `createExportWrapper("OrtRun")`. + // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async + // wrapper function. + // Value of `Module["_OrtRun"]` will not be changed again. + // + // - Release build: + // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to a lazy loading wrapper function. + // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async + // wrapper function. + // 3. When the first time `Module["_OrtRun"]` is called, the async wrapper will be called. It will call into this + // function: + // ``` + // (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); + // ``` + // This function will assign d._OrtRun (ie. the minimized `Module["_OrtRun"]`) to the real function (J.ka). + // 4. Since d._OrtRun is re-assigned, we need to update the async wrapper to re-assign its stored + // function to the updated value (J.ka), and re-assign the value of `d._OrtRun` back to the async wrapper. + // Value of `Module["_OrtRun"]` will not be changed again. + // + // The value of `Module["_OrtRun"]` will need to be assigned for 2 times for debug build and 4 times for release + // build. + // + // This is why we need this `getFunc` and `setFunc` parameters. They are used to get the current value of an + // exported function and set the new value of an exported function. + // + const jsepWrapAsync = (func, getFunc, setFunc) => { + return (...args) => { + // cache the async data before calling the function. + const previousAsync = Asyncify.currData; + + const previousFunc = getFunc?.(); + const ret = func(...args); + const newFunc = getFunc?.(); + if (previousFunc !== newFunc) { + // The exported function has been updated. + // Set the sync function reference to the new function. + func = newFunc; + // Set the exported function back to the async wrapper. + setFunc(previousFunc); + // Remove getFunc and setFunc. They are no longer needed. + setFunc = null; + getFunc = null; + } + + // If the async data has been changed, it means that the function started an async operation. + if (Asyncify.currData != previousAsync) { + // returns the promise + return Asyncify.whenDone(); + } + // the function is synchronous. returns the result. + return ret; }; }; - Module['jsepOnRunEnd'] = sessionId => { - if (Module.jsepSessionState.sessionId !== sessionId) { - throw new Error('Session ID mismatch'); - } - - const errorPromises = Module.jsepSessionState.errors; - Module.jsepSessionState = null; - - return errorPromises.length === 0 ? Promise.resolve() : new Promise((resolve, reject) => { - Promise.all(errorPromises).then(errors => { - errors = errors.filter(e => e); - if (errors.length > 0) { - reject(new Error(errors.join('\n'))); - } else { - resolve(); + // This is a wrapper for OrtRun() and OrtRunWithBinding() to ensure that Promises are handled correctly. + const runAsync = (runAsyncFunc) => { + return async (...args) => { + try { + // Module.jsepSessionState should be null, unless we are in the middle of a session. + // If it is not null, it means that the previous session has not finished yet. + if (Module.jsepSessionState) { + throw new Error('Session already started'); + } + const state = Module.jsepSessionState = {sessionHandle: args[0], errors: []}; + + // Run the acyncified function: OrtRun() or OrtRunWithBinding() + const ret = await runAsyncFunc(...args); + + // Check if the session is still valid. this object should be the same as the one we set above. + if (Module.jsepSessionState !== state) { + throw new Error('Session mismatch'); + } + + // Flush the backend. This will submit all pending commands to the GPU. + backend['flush'](); + + // Await all pending promises. This includes GPU validation promises for diagnostic purposes. + const errorPromises = state.errors; + if (errorPromises.length > 0) { + let errors = await Promise.all(errorPromises); + errors = errors.filter(e => e); + if (errors.length > 0) { + throw new Error(errors.join('\n')); + } } - }, reason => { - reject(reason); - }); - }); + + return ret; + } finally { + Module.jsepSessionState = null; + } + }; + }; + + // replace the original functions with asyncified versions + Module['_OrtRun'] = runAsync(jsepWrapAsync( + Module['_OrtRun'], + () => Module['_OrtRun'], + v => Module['_OrtRun'] = v)); + Module['_OrtRunWithBinding'] = runAsync(jsepWrapAsync( + Module['_OrtRunWithBinding'], + () => Module['_OrtRunWithBinding'], + v => Module['_OrtRunWithBinding'] = v)); + Module['_OrtBindInput'] = jsepWrapAsync( + Module['_OrtBindInput'], + () => Module['_OrtBindInput'], + v => Module['_OrtBindInput'] = v); + + // expose webgpu backend functions + Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => { + return backend['registerBuffer'](sessionId, index, buffer, size); + }; + Module['jsepUnregisterBuffers'] = sessionId => { + backend['unregisterBuffers'](sessionId); + }; + Module['jsepGetBuffer'] = (dataId) => { + return backend['getBuffer'](dataId); + }; + Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { + return backend['createDownloader'](gpuBuffer, size, type); }; }; diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 91b1df7b7cf2d..86d3cdee9ba98 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -4171,7 +4171,7 @@ Return true if all elements are true and false otherwise. "T", OpSchema::Variadic, /*is_homogeneous*/ false, /*min_arity*/ 1) - .TypeConstraint("T", OpSchema::all_tensor_types_with_bfloat(), + .TypeConstraint("T", OpSchema::all_tensor_types_ir4(), "Allow inputs and outputs to be any kind of tensor."); #endif // ENABLE_TRITON diff --git a/orttraining/orttraining/models/runner/training_runner.cc b/orttraining/orttraining/models/runner/training_runner.cc index d35366e556a42..9ac9f3ee090bb 100644 --- a/orttraining/orttraining/models/runner/training_runner.cc +++ b/orttraining/orttraining/models/runner/training_runner.cc @@ -1252,7 +1252,7 @@ Status WithOrtValuesFromTensorProtos( OrtValue ort_value; - ORT_RETURN_IF_ERROR(utils::TensorProtoToMLValue( + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue( Env::Default(), model_location.c_str(), tensor_proto, mem_buffer, ort_value)); diff --git a/orttraining/orttraining/models/runner/training_util.cc b/orttraining/orttraining/models/runner/training_util.cc index fa9c3f24b2ee5..7764508d9a091 100644 --- a/orttraining/orttraining/models/runner/training_util.cc +++ b/orttraining/orttraining/models/runner/training_util.cc @@ -52,7 +52,7 @@ common::Status DataSet::AddData(const vector& featu OrtValue ort_value; OrtMemoryInfo info("Cpu", OrtDeviceAllocator, OrtDevice{}, 0, OrtMemTypeDefault); std::unique_ptr buffer = std::make_unique(cpu_tensor_length); - ORT_RETURN_IF_ERROR(utils::TensorProtoToMLValue( + ORT_RETURN_IF_ERROR(utils::TensorProtoToOrtValue( Env::Default(), nullptr, tensor_proto, MemBuffer(buffer.get(), cpu_tensor_length, info), ort_value)); sample->push_back(ort_value); diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 7024244629c3e..88ef90a7feaa8 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -15,6 +15,12 @@ namespace onnxruntime { namespace python { namespace py = pybind11; +#if defined(USE_MPI) && defined(ORT_USE_NCCL) +static constexpr bool HAS_COLLECTIVE_OPS = true; +#else +static constexpr bool HAS_COLLECTIVE_OPS = false; +#endif + using namespace onnxruntime::logging; std::unique_ptr CreateExecutionProviderInstance( @@ -361,6 +367,8 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { }, "Clean the execution provider instances used in ort training module."); + m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; }); + // See documentation for class TrainingEnvInitialzer earlier in this module // for an explanation as to why this is needed. auto atexit = py::module_::import("atexit"); diff --git a/orttraining/orttraining/python/training/ort_triton/_codegen.py b/orttraining/orttraining/python/training/ort_triton/_codegen.py index c071f01f87ea5..8e21013da2353 100644 --- a/orttraining/orttraining/python/training/ort_triton/_codegen.py +++ b/orttraining/orttraining/python/training/ort_triton/_codegen.py @@ -263,10 +263,20 @@ def ReduceKernelNode( # noqa: N802 "Rsqrt": "{indent}{o0} = 1.0 / tl.sqrt({i0})\n", "Cast": "{indent}{o0} = {i0}.to(tl.{dtype})\n", "CastBool": "{indent}{o0} = {i0} != 0\n", - "Erf": "{indent}{o0} = tl.libdevice.erf({i0})\n", - "Gelu": "{indent}{o0} = (tl.libdevice.erf({i0} / 1.41421356237) + 1.0) * 0.5\n", + "Erf": "{indent}{o0} = tl.erf({i0})\n", + "Gelu": "{indent}{o0} = {i0} * 0.5 * (tl.math.erf({i0} * 0.70710678118654752440) + 1.0)\n", + "QuickGelu": "{indent}{o0} = {i0} * tl.sigmoid({i0} * {alpha})\n", + "GeluGrad": ( + "{indent}{o0} = {i0} * (0.5 * (1.0 + tl.math.erf(0.70710678118654752440 * {i1})) + " + "{i1} * 1.12837916709551257390 * 0.70710678118654752440 * 0.5 * tl.exp(-0.5 * {i1} * {i1}))\n" + ), + "QuickGeluGrad": ( + "{indent}tmp_v = {i1} * {alpha}\n" + "{indent}tmp_sigmoid = tl.sigmoid(tmp_v)\n" + "{indent}{o0} = {i0} * tmp_sigmoid * (1.0 + tmp_v * (1.0 - tmp_sigmoid))\n" + ), "Exp": "{indent}{o0} = tl.exp({i0})\n", - "Tanh": "{indent}{o0} = tl.libdevice.tanh({i0})\n", + "Tanh": "{indent}{o0} = tl.math.tanh({i0})\n", "Where": "{indent}{o0} = tl.where({i0}, {i1}, {i2})\n", "Sigmoid": "{indent}{o0} = tl.sigmoid({i0})\n", "Log": "{indent}{o0} = tl.log({i0})\n", @@ -303,6 +313,9 @@ def ComputeNode( # noqa: N802 else: kwargs["dtype"] = to_dtype.__name__ + if op_type == "QuickGelu" or op_type == "QuickGeluGrad": + kwargs["alpha"] = str(node.attributes.get("alpha", 1.702)) + if op_type == "Sum": output_var = kwargs["o0"] formula = " + ".join([kwargs[f"i{idx}"] for idx in range(len(node.inputs))]) diff --git a/orttraining/orttraining/python/training/ort_triton/_common.py b/orttraining/orttraining/python/training/ort_triton/_common.py index 65540202420b5..82ac82cfa2919 100644 --- a/orttraining/orttraining/python/training/ort_triton/_common.py +++ b/orttraining/orttraining/python/training/ort_triton/_common.py @@ -131,6 +131,10 @@ class TypeAndShapeInfer: "ReduceMax": _infer_reduction, "ReduceMin": _infer_reduction, "Sum": _infer_elementwise, + "Gelu": _infer_unary, + "QuickGelu": _infer_unary, + "GeluGrad": _infer_elementwise, + "QuickGeluGrad": _infer_elementwise, } @classmethod diff --git a/orttraining/orttraining/python/training/ort_triton/_ir.py b/orttraining/orttraining/python/training/ort_triton/_ir.py index 8aa5c1b13159b..f7d3b31eac5b6 100644 --- a/orttraining/orttraining/python/training/ort_triton/_ir.py +++ b/orttraining/orttraining/python/training/ort_triton/_ir.py @@ -5,7 +5,7 @@ from abc import abstractmethod from collections import defaultdict -from typing import Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple import numpy as np import sympy @@ -184,14 +184,25 @@ class ComputeNode(IRNode): Each operator is represented as a ComputeNode. """ - def __init__(self, op_type: str, inputs: List[TensorArg], outputs: List[TensorArg]): + def __init__( + self, + op_type: str, + inputs: List[TensorArg], + outputs: List[TensorArg], + attributes: Dict[str, Any] = {}, # noqa: B006 + ): super().__init__(inputs, outputs) self._op_type: str = op_type + self._attributes: Dict[str, Any] = attributes @property def op_type(self): return self._op_type + @property + def attributes(self): + return self._attributes + class ReduceNode(ComputeNode): def __init__(self, op_type: str, inputs: List[TensorArg], outputs: List[TensorArg], offset_calc: OffsetCalculator): diff --git a/orttraining/orttraining/python/training/ort_triton/_lowering.py b/orttraining/orttraining/python/training/ort_triton/_lowering.py index 16db9ab000834..5de60e69437a0 100644 --- a/orttraining/orttraining/python/training/ort_triton/_lowering.py +++ b/orttraining/orttraining/python/training/ort_triton/_lowering.py @@ -9,7 +9,7 @@ from typing import Any, Dict, List, Set, Tuple import sympy -from onnx import NodeProto +from onnx import NodeProto, helper from ._common import AutotuneConfigs, TensorInfo from ._ir import ( @@ -378,7 +378,10 @@ def _to_compute_node(self, node: NodeProto, offset_calc: OffsetCalculator): return DropoutNode(inputs, outputs, offset_calc) if is_reduction_node(node): return ReduceNode(op_type, inputs, outputs, offset_calc) - return ComputeNode(op_type, inputs, outputs) + attributes = {} + for attr in node.attribute: + attributes[attr.name] = helper.get_attribute_value(attr) + return ComputeNode(op_type, inputs, outputs, attributes) def _analyze_kernel_io_list(self): cross_kernel_inputs = set() diff --git a/orttraining/orttraining/python/training/ort_triton/_op_config.py b/orttraining/orttraining/python/training/ort_triton/_op_config.py index f58d0e1847207..7d9af00933a75 100644 --- a/orttraining/orttraining/python/training/ort_triton/_op_config.py +++ b/orttraining/orttraining/python/training/ort_triton/_op_config.py @@ -36,6 +36,10 @@ "DropoutGrad": {"domain": "com.microsoft", "versions": [1]}, "Identity": {"versions": [13], "is_no_op": True}, "Sum": {"versions": [13]}, + "Gelu": {"domain": "com.microsoft", "versions": [1]}, + "QuickGelu": {"domain": "com.microsoft", "versions": [1]}, + "GeluGrad": {"domain": "com.microsoft", "versions": [1]}, + "QuickGeluGrad": {"domain": "com.microsoft", "versions": [1]}, } _REDUCTION_OPS = { diff --git a/orttraining/orttraining/test/framework/checkpointing_test.cc b/orttraining/orttraining/test/framework/checkpointing_test.cc index b91cc2f1d5f5f..a7ee776b9bc39 100644 --- a/orttraining/orttraining/test/framework/checkpointing_test.cc +++ b/orttraining/orttraining/test/framework/checkpointing_test.cc @@ -52,21 +52,18 @@ void CompareOrtValuesToTensorProtoValues( ASSERT_EQ(name_to_ort_value.size(), name_to_tensor_proto.size()); NameMLValMap name_to_ort_value_from_tensor_proto{}; - std::vector> tensor_buffers{}; + AllocatorPtr tmp_allocator = std::make_shared(); for (const auto& name_and_tensor_proto : name_to_tensor_proto) { const auto& name = name_and_tensor_proto.first; const auto& tensor_proto = name_and_tensor_proto.second; TensorShape shape{tensor_proto.dims().data(), static_cast(tensor_proto.dims().size())}; ASSERT_EQ(tensor_proto.data_type(), ONNX_NAMESPACE::TensorProto_DataType_FLOAT); - std::vector tensor_buffer(shape.Size() * sizeof(float)); - MemBuffer m(tensor_buffer.data(), tensor_buffer.size(), cpu_alloc_info); OrtValue ort_value; - ASSERT_STATUS_OK(utils::TensorProtoToMLValue( - Env::Default(), model_path.c_str(), tensor_proto, m, ort_value)); + ASSERT_STATUS_OK(utils::TensorProtoToOrtValue(Env::Default(), model_path.c_str(), tensor_proto, + tmp_allocator, ort_value)); name_to_ort_value_from_tensor_proto.emplace(name, ort_value); - tensor_buffers.emplace_back(std::move(tensor_buffer)); } for (const auto& name_and_ort_value : name_to_ort_value) { diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py index 318de843efb8f..d205e8f2377dd 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_triton.py @@ -135,8 +135,31 @@ def _torch_layer_norm(input, weight, bias, **kwargs): return torch.nn.functional.layer_norm(input, normalized_shape, weight, bias) +def _torch_gelu(input): + return torch.nn.functional.gelu(input) + + +def _torch_quick_gelu(input, **kwargs): + alpha = kwargs.get("alpha", 1.702) + return input * torch.sigmoid(input * alpha) + + +def _torch_gelu_grad(dy, x): + alpha = 0.70710678118654752440 + beta = 1.12837916709551257390 * 0.70710678118654752440 * 0.5 + cdf = 0.5 * (1 + torch.erf(x * alpha)) + pdf = beta * torch.exp(x * x * -0.5) + return dy * (cdf + x * pdf) + + +def _torch_quick_gelu_grad(dy, x, **kwargs): + alpha = kwargs.get("alpha", 1.702) + sigmoid = torch.sigmoid(x * alpha) + return dy * sigmoid * (1.0 + x * alpha * (1.0 - sigmoid)) + + class TorchFuncExecutor: - _INFER_FUNC_MAP = { # noqa: RUF012 + _TORCH_FUNC_MAP = { # noqa: RUF012 "Add": _torch_add, "Sub": _torch_sub, "Mul": _torch_mul, @@ -154,13 +177,17 @@ class TorchFuncExecutor: "ReduceMin": _torch_reduce_min, "Softmax": _torch_softmax, "LayerNormalization": _torch_layer_norm, + "Gelu": _torch_gelu, + "QuickGelu": _torch_quick_gelu, + "GeluGrad": _torch_gelu_grad, + "QuickGeluGrad": _torch_quick_gelu_grad, } @classmethod def run(cls, op_type, *torch_tensors, **kwargs): - if op_type not in cls._INFER_FUNC_MAP: + if op_type not in cls._TORCH_FUNC_MAP: raise NotImplementedError(f"Unsupported op type: {op_type}") - return cls._INFER_FUNC_MAP[op_type](*torch_tensors, **kwargs) + return cls._TORCH_FUNC_MAP[op_type](*torch_tensors, **kwargs) def _run_op_test(op_type, onnx_dtype, create_model_func, gen_inputs_func, **kwargs): @@ -169,6 +196,8 @@ def _run_op_test(op_type, onnx_dtype, create_model_func, gen_inputs_func, **kwar pt_inputs = gen_inputs_func(_onnx_dtype_to_torch_dtype(onnx_dtype)) ort_inputs = copy.deepcopy(pt_inputs) ort_inputs = [tensor.to(torch.uint8) if tensor.dtype == torch.bool else tensor for tensor in ort_inputs] + if "::" in op_type: + _, op_type = op_type.split("::") pt_outputs = TorchFuncExecutor.run(op_type, *pt_inputs, **kwargs) model_str = create_model_func(op_type, onnx_dtype, **kwargs).SerializeToString() ort_outputs = call_triton_by_onnx(hash(model_str), model_str, *[to_dlpack(tensor) for tensor in ort_inputs]) @@ -260,13 +289,27 @@ def _run_tunable_op_test(module_cls, dtype, gen_inputs_func, tunable_op, impl_co del os.environ["ORTMODULE_TUNING_RESULTS_PATH"] -@pytest.mark.parametrize("op_type", ["Add", "Sub", "Mul", "Div"]) +@pytest.mark.parametrize( + "op", + [ + ("Add", {}), + ("Sub", {}), + ("Mul", {}), + ("Div", {}), + ("com.microsoft::GeluGrad", {}), + ("com.microsoft::QuickGeluGrad", {}), + ("com.microsoft::QuickGeluGrad", {"alpha": 1.0}), + ], +) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16]) @pytest.mark.parametrize("input_shapes", [([1024, 2], [1024, 2]), ([2, 3, 3, 3], [3, 1, 3]), ([2049], [1])]) -def test_binary_elementwise_op(op_type, onnx_dtype, input_shapes): - def _create_model(op_type, onnx_dtype): +def test_binary_elementwise_op(op, onnx_dtype, input_shapes): + def _create_model(op_type, onnx_dtype, **kwargs): + domain = "" + if "::" in op_type: + domain, op_type = op_type.split("::") graph = helper.make_graph( - [helper.make_node(op_type, ["X", "Y"], ["Z"], name="test")], + [helper.make_node(op_type, ["X", "Y"], ["Z"], name="test", domain=domain, **kwargs)], "test", [ helper.make_tensor_value_info("X", onnx_dtype, None), @@ -282,7 +325,7 @@ def _gen_inputs(dtype): torch.randn(*input_shapes[1], dtype=dtype, device=DEVICE), ] - _run_op_test(op_type, onnx_dtype, _create_model, _gen_inputs) + _run_op_test(op[0], onnx_dtype, _create_model, _gen_inputs, **op[1]) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16]) @@ -303,13 +346,25 @@ def _gen_inputs(dtype): _run_op_test("Sum", onnx_dtype, _create_model, _gen_inputs) -@pytest.mark.parametrize("op_type", ["Sqrt", "Exp"]) +@pytest.mark.parametrize( + "op", + [ + ("Sqrt", {}), + ("Exp", {}), + ("com.microsoft::Gelu", {}), + ("com.microsoft::QuickGelu", {}), + ("com.microsoft::QuickGelu", {"alpha": 1.0}), + ], +) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16]) @pytest.mark.parametrize("input_shape", [[1024, 4], [2, 3, 3, 3], [2049, 1]]) -def test_unary_elementwise_op(op_type, onnx_dtype, input_shape): - def _create_model(op_type, onnx_dtype): +def test_unary_elementwise_op(op, onnx_dtype, input_shape): + def _create_model(op_type, onnx_dtype, **kwargs): + domain = "" + if "::" in op_type: + domain, op_type = op_type.split("::") graph = helper.make_graph( - [helper.make_node(op_type, ["X"], ["Y"], name="test")], + [helper.make_node(op_type, ["X"], ["Y"], name="test", domain=domain, **kwargs)], "test", [helper.make_tensor_value_info("X", onnx_dtype, None)], [helper.make_tensor_value_info("Y", onnx_dtype, None)], @@ -319,7 +374,7 @@ def _create_model(op_type, onnx_dtype): def _gen_inputs(dtype): return [torch.rand(*input_shape, dtype=dtype, device=DEVICE)] - _run_op_test(op_type, onnx_dtype, _create_model, _gen_inputs) + _run_op_test(op[0], onnx_dtype, _create_model, _gen_inputs, **op[1]) @pytest.mark.parametrize("onnx_dtype", [TensorProto.FLOAT, TensorProto.FLOAT16]) diff --git a/requirements-dev.txt b/requirements-dev.txt index 73e04e6b37c0b..1b5ca65cf8037 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -17,4 +17,4 @@ scikit-learn scipy sympy wheel -setuptools>=41.4.0 +setuptools>=61.0.0 diff --git a/requirements-training.txt b/requirements-training.txt index 4b1be6cef9b7c..dbfd7305d1bec 100644 --- a/requirements-training.txt +++ b/requirements-training.txt @@ -6,4 +6,4 @@ onnx packaging protobuf sympy -setuptools>=41.4.0 +setuptools>=61.0.0 diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index c4fb5499983cb..c25a61ee5d568 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2055,7 +2055,7 @@ def build_nuget_package( ): if not (is_windows() or is_linux()): raise BuildError( - "Currently csharp builds and nuget package creation is only supportted on Windows and Linux platforms." + "Currently csharp builds and nuget package creation is only supported on Windows and Linux platforms." ) csharp_build_dir = os.path.join(source_dir, "csharp") diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index f17bc8de5739b..cf73691a5eecc 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.81 + version: 1.0.90 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.81 + version: 1.0.90 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml index 8868e671a5fa5..09c52f4d5ba0d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/win-ci-prebuild-steps.yml @@ -31,7 +31,7 @@ steps: architecture: ${{parameters.BuildArch}} - script: | - python -m pip install -q setuptools wheel numpy flatbuffers + python -m pip install --upgrade "setuptools>=68.2.2" wheel numpy flatbuffers workingDirectory: '$(Build.BinariesDirectory)' displayName: 'Install python modules' diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index 96a0ebd753d8e..fe7f752513f3c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -98,7 +98,6 @@ jobs: cd '$(Build.SourcesDirectory)/cmake/external/emsdk' ./emsdk install 3.1.44 ccache-git-emscripten-64bit ./emsdk activate 3.1.44 ccache-git-emscripten-64bit - ln -s $(Build.SourcesDirectory)/cmake/external/emsdk/ccache/git-emscripten_64bit/bin/ccache /usr/local/bin/ccache displayName: 'emsdk install and activate ccache for emscripten' condition: eq('${{ parameters.WithCache }}', 'true') diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index d737376eb99b5..788b02f539821 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -29,6 +29,7 @@ jobs: pool: ${{ parameters.PoolName }} variables: + webgpuCommandlineExtraFlags: '--chromium-flags=--ignore-gpu-blocklist --chromium-flags=--gpu-vendor-id=0x10de' runCodesignValidationInjection: false timeoutInMinutes: 60 workspace: @@ -159,12 +160,22 @@ jobs: npm test -- -e=edge -b=webgl,wasm,xnnpack workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (wasm,webgl,xnnpack backend)' - condition: ne('${{ parameters.RunWebGpuTests }}', 'true') + condition: eq('${{ parameters.RunWebGpuTests }}', 'false') - script: | - npm test -- -e=edge -b=webgl,wasm,xnnpack,webgpu --chromium-flags=--ignore-gpu-blocklist --chromium-flags=--gpu-vendor-id=0x10de + npm test -- -e=edge -b=webgl,wasm,xnnpack,webgpu $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (ALL backends)' - condition: ne('${{ parameters.RunWebGpuTests }}', 'false') + condition: eq('${{ parameters.RunWebGpuTests }}', 'true') + - script: | + npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-tensor $(webgpuCommandlineExtraFlags) + workingDirectory: '$(Build.SourcesDirectory)\js\web' + displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-tensor)' + condition: eq('${{ parameters.RunWebGpuTests }}', 'true') + - script: | + npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-location $(webgpuCommandlineExtraFlags) + workingDirectory: '$(Build.SourcesDirectory)\js\web' + displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-location)' + condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - script: | npm test -- --webgl-texture-pack-mode -b=webgl -e=edge workingDirectory: '$(Build.SourcesDirectory)\js\web' diff --git a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 index cdf504c8e3b03..bbdb411b790a0 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 @@ -22,7 +22,7 @@ RUN dnf install -y \ ln -s /usr/bin/pip3 pip3.8; RUN pip3 install --upgrade pip -RUN pip3 install setuptools>=41.0.0 +RUN pip3 install setuptools>=68.2.2 # Install TensorRT RUN dnf install -y libnvinfer8 libnvonnxparsers8 libnvparsers8 libnvinfer-plugin8 libnvinfer-lean8 libnvinfer-vc-plugin8 libnvinfer-dispatch8 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4 index 10f404c7c6a85..8b32425afce1c 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_6_tensorrt8_4 @@ -31,7 +31,7 @@ RUN apt-get install -y --no-install-recommends \ ln -s /usr/bin/pip3 pip; RUN pip install --upgrade pip -RUN pip install setuptools>=41.0.0 +RUN pip install setuptools>=68.2.2 # Install TensorRT RUN v="8.4.1-1+cuda11.6" &&\ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5 index cacc09f0c7455..cfc7023ef8e61 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_5 @@ -28,7 +28,7 @@ RUN apt-get install -y --no-install-recommends \ ln -s /usr/bin/pip3 pip; RUN pip install --upgrade pip -RUN pip install setuptools>=41.0.0 +RUN pip install setuptools>=68.2.2 # Install TensorRT RUN v="8.5.1-1+cuda11.8" &&\ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 index 0a4885e774047..edc41197be5c9 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_cuda11_8_tensorrt8_6 @@ -28,7 +28,7 @@ RUN apt-get install -y --no-install-recommends \ ln -s /usr/bin/pip3 pip; RUN pip install --upgrade pip -RUN pip install setuptools>=41.0.0 +RUN pip install setuptools>=68.2.2 # Install TensorRT RUN v="8.6.1.6-1+cuda11.8" &&\ diff --git a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin index c9308ade37396..21b09b2d8978e 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin +++ b/tools/ci_build/github/linux/docker/Dockerfile.ubuntu_tensorrt_bin @@ -42,7 +42,7 @@ RUN apt-get install -y --no-install-recommends \ ln -s /usr/bin/pip3 pip; RUN pip install --upgrade pip -RUN pip install setuptools>=41.0.0 +RUN pip install setuptools>=68.2.2 # Install TensorRT from tar.gz RUN tar -xzvf /TensorRT-${TAR_TRT_VERSION}.tar.gz diff --git a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt index 8a9c4dac1dd58..5341ae062d332 100644 --- a/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/inference/x64/python/cpu/scripts/requirements.txt @@ -2,9 +2,9 @@ numpy==1.21.6 ; python_version < '3.11' numpy==1.24.2 ; python_version >= '3.11' mypy pytest -setuptools>=41.4.0 +setuptools>=68.2.2 wheel -git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx +git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx protobuf==3.20.2 sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt index 6b8003c01c24d..b2893286803b0 100644 --- a/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/manylinux/requirements.txt @@ -2,9 +2,9 @@ numpy==1.21.6 ; python_version < '3.11' numpy==1.24.2 ; python_version >= '3.11' mypy pytest -setuptools>=41.4.0 +setuptools>=68.2.2 wheel -git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx +git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx protobuf==3.20.2 sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/scripts/requirements.txt b/tools/ci_build/github/linux/docker/scripts/requirements.txt index 9dbe856753faa..5d48a93b09c90 100644 --- a/tools/ci_build/github/linux/docker/scripts/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/requirements.txt @@ -3,9 +3,9 @@ numpy==1.21.6 ; python_version < '3.11' numpy==1.24.2 ; python_version >= '3.11' mypy pytest -setuptools>=41.4.0 +setuptools>=68.2.2 wheel>=0.35.1 -git+http://github.com/onnx/onnx.git@e2525550194ce3d8a2c4a3af451c9d9b3ae6650e#egg=onnx +git+http://github.com/onnx/onnx.git@fdefbe85ed9c362b95b9b401cd19db068a76141f#egg=onnx argparse sympy==1.12 flatbuffers diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.0.0_cu11.8/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.0.0_cu11.8/requirements.txt index fa28a810370f7..b3b2651c8d26d 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.0.0_cu11.8/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch2.0.0_cu11.8/requirements.txt @@ -5,4 +5,4 @@ torchvision==0.15.1+cu118 torchtext==0.15.1 # TODO(bmeswani): packaging 22.0 removes support for LegacyVersion leading to errors because transformers 4.4.2 uses LegacyVersion packaging==21.3 -setuptools>=41.4.0 +setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt index 94b16d7ff4894..95d02b8400339 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/requirements_torch_cpu/requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/torch_stable.html torch==2.0.0+cpu -setuptools>=41.4.0 +setuptools>=68.2.2 diff --git a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt index b071770c629e2..08e251eddbf96 100644 --- a/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt +++ b/tools/ci_build/github/linux/docker/scripts/training/ortmodule/stage1/torch_eager_cpu/requirements.txt @@ -1,7 +1,7 @@ --pre -f https://download.pytorch.org/whl/torch_stable.html torch==1.13.1+cpu -setuptools>=41.4.0 +setuptools>=68.2.2 cerberus h5py scikit-learn diff --git a/tools/ci_build/github/windows/helpers.ps1 b/tools/ci_build/github/windows/helpers.ps1 index 6e81f901a8288..20df10b244408 100644 --- a/tools/ci_build/github/windows/helpers.ps1 +++ b/tools/ci_build/github/windows/helpers.ps1 @@ -444,7 +444,7 @@ function Install-Abseil { .Description The Install-UTF8-Range function installs Google's utf8_range library from source. utf8_range depends on Abseil. - + .PARAMETER cmake_path The full path of cmake.exe @@ -604,14 +604,17 @@ function Install-ONNX { pushd . Write-Host "Uninstalling onnx and ignore errors if there is any..." - python.exe -m pip uninstall -y onnx -qq + [string[]]$pip_args ="-m", "pip", "uninstall", "-y", "onnx", "-qq" + &"python.exe" $pip_args + if ($lastExitCode -ne 0) { + exit $lastExitCode + } Write-Host "Installing python packages..." - $p = Start-Process -NoNewWindow -Wait -PassThru -FilePath "python.exe" -ArgumentList "-m", "pip", "install", "--disable-pip-version-check", "setuptools", "wheel", "numpy", "protobuf==$protobuf_version" - $exitCode = $p.ExitCode - if ($exitCode -ne 0) { - Write-Host -Object "Install dependent python wheels failed. Exitcode: $exitCode" - exit $exitCode + [string[]]$pip_args = "-m", "pip", "install", "-qq", "--disable-pip-version-check", "setuptools>=68.2.2", "wheel", "numpy", "protobuf==$protobuf_version" + &"python.exe" $pip_args + if ($lastExitCode -ne 0) { + exit $lastExitCode } $url=Get-DownloadURL -name onnx -src_root $src_root diff --git a/winml/adapter/winml_adapter_environment.cpp b/winml/adapter/winml_adapter_environment.cpp index 43babdf43967e..e2da473c7d5b5 100644 --- a/winml/adapter/winml_adapter_environment.cpp +++ b/winml/adapter/winml_adapter_environment.cpp @@ -9,6 +9,7 @@ #include "winml_adapter_apis.h" #include "core/framework/error_code_helper.h" #include "core/session/ort_env.h" +#include "core/session/user_logging_sink.h" #ifdef USE_DML #include "abi_custom_registry_impl.h" @@ -18,12 +19,12 @@ #endif USE_DML namespace winmla = Windows::AI::MachineLearning::Adapter; -class WinmlAdapterLoggingWrapper : public LoggingWrapper { +class WinmlAdapterLoggingWrapper : public onnxruntime::UserLoggingSink { public: WinmlAdapterLoggingWrapper( OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function, void* logger_param ) - : LoggingWrapper(logging_function, logger_param), + : onnxruntime::UserLoggingSink(logging_function, logger_param), profiling_function_(profiling_function) {} void SendProfileEvent(onnxruntime::profiling::EventRecord& event_record) const override { diff --git a/winml/test/model/model_tests.cpp b/winml/test/model/model_tests.cpp index 0b4c10eac9142..5057f74046638 100644 --- a/winml/test/model/model_tests.cpp +++ b/winml/test/model/model_tests.cpp @@ -380,13 +380,6 @@ std::string GetFullNameOfTest(ITestCase* testCase, winml::LearningModelDeviceKin name += tokenizedModelPath[tokenizedModelPath.size() - 2] += "_"; // model name name += tokenizedModelPath[tokenizedModelPath.size() - 3]; // opset version - // To introduce models from model zoo, the model path is structured like this "///?.onnx" - std::string source = tokenizedModelPath[tokenizedModelPath.size() - 4]; - // `models` means the root of models, to be ompatible with the old structure, that is, the source name is empty. - if (source != "models") { - name += "_" + source; - } - std::replace_if( name.begin(), name.end(), [](char c) { return !absl::ascii_isalnum(c); }, '_' ); @@ -405,6 +398,13 @@ std::string GetFullNameOfTest(ITestCase* testCase, winml::LearningModelDeviceKin ModifyNameIfDisabledTest(/*inout*/ name, deviceKind); } + // To introduce models from model zoo, the model path is structured like this "///?.onnx" + std::string source = tokenizedModelPath[tokenizedModelPath.size() - 4]; + // `models` means the root of models, to be ompatible with the old structure, that is, the source name is empty. + if (source != "models") { + name += "_" + source; + } + return name; }