diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml
index 6c3f2eb0fbbe1..725c40c2ded53 100644
--- a/.github/workflows/rust-ci.yml
+++ b/.github/workflows/rust-ci.yml
@@ -24,7 +24,7 @@ jobs:
name: Download prebuilt ONNX Runtime archive from build.rs
runs-on: ubuntu-latest
env:
- ORT_RUST_STRATEGY=download
+ ORT_RUST_STRATEGY: download
steps:
- uses: actions/checkout@v4
- uses: ./.github/actions/rust-toolchain-setup
diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml
index 95607f297c6bd..3ef5076583001 100644
--- a/.github/workflows/stale.yml
+++ b/.github/workflows/stale.yml
@@ -29,7 +29,7 @@ jobs:
# Label you want to apply to issues that have been inactive for the amount of time specified by days-before-issue-stale
stale-issue-label: "stale"
# Comment that you want to add to issues that are labeled by the actions/stale action
- stale-issue-message: "This issue has been automatically marked as stale due to inactivity and will be closed in 7 days if no further activity occurs. If further support is needed, please provide an update and/or more details."
+ stale-issue-message: "This issue has been automatically marked as stale due to inactivity and will be closed in 30 days if no further activity occurs. If further support is needed, please provide an update and/or more details."
# Comment that you want to add to issues that are closed by the actions/stale action
close-issue-message: "This issue has been automatically closed due to inactivity. Please reactivate if further support is needed."
# If you never want this action to label PRs, set this value to -1
diff --git a/.github/workflows/windows.yml b/.github/workflows/windows.yml
index ba24e7eebfb03..3a780f87d2300 100644
--- a/.github/workflows/windows.yml
+++ b/.github/workflows/windows.yml
@@ -49,13 +49,10 @@ jobs:
- uses: actions/checkout@v4
with:
submodules: true
- - uses: actions/setup-python@v4
- with:
- python-version: '3.8.x'
- architecture: 'x64'
- uses: conda-incubator/setup-miniconda@v2
with:
- activate-environment: ""
+ activate-environment: "ort_build"
+ python-version: 3.8
- name: 'Install LLVM-Dev'
shell: pwsh
run: |
diff --git a/.gitignore b/.gitignore
index 6937f338b8a6b..4d0a1205b7c19 100644
--- a/.gitignore
+++ b/.gitignore
@@ -195,3 +195,4 @@ Package.pins
Package.resolved
.build/
.swiftpm/
+repros/
diff --git a/.pipelines/windowsai-steps.yml b/.pipelines/windowsai-steps.yml
index 45ebf889c5da1..292ce60c6b6cf 100644
--- a/.pipelines/windowsai-steps.yml
+++ b/.pipelines/windowsai-steps.yml
@@ -84,7 +84,7 @@ jobs:
7z x cmake-3.26.3-windows-x86_64.zip
set PYTHONHOME=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools
set PYTHONPATH=$(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools
- $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe
+ $(Build.BinariesDirectory)\${{ parameters.PythonPackageName }}.3.9.7\tools\python.exe "$(Build.SourcesDirectory)\tools\ci_build\build.py" --build_dir $(Build.BinariesDirectory) --build_shared_lib --enable_onnx_tests --ms_experimental --use_dml --use_winml --cmake_generator "Visual Studio 17 2022" --update --config RelWithDebInfo --enable_lto --use_telemetry --disable_rtti --enable_wcos $(BuildFlags) --cmake_extra_defines "CMAKE_EXE_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" "CMAKE_SHARED_LINKER_FLAGS_RELWITHDEBINFO=/PROFILE" CMAKE_SYSTEM_VERSION=10.0.19041.0 --cmake_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\cmake.exe --ctest_path $(Build.BinariesDirectory)\cmake-3.26.3-windows-x86_64\bin\ctest.exe
workingDirectory: '$(Build.BinariesDirectory)'
displayName: 'Generate cmake config'
diff --git a/README.md b/README.md
index 22ef387f5a7cd..33bce867e3bde 100644
--- a/README.md
+++ b/README.md
@@ -30,7 +30,7 @@
|Android|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/Android%20CI%20Pipeline?label=Android)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=53)||
|iOS|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/iOS%20CI%20Pipeline?label=iOS)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=134)||
|Web|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/ONNX%20Runtime%20Web%20CI%20Pipeline?label=Web)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=161)||
-|Other|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/onnxruntime-binary-size-checks-ci-pipeline?repoName=microsoft%2Fonnxruntime&label=Binary+Size+Check)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=187&repoName=microsoft%2Fonnxruntime)
[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/onnxruntime-python-checks-ci-pipeline?label=Python+Checks)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=164)||
+|Other|[![Build Status](https://dev.azure.com/onnxruntime/onnxruntime/_apis/build/status/onnxruntime-binary-size-checks-ci-pipeline?repoName=microsoft%2Fonnxruntime&label=Binary+Size+Check)](https://dev.azure.com/onnxruntime/onnxruntime/_build/latest?definitionId=187&repoName=microsoft%2Fonnxruntime)||
## Third-party Pipeline Status
diff --git a/build_arm64x.bat b/build_arm64x.bat
new file mode 100644
index 0000000000000..fbcdd373086a9
--- /dev/null
+++ b/build_arm64x.bat
@@ -0,0 +1,12 @@
+:: Copyright (c) Microsoft Corporation. All rights reserved.
+:: Licensed under the MIT License.
+
+@echo off
+
+setlocal
+set PATH=C:\Program Files\Git\usr\bin;%PATH%
+set LINK_REPRO_NAME=/mylink.rsp
+
+rem Requires a Python install to be available in your PATH
+python "%~dp0\tools\ci_build\build.py" --arm64 --buildasx --build_dir "%~dp0\build\arm64-x" %*
+python "%~dp0\tools\ci_build\build.py" --arm64ec --buildasx --build_dir "%~dp0\build\arm64ec-x" %*
diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json
index 12fbb291c3a70..137ea8a50c011 100644
--- a/cgmanifests/generated/cgmanifest.json
+++ b/cgmanifests/generated/cgmanifest.json
@@ -36,7 +36,7 @@
"component": {
"type": "git",
"git": {
- "commitHash": "29bf8085f3bf17b84d30e34b3d7ff8248fda404e",
+ "commitHash": "dcd5bd5fd593e31465af3d9ef291d26c646b0a4f",
"repositoryUrl": "https://github.com/abseil/abseil-cpp.git"
},
"comments": "abseil_cpp"
@@ -126,7 +126,7 @@
"component": {
"type": "git",
"git": {
- "commitHash": "f8d7d77c06936315286eb55f8de22cd23c188571",
+ "commitHash": "530d5c8c84abd2a46f38583ee817743c9b3a42b4",
"repositoryUrl": "https://github.com/google/googletest.git"
},
"comments": "googletest"
@@ -316,7 +316,7 @@
"component": {
"type": "git",
"git": {
- "commitHash": "a4f72a314a85732ed67d5aa8d1088d207a7e0e61",
+ "commitHash": "5356c4a943a35e74d7cdc69486afcb8703b9a59a",
"repositoryUrl": "https://github.com/ROCmSoftwarePlatform/composable_kernel.git"
},
"comments": "composable_kernel"
diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt
index e82219a0aff64..7494035e4784e 100644
--- a/cmake/CMakeLists.txt
+++ b/cmake/CMakeLists.txt
@@ -1258,13 +1258,7 @@ if (onnxruntime_USE_OPENVINO)
endif()
# Check OpenVINO version for support
- if (${VER} MATCHES "2022.1" OR $ENV{INTEL_OPENVINO_DIR} MATCHES "2022.1")
- set(OPENVINO_VERSION "2022.1")
- add_definitions(-DOPENVINO_2022_1=1)
- elseif (${VER} MATCHES "2022.2" OR $ENV{INTEL_OPENVINO_DIR} MATCHES "2022.2")
- set(OPENVINO_VERSION "2022.2")
- add_definitions(-DOPENVINO_2022_2=1)
- elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2022.3")
+ if ($ENV{INTEL_OPENVINO_DIR} MATCHES "2022.3")
set(OPENVINO_VERSION "2022.3")
add_definitions(-DOPENVINO_2022_3=1)
elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.0")
@@ -1273,9 +1267,12 @@ if (onnxruntime_USE_OPENVINO)
elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.1")
set(OPENVINO_VERSION "2023.1")
add_definitions(-DOPENVINO_2023_1=1)
- elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "openvino")
- set(OPENVINO_VERSION "2023.1")
+ elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "2023.2")
+ set(OPENVINO_VERSION "2023.2")
add_definitions(-DOPENVINO_2023_1=1)
+ elseif ($ENV{INTEL_OPENVINO_DIR} MATCHES "openvino")
+ set(OPENVINO_VERSION "2023.2")
+ add_definitions(-DOPENVINO_2023_2=1)
else()
message(FATAL_ERROR "Unsupported OpenVINO version: ${INTEL_OPENVINO_DIR}")
endif()
@@ -1587,6 +1584,13 @@ set(VERSION_STRING "Internal Build" CACHE STRING "String representation of
if (WIN32)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${SYS_PATH_LIB})
list(APPEND onnxruntime_EXTERNAL_LIBRARIES debug Dbghelp)
+ # In a onecore build the umbrella libs already contains references to the APIs in advapi32, so in onecore build we do not need to link to advapi32
+ # In a non-onecore build, usually we also do not need to link to advapi32 because VC++ by default should have provide everything we need, except when the build target is Windows ARM32.
+ # In the future we will add a build option to allow users disabling all API uses from advapi32 because some Windows environments do not have these APIs. For example, some Windows do not have
+ # Windows Registry so we cannot query Registry values.
+ if(onnxruntime_target_platform STREQUAL "ARM" AND CMAKE_CXX_STANDARD_LIBRARIES MATCHES kernel32.lib)
+ list(APPEND onnxruntime_EXTERNAL_LIBRARIES advapi32)
+ endif()
else()
list(APPEND onnxruntime_EXTERNAL_LIBRARIES nsync::nsync_cpp)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES ${ICONV_LIB} ${CMAKE_DL_LIBS} Threads::Threads)
@@ -1776,3 +1780,8 @@ if(TARGET onnxruntime)
"${PROJECT_BINARY_DIR}/${PROJECT_NAME}ConfigVersion.cmake"
DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME}")
endif()
+
+if(DEFINED BUILD_AS_ARM64X)
+ set(ARM64X_TARGETS onnxruntime)
+ include("${CMAKE_SOURCE_DIR}/arm64x.cmake")
+endif()
diff --git a/cmake/arm64x.cmake b/cmake/arm64x.cmake
new file mode 100644
index 0000000000000..be476e09625bd
--- /dev/null
+++ b/cmake/arm64x.cmake
@@ -0,0 +1,33 @@
+set(arm64ReproDir "${CMAKE_SOURCE_DIR}/repros")
+
+if("${BUILD_AS_ARM64X}" STREQUAL "ARM64")
+ foreach (n ${ARM64X_TARGETS})
+ add_custom_target(mkdirs_${n} ALL COMMAND cmd /c (if exist \"${arm64ReproDir}/${n}_temp/\" rmdir /s /q \"${arm64ReproDir}/${n}_temp\") && mkdir \"${arm64ReproDir}/${n}_temp\" )
+ add_dependencies(${n} mkdirs_${n})
+ target_link_options(${n} PRIVATE "/LINKREPRO:${arm64ReproDir}/${n}_temp")
+ add_custom_target(${n}_checkRepro ALL COMMAND cmd /c if exist \"${n}_temp/*.obj\" if exist \"${n}\" rmdir /s /q \"${n}\" 2>nul && if not exist \"${n}\" ren \"${n}_temp\" \"${n}\" DEPENDS ${n}
+ WORKING_DIRECTORY ${arm64ReproDir})
+ endforeach()
+
+
+elseif("${BUILD_AS_ARM64X}" STREQUAL "ARM64EC")
+ foreach (n ${ARM64X_TARGETS})
+ set(ARM64_LIBS)
+ set(ARM64_OBJS)
+ set(ARM64_DEF)
+
+ file(GLOB ARM64_OBJS "${arm64ReproDir}/${n}/*.obj")
+ file(GLOB ARM64_DEF "${arm64ReproDir}/${n}/*.def")
+ file(GLOB ARM64_LIBS "${arm64ReproDir}/${n}/*.LIB")
+
+ if(NOT "${ARM64_DEF}" STREQUAL "")
+ set(ARM64_DEF "/defArm64Native:${ARM64_DEF}")
+ endif()
+ target_sources(${n} PRIVATE ${ARM64_OBJS})
+ target_link_options(${n} PRIVATE /machine:arm64x "${ARM64_DEF}")
+
+ if(NOT "${ARM64_LIBS}" STREQUAL "")
+ target_link_libraries(${n} PUBLIC ${ARM64_LIBS})
+ endif()
+ endforeach()
+endif()
diff --git a/cmake/deps.txt b/cmake/deps.txt
index e065cacdfc423..ff07803013071 100644
--- a/cmake/deps.txt
+++ b/cmake/deps.txt
@@ -12,7 +12,7 @@
# NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI.
# See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29
#
-abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20230802.0.zip;04271dfbfac59269b6939e1e9d5faf0d18a7ba91
+abseil_cpp;https://github.com/abseil/abseil-cpp/archive/dcd5bd5fd593e31465af3d9ef291d26c646b0a4f.zip;6cc204586014e189f5c0fe3274f83162fa7c700c
cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0
date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159
dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445
@@ -27,7 +27,7 @@ fp16;https://github.com/Maratyszcza/FP16/archive/0a92994d729ff76a58f692d3028ca1b
fxdiv;https://github.com/Maratyszcza/FXdiv/archive/63058eff77e11aa15bf531df5dd34395ec3017c8.zip;a5658f4036402dbca7cebee32be57fb8149811e1
google_benchmark;https://github.com/google/benchmark/archive/refs/tags/v1.7.0.zip;e97c368b176e8614e3f1bf13dd9abcf6a7ad9908
google_nsync;https://github.com/google/nsync/archive/refs/tags/1.26.0.zip;5e7c00ef6bf5b787386fc040067903ec774e2752
-googletest;https://github.com/google/googletest/archive/refs/tags/v1.14.0.zip;0ac421f2ec11af38b0fff0f1992184032731a8bc
+googletest;https://github.com/google/googletest/archive/530d5c8c84abd2a46f38583ee817743c9b3a42b4.zip;5e3a61db2aa975cfd0f97ba92c818744e7fa7034
googlexnnpack;https://github.com/google/XNNPACK/archive/0da379fc4808f9601faef392352018c741c0f297.zip;663883491e380b628e0a5b162b5f2658032fae73
json;https://github.com/nlohmann/json/archive/refs/tags/v3.10.5.zip;f257f8dc27c5b8c085dc887b40cddd18ae1f725c
microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf368104cd22a87b4dd0c80228919bb2df3e2a14
diff --git a/cmake/external/dnnl.cmake b/cmake/external/dnnl.cmake
index 397c4d6abeb9a..d7b70640781d0 100644
--- a/cmake/external/dnnl.cmake
+++ b/cmake/external/dnnl.cmake
@@ -25,6 +25,16 @@ elseif(onnxruntime_USE_DNNL AND onnxruntime_DNNL_GPU_RUNTIME STREQUAL "ocl" AND
set(DNNL_GPU_CMAKE_ARGS "-DDNNL_GPU_RUNTIME=OCL " "-DOPENCLROOT=${onnxruntime_DNNL_OPENCL_ROOT}")
endif()
+if(onnxruntime_USE_DNNL AND onnxruntime_DNNL_AARCH64_RUNTIME STREQUAL "acl" AND onnxruntime_DNNL_ACL_ROOT STREQUAL "")
+ message(FATAL_ERROR "--dnnl_acl_root required")
+elseif(onnxruntime_USE_DNNL AND onnxruntime_DNNL_AARCH64_RUNTIME STREQUAL "" AND NOT (onnxruntime_DNNL_ACL_ROOT STREQUAL ""))
+ message(FATAL_ERROR "--dnnl_aarch64_runtime required")
+elseif(onnxruntime_USE_DNNL AND onnxruntime_DNNL_AARCH64_RUNTIME STREQUAL "acl" AND NOT (onnxruntime_DNNL_ACL_ROOT STREQUAL ""))
+ file(TO_CMAKE_PATH ${onnxruntime_DNNL_ACL_ROOT} onnxruntime_DNNL_ACL_ROOT)
+ set(ACL_INCLUDE_DIR ${onnxruntime_DNNL_ACL_ROOT}/arm_compute)
+ set(DNNL_AARCH64_CMAKE_ARGS "-DDNNL_AARCH64_USE_ACL=ON")
+endif()
+
if (onnxruntime_USE_DNNL)
set(DNNL_SOURCE ${CMAKE_CURRENT_BINARY_DIR}/dnnl/src/dnnl/src)
set(DNNL_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/dnnl/install)
@@ -51,7 +61,7 @@ if (onnxruntime_USE_DNNL)
GIT_TAG ${DNNL_TAG}
# PATCH_COMMAND ${MKLDNN_PATCH_DISCARD_COMMAND} COMMAND ${DNNL_PATCH_COMMAND}
SOURCE_DIR ${DNNL_SOURCE}
- CMAKE_ARGS -DDNNL_BUILD_TESTS=OFF -DDNNL_ENABLE_CONCURRENT_EXEC=ON -DDNNL_BUILD_EXAMPLES=OFF -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX=${DNNL_INSTALL} ${DNNL_GPU_CMAKE_ARGS}
+ CMAKE_ARGS -DDNNL_BUILD_TESTS=OFF -DDNNL_ENABLE_CONCURRENT_EXEC=ON -DDNNL_BUILD_EXAMPLES=OFF -DCMAKE_BUILD_TYPE=${CMAKE_BUILD_TYPE} -DCMAKE_INSTALL_PREFIX=${DNNL_INSTALL} ${DNNL_GPU_CMAKE_ARGS} ${DNNL_AARCH64_CMAKE_ARGS}
)
link_directories(${DNNL_LIB_DIR})
endif()
diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake
index 0fa5163dc06bf..78f63227c8392 100644
--- a/cmake/external/onnxruntime_external_deps.cmake
+++ b/cmake/external/onnxruntime_external_deps.cmake
@@ -47,8 +47,8 @@ if (onnxruntime_BUILD_UNIT_TESTS)
FetchContent_Declare(
googletest
URL ${DEP_URL_googletest}
- FIND_PACKAGE_ARGS 1.14.0...<2.0.0 NAMES GTest
URL_HASH SHA1=${DEP_SHA1_googletest}
+ FIND_PACKAGE_ARGS 1.14.0...<2.0.0 NAMES GTest
)
endif()
@@ -124,7 +124,7 @@ if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE)
if(protoc_binary_SOURCE_DIR)
message("Use prebuilt protoc")
set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe)
- set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE})
+ set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE})
endif()
elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux")
if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$")
@@ -140,7 +140,7 @@ if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE)
if(protoc_binary_SOURCE_DIR)
message("Use prebuilt protoc")
set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc)
- set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE})
+ set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE})
endif()
elseif ((CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") AND CMAKE_HOST_SYSTEM_NAME STREQUAL "Darwin")
FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_mac_universal} URL_HASH SHA1=${DEP_SHA1_protoc_mac_universal})
@@ -281,7 +281,7 @@ if ((CPUINFO_SUPPORTED OR onnxruntime_USE_XNNPACK) AND NOT ANDROID)
pytorch_clog
URL ${DEP_URL_pytorch_cpuinfo}
URL_HASH SHA1=${DEP_SHA1_pytorch_cpuinfo}
- SOURCE_SUBDIR deps/clog
+ SOURCE_SUBDIR deps/clog
)
set(ONNXRUNTIME_CLOG_PROJ pytorch_clog)
set(ONNXRUNTIME_CLOG_TARGET_NAME clog)
diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
index 04efa5c2b4f6d..26e4380af4c23 100644
--- a/cmake/onnxruntime_mlas.cmake
+++ b/cmake/onnxruntime_mlas.cmake
@@ -284,6 +284,8 @@ else()
set(X86 TRUE)
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$")
set(X86_64 TRUE)
+ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^loongarch64.*")
+ set(LOONGARCH64 TRUE)
endif()
endif()
@@ -575,6 +577,26 @@ else()
set(MLAS_SOURCE_IS_NOT_SET 0)
endif()
endif()
+ if(LOONGARCH64 AND MLAS_SOURCE_IS_NOT_SET)
+ set(mlas_platform_srcs
+ ${MLAS_SRC_DIR}/qgemm_kernel_lsx.cpp
+ ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLasx.S
+ ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLsx.S
+ ${MLAS_SRC_DIR}/loongarch64/DgemmKernelLasx.S
+ ${MLAS_SRC_DIR}/loongarch64/SgemmKernelLsx.S
+ ${MLAS_SRC_DIR}/loongarch64/SconvKernelLsx.S
+ ${MLAS_SRC_DIR}/loongarch64/SconvKernelLasx.S
+ ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLSX.S
+ ${MLAS_SRC_DIR}/loongarch64/SpoolKernelLasx.S
+ ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4LSX.S
+ ${MLAS_SRC_DIR}/loongarch64/SgemmTransposePackB16x4Lasx.S
+ ${MLAS_SRC_DIR}/loongarch64/SoftmaxKernelLasx.S
+ )
+ set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mlsx -mlasx")
+ if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH)
+ set(MLAS_SOURCE_IS_NOT_SET 0)
+ endif()
+ endif()
if(NOT ONNXRUNTIME_MLAS_MULTI_ARCH AND MLAS_SOURCE_IS_NOT_SET)
file(GLOB_RECURSE mlas_platform_srcs
"${MLAS_SRC_DIR}/scalar/*.cpp")
diff --git a/cmake/onnxruntime_providers_cuda.cmake b/cmake/onnxruntime_providers_cuda.cmake
index cf298aee9fa85..84d1376f99d5e 100644
--- a/cmake/onnxruntime_providers_cuda.cmake
+++ b/cmake/onnxruntime_providers_cuda.cmake
@@ -34,6 +34,8 @@
if (NOT onnxruntime_USE_NCCL)
list(REMOVE_ITEM onnxruntime_cuda_contrib_ops_cc_srcs
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/nccl_kernels.cc"
+ "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharded_moe.h"
+ "${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharded_moe.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding_spec.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/sharding.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cuda/collective/distributed_matmul.cc"
diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake
index 7ac4a82c89a76..0951c2d02664d 100644
--- a/cmake/onnxruntime_providers_vitisai.cmake
+++ b/cmake/onnxruntime_providers_vitisai.cmake
@@ -15,16 +15,10 @@
"${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc"
"${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h"
)
- list(REMOVE_ITEM onnxruntime_providers_vitisai_cc_srcs "${ONNXRUNTIME_ROOT}/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc")
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vitisai_cc_srcs})
onnxruntime_add_static_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs})
onnxruntime_add_include_to_target(onnxruntime_providers_vitisai onnxruntime_common onnxruntime_framework onnx onnx_proto)
- onnxruntime_add_shared_library(onnxruntime_vitisai_ep ${ONNXRUNTIME_ROOT}/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc)
- onnxruntime_add_include_to_target(onnxruntime_vitisai_ep onnxruntime_common)
- target_include_directories(onnxruntime_vitisai_ep PRIVATE "${ONNXRUNTIME_ROOT}" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include")
- target_link_libraries(onnxruntime_providers_vitisai PUBLIC onnxruntime_vitisai_ep PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json )
- target_compile_definitions(onnxruntime_vitisai_ep
- PRIVATE "-DONNXRUNTIME_VITISAI_EP_STUB=1" "-DONNXRUNTIME_VITISAI_EP_EXPORT_DLL=1")
+ target_link_libraries(onnxruntime_providers_vitisai PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json)
if(NOT MSVC)
target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>)
endif(NOT MSVC)
@@ -49,4 +43,4 @@
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR})
- endif()
\ No newline at end of file
+ endif()
diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake
index 345ef2b504aa4..61922961588b2 100644
--- a/cmake/onnxruntime_python.cmake
+++ b/cmake/onnxruntime_python.cmake
@@ -453,6 +453,12 @@ file(GLOB onnxruntime_python_quantization_operators_src CONFIGURE_DEPENDS
file(GLOB onnxruntime_python_quantization_cal_table_flatbuffers_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/quantization/CalTableFlatBuffers/*.py"
)
+file(GLOB onnxruntime_python_quantization_fusions_src CONFIGURE_DEPENDS
+ "${ONNXRUNTIME_ROOT}/python/tools/quantization/fusions/*.py"
+)
+file(GLOB onnxruntime_python_quantization_ep_qnn_src CONFIGURE_DEPENDS
+ "${ONNXRUNTIME_ROOT}/python/tools/quantization/execution_providers/qnn/*.py"
+)
file(GLOB onnxruntime_python_transformers_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/*.py"
)
@@ -547,6 +553,9 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization
COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/operators
COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/CalTableFlatBuffers
+ COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/fusions
+ COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/execution_providers
+ COMMAND ${CMAKE_COMMAND} -E make_directory $/onnxruntime/quantization/execution_providers/qnn
COMMAND ${CMAKE_COMMAND} -E make_directory $/quantization
COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers
COMMAND ${CMAKE_COMMAND} -E make_directory $/transformers/test_data/models
@@ -617,6 +626,12 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_quantization_cal_table_flatbuffers_src}
$/onnxruntime/quantization/CalTableFlatBuffers/
+ COMMAND ${CMAKE_COMMAND} -E copy
+ ${onnxruntime_python_quantization_fusions_src}
+ $/onnxruntime/quantization/fusions/
+ COMMAND ${CMAKE_COMMAND} -E copy
+ ${onnxruntime_python_quantization_ep_qnn_src}
+ $/onnxruntime/quantization/execution_providers/qnn/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_src}
$/onnxruntime/transformers/
diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index 980bd59b22c3f..f70961a66329a 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -109,6 +109,8 @@ if (NOT onnxruntime_USE_NCCL)
# Those are string patterns to exclude. Do NOT use stars such as
# collective/*.cc or *.h.
list(APPEND contrib_ops_excluded_files "collective/nccl_kernels.cc")
+ list(APPEND contrib_ops_excluded_files "collective/sharded_moe.h")
+ list(APPEND contrib_ops_excluded_files "collective/sharded_moe.cc")
list(APPEND contrib_ops_excluded_files "collective/sharding.cc")
list(APPEND contrib_ops_excluded_files "collective/sharding_spec.cc")
list(APPEND contrib_ops_excluded_files "collective/distributed_matmul.cc")
diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake
index df62199dc2b42..7c8c70f913dca 100644
--- a/cmake/onnxruntime_unittests.cmake
+++ b/cmake/onnxruntime_unittests.cmake
@@ -1373,56 +1373,55 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
target_link_libraries(compare_two_sessions PRIVATE ${GETOPT_LIB_WIDE} tdh Advapi32)
endif()
- file(GLOB onnxruntime_mlas_test_src CONFIGURE_DEPENDS
- "${TEST_SRC_DIR}/mlas/unittest/*.h"
- "${TEST_SRC_DIR}/mlas/unittest/*.cpp"
- )
- onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src})
- if(MSVC)
- target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26409>"
- "$<$>:/wd26409>")
- target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>"
- "$<$>:/utf-8>")
- target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd6326>"
- "$<$>:/wd6326>")
- target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>"
- "$<$>:/wd26426>")
- endif()
- if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS")
- set_target_properties(onnxruntime_mlas_test PROPERTIES
- XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO"
+ if(NOT onnxruntime_target_platform STREQUAL "ARM64EC")
+ file(GLOB onnxruntime_mlas_test_src CONFIGURE_DEPENDS
+ "${TEST_SRC_DIR}/mlas/unittest/*.h"
+ "${TEST_SRC_DIR}/mlas/unittest/*.cpp"
)
- endif()
- target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}
- ${CMAKE_CURRENT_BINARY_DIR})
- target_link_libraries(onnxruntime_mlas_test PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
- if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
- target_link_libraries(onnxruntime_mlas_test PRIVATE cpuinfo)
- endif()
- if(NOT WIN32)
- target_link_libraries(onnxruntime_mlas_test PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS})
- endif()
- if (CMAKE_SYSTEM_NAME STREQUAL "Android")
- target_link_libraries(onnxruntime_mlas_test PRIVATE ${android_shared_libs})
- endif()
-
- if(WIN32)
- target_link_libraries(onnxruntime_mlas_test PRIVATE debug Dbghelp Advapi32)
- endif()
- if (onnxruntime_LINK_LIBATOMIC)
- target_link_libraries(onnxruntime_mlas_test PRIVATE atomic)
- endif()
- target_link_libraries(onnxruntime_mlas_test PRIVATE Threads::Threads)
-
- set_target_properties(onnxruntime_mlas_test PROPERTIES FOLDER "ONNXRuntimeTest")
- if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
- if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
- set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1")
- else()
- set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1")
+ onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src})
+ if(MSVC)
+ target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26409>"
+ "$<$>:/wd26409>")
+ target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>"
+ "$<$>:/utf-8>")
+ target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd6326>"
+ "$<$>:/wd6326>")
+ target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>"
+ "$<$>:/wd26426>")
endif()
- endif()
-
+ if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS")
+ set_target_properties(onnxruntime_mlas_test PROPERTIES
+ XCODE_ATTRIBUTE_CODE_SIGNING_ALLOWED "NO"
+ )
+ endif()
+ target_include_directories(onnxruntime_mlas_test PRIVATE ${ONNXRUNTIME_ROOT}/core/mlas/inc ${ONNXRUNTIME_ROOT}
+ ${CMAKE_CURRENT_BINARY_DIR})
+ target_link_libraries(onnxruntime_mlas_test PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
+ if (CPUINFO_SUPPORTED AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
+ target_link_libraries(onnxruntime_mlas_test PRIVATE cpuinfo)
+ endif()
+ if(NOT WIN32)
+ target_link_libraries(onnxruntime_mlas_test PRIVATE nsync::nsync_cpp ${CMAKE_DL_LIBS})
+ endif()
+ if (CMAKE_SYSTEM_NAME STREQUAL "Android")
+ target_link_libraries(onnxruntime_mlas_test PRIVATE ${android_shared_libs})
+ endif()
+ if(WIN32)
+ target_link_libraries(onnxruntime_mlas_test PRIVATE debug Dbghelp Advapi32)
+ endif()
+ if (onnxruntime_LINK_LIBATOMIC)
+ target_link_libraries(onnxruntime_mlas_test PRIVATE atomic)
+ endif()
+ target_link_libraries(onnxruntime_mlas_test PRIVATE Threads::Threads)
+ set_target_properties(onnxruntime_mlas_test PROPERTIES FOLDER "ONNXRuntimeTest")
+ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
+ if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
+ set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1 -s PROXY_TO_PTHREAD=1 -s EXIT_RUNTIME=1")
+ else()
+ set_target_properties(onnxruntime_mlas_test PROPERTIES LINK_FLAGS "-s ALLOW_MEMORY_GROWTH=1")
+ endif()
+ endif()
+endif()
# Training API Tests
# Disabling training_api_test_trainer. CXXOPT generates a ton of warnings because of which nuget pipeline is failing.
# TODO(askhade): Fix the warnings.
diff --git a/cmake/patches/abseil/absl_gh_issue_1435_workaround.patch b/cmake/patches/abseil/absl_gh_issue_1435_workaround.patch
deleted file mode 100644
index 0a864cdc019b4..0000000000000
--- a/cmake/patches/abseil/absl_gh_issue_1435_workaround.patch
+++ /dev/null
@@ -1,17 +0,0 @@
---- absl/container/internal/layout.h 2023-11-28 09:35:48
-+++ absl/container/internal/layout.updated.h 2023-11-28 10:13:14
-@@ -181,9 +181,11 @@
- #include
- #endif
-
--#if defined(__GXX_RTTI)
--#define ABSL_INTERNAL_HAS_CXA_DEMANGLE
--#endif
-+// Comment out ABSL_INTERNAL_HAS_CXA_DEMANGLE definition to work around this issue:
-+// https://github.com/abseil/abseil-cpp/issues/1435
-+// #if defined(__GXX_RTTI)
-+// #define ABSL_INTERNAL_HAS_CXA_DEMANGLE
-+// #endif
-
- #ifdef ABSL_INTERNAL_HAS_CXA_DEMANGLE
- #include
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj
index 0c74a23204d4f..1d15383239baf 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/Microsoft.ML.OnnxRuntime.csproj
@@ -6,7 +6,7 @@
true
- netstandard2.0
+ netstandard2.0;netcoreapp3.1;net6.0
diff --git a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
index 86b44a6784817..163a2b394c4ae 100644
--- a/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
+++ b/csharp/src/Microsoft.ML.OnnxRuntime/OrtValue.shared.cs
@@ -263,12 +263,16 @@ public ReadOnlyMemory GetStringElementAsMemory(int index)
/// UTF-16 string instance
public string GetStringElement(int index)
{
- var chars = GetStringTensorElementChars(index);
- if (chars.Length == 0)
+ GetStringTensorElementBuffer((UIntPtr)index, out uint bytesLen, out IntPtr bufferPtr);
+ if (bytesLen == 0)
{
return string.Empty;
}
- return new string(chars);
+
+ unsafe
+ {
+ return Encoding.UTF8.GetString((byte*)bufferPtr.ToPointer(), (int)bytesLen);
+ }
}
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index c73f978bdf404..e5b43ddba8cc7 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -1599,14 +1599,14 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Inputs (1 - ∞)
-- inputs (variadic) : T
+- inputs (variadic, heterogeneous) : T
- List of tensors for inputs
#### Outputs (1 - ∞)
-- outputs (variadic) : T
+- outputs (variadic, heterogeneous) : T
- One or more outputs, list of tensors for outputs
diff --git a/docs/Memory_Optimizer.md b/docs/Memory_Optimizer.md
index 0147a937db81d..97f7e7ff2c14b 100644
--- a/docs/Memory_Optimizer.md
+++ b/docs/Memory_Optimizer.md
@@ -17,55 +17,83 @@ Classical scenarios include:
Not all models and recipes need this optimizer technique. Imagine if your training recipe uses a batch size 6 (GPU compute and memory are fully saturated), and you don't need bump it to 8 to maintain a fixed global batch size. Enabling recompute maybe not bring better throughput on batch size 8 than the original batch size 6.
-## Quick trial
+## Usage
-1. Make sure ONNX Runtime training wheel is installed and correctly configured.
-2. Integrate models using `ORTModule`, be noted log_level should be equal or lower than INFO.
- > ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.INFO))
-3. Run the training as usual; then stop it after training few steps.
-4. Check the logs, you could find something like this:
+
+Make sure ONNX Runtime training wheel is installed and correctly configured.
+Integrate models using `ORTModule`.
+```diff
+ model = build_model()
+
++ from onnxruntime.training.ortmodule import ORTModule
++ model = ORTModule(model)
+```
+
+There are two modes to enable the memory optimizations:
+- Aggressively Recompute All Within Each Transformer Layer, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=1`. This will recompute all detected subgraphs within each Transformer Attention+MLP layer. It is easy to enable, but be noted this recompute plan may NOT be the best one. In this mode, `ORTMODULE_MEMORY_OPT_CONFIG` env values passed by users are not respected.
+- User Specified Subgraph Recompute, enabled by `export ORTMODULE_MEMORY_OPT_LEVEL=0` and `export ORTMODULE_MEMORY_OPT_CONFIG=,,...`. This is an advanced usage, that allows users to find the most suitable graphs to recompute, at the cost of overhead to look for the best plans.
+
+### Mode 1 - Simple Usage (Aggressively Recompute All Within Each Transformer Layer)
+
+
+1. Set memory optimization level to be TRANSFORMER_LAYERWISE_RECOMPUTE, by `export ORTMODULE_MEMORY_OPT_LEVEL=1`
+2. Run the training as usual; check the logs, you could find something like this if the current log level <= LogLevel.INFO:
```
- Memory Optimizer : OFF : Enable with env ORTMODULE_MEMORY_OPT_CONFIG=, available configs:
- Config Freq Max Saving(B) Saving Symbolic(Bytes)
- - Plan 1 : OFF : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- - Plan 2 : OFF : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0)
- - Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- - Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- - Plan 5 : OFF : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- - Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- - Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- - Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- - Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- - Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
-
-
- Note 1: use comma as delimiter to enable multiple memory optimization plans at the same time:
- export ORTMODULE_MEMORY_OPT_CONFIG=,,...
- Note 2: memory saving is calculated based on the 1st batch symbolic dim values:
- inputs_input_ids_dim0=1, inputs_input_ids_dim1=1024, inputs_attention_mask_dim0=1, inputs_attention_mask_dim1=1024, inputs_labels_dim0=1, inputs_labels_dim1=1024,
+ Memory Optimizer : ON : Memory Optimization Level: [TRANSFORMER_LAYERWISE_RECOMPUTE], Optimization Config: [Reshape+Where+:1:-1,BiasSoftmax+:1:-1,Cast+:1:-1,BiasGelu+:1:-1,FusedMatMul+:1:-1,Add+:1:-1,Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1]
+ Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes)
+ - Plan 1 : ON : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
+ - Plan 2 : ON : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
+ - Plan 3 : ON : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
+ - Plan 4 : ON : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
+ - Plan 5 : ON : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
+ - Plan 6 : ON : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
+ - Plan 7 : ON : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
+ - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
```
-5. As shown above, `Config` is a string representative for a re-computable subgraph. All are disabled for recompute in this case.
-6. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraph to do recompute. In below example, `6` `BiasGelu+` related subgraphs are allowed to recompute.
-`BiasGelu+` is the subgraph string representative; `1` in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled); `6` means the initial 6 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed.
+3. As shown above, `Config` is a string representative for a re-computable subgraph. All are enabled for recompute in this case.
+
+
+### Mode 2 - Advanced Usage (User Selected Subgraph Recompute)
+
+1. Be noted `ORTMODULE_MEMORY_OPT_LEVEL` is by default be 0. Run the training as usual; then stop it after training a few steps.
+2. Check the logs, you could find something like this if the current log level <= LogLevel.INFO::
```
- export ORTMODULE_MEMORY_OPT_CONFIG="BiasGelu+:1:6" # Use comma as separator for enabling more than one subgraphs.
+ Memory Optimizer : OFF : Enable with env ORTMODULE_MEMORY_OPT_LEVEL=1 or ORTMODULE_MEMORY_OPT_CONFIG=,,...
+ Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes)
+ - Plan 1 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
+ - Plan 2 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
+ - Plan 3 : OFF : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
+ - Plan 4 : OFF : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
+ - Plan 5 : OFF : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
+ - Plan 6 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
+ - Plan 7 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
+ - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
```
-7. Then run the training again, and you will see logs like this:
+3. As shown above, `Config` is a string representative for a re-computable subgraph. All are disabled for recompute in this case.
+4. Set environment variable `ORTMODULE_MEMORY_OPT_CONFIG` to enable some of the subgraphs to do recompute.
+ ```bash
+ # Use comma as a separator for enabling more than one subgraphs.
+ export ORTMODULE_MEMORY_OPT_CONFIG="BiasGelu+:1:1"
+ # Explanation:
+ # > BiasGelu+ is the subgraph string representative;
+ # > 1 in the middle indicates 'Recompute' is enabled (0, on the contrary indicates it's disabled)
+ # > The last 1 means the initial 1 subgraph occurrences will be recomputed, all others are left as it is, filling `-1` will make all occurrences be recomputed.
+
+ ```
+5. Then run the training again, and you will see logs like this:
```
- Memory Optimizer : ON : User config: Reshape+Where+BiasSoftmax+:1:-1, probe level: 1, available configs:
- Config Freq Max Saving(B) Saving Symbolic(Bytes)
- - Plan 1 : OFF : Reshape+Where+BiasSoftmax+:1:-1 5 671,088,640 640.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- - Plan 2 : OFF : Cast+:1:-1 6 402,587,648 inputs_input_ids_dim0*inputs_input_ids_dim1*(384.0*inputs_input_ids_dim1 - 64.0)
- - Plan 3 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
- - Plan 4 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
- - Plan 5 : ON : BiasGelu+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- - Plan 6 : OFF : FusedMatMul+:1:-1 6 125,808,640 inputs_input_ids_dim0*(122880.0*inputs_input_ids_dim1 - 20480.0)
- - Plan 7 : OFF : FusedMatMul+Add+FusedMatMul+Add+Add+Add+:1:-1 5 26,214,400 25600.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- - Plan 8 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
- - Plan 9 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
- - Plan 10 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
+ Memory Optimizer : ON : Memory Optimization Level: [USER_SPECIFIED], Optimization Config: [BiasGelu+:1:-1]
+ Configs Freq Max Saving(Bytes) Saving Symbolic(Bytes)
+ - Plan 1 : OFF : Reshape+Where+:1:-1 1 134,217,728 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1**2
+ - Plan 2 : OFF : BiasSoftmax+:1:-1 1 134,086,656 128.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
+ - Plan 3 : OFF : Cast+:1:-1 1 67,043,328 64.0*inputs_input_ids_dim0*inputs_input_ids_dim1*(inputs_input_ids_dim1 - 1)
+ - Plan 4 : ON : BiasGelu+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
+ - Plan 5 : OFF : FusedMatMul+:1:-1 1 20,951,040 20480.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
+ - Plan 6 : OFF : Add+:1:-1 1 5,237,760 5120.0*inputs_input_ids_dim0*(inputs_input_ids_dim1 - 1)
+ - Plan 7 : OFF : Reshape+Unsqueeze+Unsqueeze+Cast+Sub+Mul+Cast+:1:-1 1 4,096 4.0*inputs_input_ids_dim0*inputs_input_ids_dim1
+ - Plan 8 : OFF : Cast+:2:-1 1 2,048 2.0*inputs_input_ids_dim0*inputs_input_ids_dim1
```
-8. You may need iterate few times on step 6 and 7 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well.
+6. You may need iterate a few times on step 4 and 5 until you find a good config for this model to run a bigger batch size. Or you may fail to find if memory optimization does not apply to the model well.
## Optimization Configuration
@@ -73,11 +101,13 @@ The basic optimization unit is represented with a unique `cluster id`, for examp
Following `cluster id` is the `optimization strategy`: 0 - none, 1 - recompute, 2 - recompute with compromised memory saving.
Following `optimization strategy` is the `request count` to apply the given optimization. Using `-1` to apply all. This would give user a bit more flexibility to avoid unnecessary memory saving.
-## Compromised Recompute
+### Compromised Recompute
If you check the above logs, there is a config `Cast+:2:-1`, `2` indicates it's a recomputation than can save part of the stashed activation size, not all. Recompute the subgraphs under it usually will save part of the activation (for example half of them), not all of them. Follow the same way to enable it.
-## Memory Optimization Debug Infos
+## Dev Notes
+
+### Memory Optimization Debug Infos
Using following log level
> ort_model = ORTModule(pt_model, DebugOptions(log_level=LogLevel.DEVINFO))
@@ -132,4 +162,4 @@ MemoryInsight Summary - User config: not provided
## Notes
-The feature is in experimental stage, we will tune and refine it according to real use cases.
+The feature is in the experimental stage, we will tune and refine it according to real use cases.
diff --git a/docs/ORTModule_Training_Guidelines.md b/docs/ORTModule_Training_Guidelines.md
index 7fa89cca381d9..bede16204d420 100644
--- a/docs/ORTModule_Training_Guidelines.md
+++ b/docs/ORTModule_Training_Guidelines.md
@@ -146,7 +146,6 @@ Check [DebugOptions implementation](../orttraining/orttraining/python/training/o
export ORTMODULE_ONNX_OPSET_VERSION=14
```
-
#### ORTMODULE_FALLBACK_POLICY
- **Feature Area**: *ORTMODULE/FallbackToPytorch*
@@ -155,7 +154,6 @@ Check [DebugOptions implementation](../orttraining/orttraining/python/training/o
export ORTMODULE_FALLBACK_POLICY="FALLBACK_DISABLE"
```
-
#### ORTMODULE_LOG_LEVEL
- **Feature Area**: *ORTMODULE/DebugOptions*
@@ -182,7 +180,6 @@ The output directory of the onnx models by default is set to the current working
> On the other hand, if the wrapped computation graph is small, it is reasonable to allow it.
> Overall users should be aware that ORT performance boost might be trivial when they explicitly allow it.
-
#### ORTMODULE_ENABLE_CUSTOM_AUTOGRAD
- **Feature Area**: *ORTMODULE/PythonOp (torch.autograd.Function)*
@@ -199,8 +196,6 @@ The output directory of the onnx models by default is set to the current working
enable_custom_autograd_support(False)
```
-
-
#### ORTMODULE_ENABLE_COMPUTE_OPTIMIZER
- **Feature Area**: *ORTMODULE/Optimizations*
@@ -278,6 +273,26 @@ data sparsity based performance optimizations.
export ORTMODULE_USE_EFFICIENT_ATTENTION=1
```
+#### ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT
+
+- **Feature Area**: *ORTMODULE/Optimizations*
+- **Description**: By default, this is enabled. This env var can be used for enabling or disabling the module deep copy when preparing output data which will be used by ONNX export.
+A classical usage of disabling the deep copy: when the deep copy before module export bring the memory peak, then we should disable it and have a try.
+
+ ```bash
+ export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=1 # Enable
+ export ORTMODULE_DEEPCOPY_BEFORE_MODEL_EXPORT=0 # Disable
+ ```
+
+#### ORTMODULE_MEMORY_OPT_LEVEL
+
+- **Feature Area**: *ORTMODULE/Optimizations*
+- **Description**: By default, the level is 0. This env var can be used for enabling recomputation for reducing memory peak requirement. Setting the level to be 0 means all detected subgraphs with each transformer-based model layer generating stashed activations will be recomputed. This is conceptually equivalent to PyTorch's gradient checkpoint. When level is not 0, check Check [Memory Optimizer for ONNX Runtime Training](Memory_Optimizer.md) for more details.
+
+ ```bash
+ export ORTMODULE_MEMORY_OPT_LEVEL=0
+ ```
+
### 2.2 Memory Optimization
Q: *Want to run a bigger batch size?*
@@ -379,6 +394,30 @@ Check [FP16_Optimizer implementation](../orttraining/orttraining/python/training
export ORTMODULE_USE_TRITON=1
```
+#### ORTMODULE_TRITON_CONFIG_FILE
+
+- **Feature Area**: *ORTMODULE/TritonOp*
+- **Description**: Triton codegen currently supported some Ops such as some elementwise Ops and some reduction Ops. If Triton optimization is enabled, all these supported Ops will be optimized by default if possible. User can provide a customized JSON config file to control which Ops to optimize and how to optimize them. Below is a sample of config JSON. For each Op, Opset version list and domain is needed. Currently "conditions" field can be used to control axis/axes attribute or input, by specify the real value, or "single" means it contains only one dimension, or "constant" means it must be constant tensor. Save the JSON as a file somewhere and assign its path to below env variable to enable the customized config.
+
+ ```json
+ {
+ "ops": {
+ "Add": {"versions": [13, 14]},
+ "Sub": {"versions": [13, 14]},
+ "Identity": {"versions": [13], "is_no_op": True},
+ "ReduceSum": {"versions": [13], "conditions": {"axes": "[-1]"}},
+ "Softmax": {"versions": [13]},
+ "SoftmaxGrad_13": {"domain": "com.microsoft", "versions": [1]}
+ },
+ "initializer": "scalar",
+ "min_nodes": 2
+ }
+ ```
+
+ ```bash
+ export ORTMODULE_TRITON_CONFIG_FILE=triton_config.json
+ ```
+
#### ORTMODULE_ENABLE_TUNING
- **Feature Area**: *ORTMODULE/TritonOp*
diff --git a/docs/python/api_summary.rst b/docs/python/api_summary.rst
index cecd62aff15c4..092b42010a5c6 100644
--- a/docs/python/api_summary.rst
+++ b/docs/python/api_summary.rst
@@ -274,6 +274,77 @@ SessionOptions
.. autoclass:: onnxruntime.SessionOptions
:members:
+.. autoclass:: onnxruntime.ExecutionMode
+ :members:
+
+.. autoclass:: onnxruntime.ExecutionOrder
+ :members:
+
+.. autoclass:: onnxruntime.GraphOptimizationLevel
+ :members:
+
+.. autoclass:: onnxruntime.OrtAllocatorType
+ :members:
+
+.. autoclass:: onnxruntime.OrtArenaCfg
+ :members:
+
+.. autoclass:: onnxruntime.OrtMemoryInfo
+ :members:
+
+.. autoclass:: onnxruntime.OrtMemType
+ :members:
+
+Functions
+---------
+
+Allocators
+^^^^^^^^^^
+
+.. autofunction:: onnxruntime.create_and_register_allocator
+
+.. autofunction:: onnxruntime.create_and_register_allocator_v2
+
+Telemetry events
+^^^^^^^^^^^^^^^^
+
+.. autofunction:: onnxruntime.disable_telemetry_events
+
+.. autofunction:: onnxruntime.enable_telemetry_events
+
+Providers
+^^^^^^^^^
+
+.. autofunction:: onnxruntime.get_all_providers
+
+.. autofunction:: onnxruntime.get_available_providers
+
+Build, Version
+^^^^^^^^^^^^^^
+
+.. autofunction:: onnxruntime.get_build_info
+
+.. autofunction:: onnxruntime.get_version_string
+
+.. autofunction:: onnxruntime.has_collective_ops
+
+Device
+^^^^^^
+
+.. autofunction:: onnxruntime.get_device
+
+Logging
+^^^^^^^
+
+.. autofunction:: onnxruntime.set_default_logger_severity
+
+.. autofunction:: onnxruntime.set_default_logger_verbosity
+
+Random
+^^^^^^
+
+.. autofunction:: onnxruntime.set_seed
+
Data
----
@@ -298,6 +369,9 @@ IOBinding
.. autoclass:: onnxruntime.IOBinding
:members:
+.. autoclass:: onnxruntime.SessionIOBinding
+ :members:
+
OrtDevice
^^^^^^^^^
diff --git a/include/onnxruntime/core/graph/constants.h b/include/onnxruntime/core/graph/constants.h
index 7e59aad80cc47..9b26ba914c7dd 100644
--- a/include/onnxruntime/core/graph/constants.h
+++ b/include/onnxruntime/core/graph/constants.h
@@ -55,4 +55,7 @@ constexpr const char* kAzureExecutionProvider = "AzureExecutionProvider";
constexpr const char* kExecutionProviderSharedLibraryPath = "shared_lib_path";
constexpr const char* kExecutionProviderSharedLibraryEntry = "provider_factory_entry_point";
+// For Priority based graph topology sorting.
+constexpr const char* kBackwardNodeAttributeName = "__backwardpass";
+
} // namespace onnxruntime
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index cddad732104ed..c41700453a73b 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -3598,6 +3598,7 @@ struct OrtApi {
* "qnn_context_cache_path": explicitly provide the QNN context cache file. Default to model_file.onnx.bin if not provided.
* "profiling_level": QNN profiling level, options: "off", "basic", "detailed". Default to off.
* "rpc_control_latency": QNN RPC control latency.
+ * "vtcm_mb": QNN VTCM size in MB. default to 0(not set).
* "htp_performance_mode": QNN performance mode, options: "burst", "balanced", "default", "high_performance",
* "high_power_saver", "low_balanced", "low_power_saver", "power_saver", "sustained_high_performance". Default to "default".
* "qnn_context_embed_mode", 1 means dump the QNN context binary into node attribute EPContext->ep_cache_context in the ONNX skeleton model.
diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
index 4628afbb5a702..a94973b2cc5d7 100644
--- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
+++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
@@ -88,9 +88,9 @@ static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining =
// the memory.
static const char* const kOrtSessionOptionsMemoryOptimizerEnabler = "optimization.memory_optimizer_config";
-// Specifies the level for detecting subgraphs for memory footprint reduction.
-// The value should be an integer. The default value is 0.
-static const char* const kOrtSessionOptionsMemoryOptimizerProbeLevel = "optimization.enable_memory_probe_recompute_level";
+// Specifies the config for detecting subgraphs for memory footprint reduction.
+// The value should be a string contains int separated using commas. The default value is "0:0".
+static const char* const kOrtSessionOptionsMemoryOptimizerProbeConfig = "optimization.enable_memory_probe_recompute_config";
#endif
// Enable or disable using device allocator for allocating initialized tensor memory. "1": enable; "0": disable. The default is "0".
diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts
index 67d283b694955..5460ae086fc2f 100644
--- a/js/common/lib/backend.ts
+++ b/js/common/lib/backend.ts
@@ -45,9 +45,17 @@ export interface InferenceSessionHandler extends SessionHandler {
* @ignore
*/
export interface TrainingSessionHandler extends SessionHandler {
+ readonly evalInputNames: readonly string[];
+ readonly evalOutputNames: readonly string[];
+
+ lazyResetGrad(): Promise;
runTrainStep(
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise;
+ runOptimizerStep(options: InferenceSession.RunOptions): Promise;
+ runEvalStep(
+ feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
+ options: InferenceSession.RunOptions): Promise;
getParametersSize(trainableOnly: boolean): Promise;
loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise;
diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts
index 76575ef7b9368..0cded7e5edbcb 100644
--- a/js/common/lib/env.ts
+++ b/js/common/lib/env.ts
@@ -92,11 +92,48 @@ export declare namespace Env {
async?: boolean;
}
+ export interface WebGpuProfilingDataV1TensorMetadata {
+ dims: readonly number[];
+ dataType: string;
+ }
+ export interface WebGpuProfilingDataV1 {
+ version: 1;
+ inputsMetadata: readonly WebGpuProfilingDataV1TensorMetadata[];
+ outputsMetadata: readonly WebGpuProfilingDataV1TensorMetadata[];
+ kernelId: number;
+ kernelType: string;
+ kernelName: string;
+ startTime: number;
+ endTime: number;
+ }
+
+ export type WebGpuProfilingData = WebGpuProfilingDataV1;
+
export interface WebGpuFlags {
/**
* Set or get the profiling mode.
+ *
+ * @deprecated Use `env.webgpu.profiling.mode` instead. If `env.webgpu.profiling.mode` is set, this property will be
+ * ignored.
*/
profilingMode?: 'off'|'default';
+ /**
+ * Set or get the profiling configuration.
+ */
+ profiling?: {
+ /**
+ * Set or get the profiling mode.
+ *
+ * @defaultValue `'off'`
+ */
+ mode?: 'off'|'default';
+
+ /**
+ * Set or get a callback function when a profiling data is received. If not set, the profiling data will be
+ * printed to console.
+ */
+ ondata?: (data: WebGpuProfilingData) => void;
+ };
/**
* Get the device for WebGPU.
*
diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts
index 03694738387f2..23bd4421ae672 100644
--- a/js/common/lib/training-session-impl.ts
+++ b/js/common/lib/training-session-impl.ts
@@ -18,18 +18,37 @@ const noBackendErrMsg: string = 'Training backend could not be resolved. ' +
'Make sure you\'re using the correct configuration & WebAssembly files.';
export class TrainingSession implements TrainingSessionInterface {
- private constructor(handler: TrainingSessionHandler) {
+ private constructor(handler: TrainingSessionHandler, hasOptimizerModel: boolean, hasEvalModel: boolean) {
this.handler = handler;
+ this.hasOptimizerModel = hasOptimizerModel;
+ this.hasEvalModel = hasEvalModel;
}
private handler: TrainingSessionHandler;
+ private hasOptimizerModel: boolean;
+ private hasEvalModel: boolean;
- get inputNames(): readonly string[] {
+ get trainingInputNames(): readonly string[] {
return this.handler.inputNames;
}
- get outputNames(): readonly string[] {
+ get trainingOutputNames(): readonly string[] {
return this.handler.outputNames;
}
+ get evalInputNames(): readonly string[] {
+ if (this.hasEvalModel) {
+ return this.handler.evalInputNames;
+ } else {
+ throw new Error('This training session has no evalModel loaded.');
+ }
+ }
+ get evalOutputNames(): readonly string[] {
+ if (this.hasEvalModel) {
+ return this.handler.evalOutputNames;
+ } else {
+ throw new Error('This training session has no evalModel loaded.');
+ }
+ }
+
static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions):
Promise {
const evalModel: string|Uint8Array = trainingOptions.evalModel || '';
@@ -43,7 +62,7 @@ export class TrainingSession implements TrainingSessionInterface {
if (backend.createTrainingSessionHandler) {
const handler = await backend.createTrainingSessionHandler(
trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
- return new TrainingSession(handler);
+ return new TrainingSession(handler, !!trainingOptions.optimizerModel, !!trainingOptions.evalModel);
} else {
throw new Error(noBackendErrMsg);
}
@@ -53,13 +72,18 @@ export class TrainingSession implements TrainingSessionInterface {
* Helper function for runTrainStep and future runStep methods that handles the type-narrowing conversion from
* the given parameters to SessionHandler.FetchesType and RunOptions.
*
+ * @param inputNames the feeds object is checked that they contain all input names in the provided list of input
+ * names.
+ * @param outputNames the fetches object is checked that their keys match up with valid names in the list of output
+ * names.
* @param feeds the required input
* @param arg1 narrowed & converted into the SessionHandler.FetchesType or RunOptions object
* @param arg2 optional RunOptions object.
* @returns
*/
- typeNarrowingForRunStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions):
- [SessionHandler.FetchesType, RunOptions] {
+ typeNarrowingForRunStep(
+ inputNames: readonly string[], outputNames: readonly string[], feeds: FeedsType, arg1?: FetchesType|RunOptions,
+ arg2?: RunOptions): [SessionHandler.FetchesType, RunOptions] {
const fetches: {[name: string]: OnnxValue|null} = {};
let options: RunOptions = {};
// check inputs
@@ -88,7 +112,7 @@ export class TrainingSession implements TrainingSessionInterface {
if (typeof name !== 'string') {
throw new TypeError('\'fetches\' must be a string array or an object.');
}
- if (this.outputNames.indexOf(name) === -1) {
+ if (outputNames.indexOf(name) === -1) {
throw new RangeError(`'fetches' contains invalid output name: ${name}.`);
}
fetches[name] = null;
@@ -104,7 +128,7 @@ export class TrainingSession implements TrainingSessionInterface {
// if any output name is present and its value is valid OnnxValue, we consider it fetches
let isFetches = false;
const arg1Keys = Object.getOwnPropertyNames(arg1);
- for (const name of this.outputNames) {
+ for (const name of outputNames) {
if (arg1Keys.indexOf(name) !== -1) {
const v = (arg1 as InferenceSession.NullableOnnxValueMapType)[name];
if (v === null || v instanceof Tensor) {
@@ -130,7 +154,7 @@ export class TrainingSession implements TrainingSessionInterface {
}
// check if all inputs are in feed
- for (const name of this.inputNames) {
+ for (const name of inputNames) {
if (typeof feeds[name] === 'undefined') {
throw new Error(`input '${name}' is missing in 'feeds'.`);
}
@@ -138,7 +162,7 @@ export class TrainingSession implements TrainingSessionInterface {
// if no fetches is specified, we use the full output names list
if (isFetchesEmpty) {
- for (const name of this.outputNames) {
+ for (const name of outputNames) {
fetches[name] = null;
}
}
@@ -168,14 +192,40 @@ export class TrainingSession implements TrainingSessionInterface {
return returnValue;
}
+ async lazyResetGrad(): Promise {
+ await this.handler.lazyResetGrad();
+ }
+
runTrainStep(feeds: FeedsType, options?: RunOptions): Promise;
runTrainStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions): Promise;
async runTrainStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise {
- const [fetches, options] = this.typeNarrowingForRunStep(feeds, arg1, arg2);
+ const [fetches, options] =
+ this.typeNarrowingForRunStep(this.trainingInputNames, this.trainingOutputNames, feeds, arg1, arg2);
const results = await this.handler.runTrainStep(feeds, fetches, options);
return this.convertHandlerReturnTypeToMapOfTensors(results);
}
+ async runOptimizerStep(options?: InferenceSession.RunOptions|undefined): Promise {
+ if (this.hasOptimizerModel) {
+ await this.handler.runOptimizerStep(options || {});
+ } else {
+ throw new Error('This TrainingSession has no OptimizerModel loaded.');
+ }
+ }
+
+ runEvalStep(feeds: FeedsType, options?: RunOptions|undefined): Promise;
+ runEvalStep(feeds: FeedsType, fetches: FetchesType, options?: RunOptions|undefined): Promise;
+ async runEvalStep(feeds: FeedsType, arg1?: FetchesType|RunOptions, arg2?: RunOptions): Promise {
+ if (this.hasEvalModel) {
+ const [fetches, options] =
+ this.typeNarrowingForRunStep(this.evalInputNames, this.evalOutputNames, feeds, arg1, arg2);
+ const results = await this.handler.runEvalStep(feeds, fetches, options);
+ return this.convertHandlerReturnTypeToMapOfTensors(results);
+ } else {
+ throw new Error('This TrainingSession has no EvalModel loaded.');
+ }
+ }
+
async getParametersSize(trainableOnly = true): Promise {
return this.handler.getParametersSize(trainableOnly);
}
diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts
index 810ec2a8583b3..e54aed90e702c 100644
--- a/js/common/lib/training-session.ts
+++ b/js/common/lib/training-session.ts
@@ -22,6 +22,12 @@ export declare namespace TrainingSession {
export interface TrainingSession {
// #region run()
+ /**
+ * Lazily resets the gradients of all trainable parameters to zero. Should happen after the invocation of
+ * runOptimizerStep.
+ */
+ lazyResetGrad(): Promise;
+
/**
* Run TrainStep asynchronously with the given feeds and options.
*
@@ -39,7 +45,7 @@ export interface TrainingSession {
* @param feeds - Representation of the model input.
* @param fetches - Representation of the model output.
* detail.
- * @param options - Optional. A set of options that controls the behavior of model inference.
+ * @param options - Optional. A set of options that controls the behavior of model training.
* @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
values.
*/
@@ -47,6 +53,38 @@ export interface TrainingSession {
feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType,
options?: InferenceSession.RunOptions): Promise;
+ /**
+ * Runs a single optimizer step, which performs weight updates for the trainable parameters using the optimizer model.
+ *
+ * @param options - Optional. A set of options that controls the behavior of model optimizing.
+ */
+ runOptimizerStep(options?: InferenceSession.RunOptions): Promise;
+
+ /**
+ * Run a single eval step with the given inputs and options using the eval model.
+ *
+ * @param feeds - Representation of the model input.
+ * @param options - Optional. A set of options that controls the behavior of model eval step.
+ * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
+ values.
+ */
+ runEvalStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions):
+ Promise;
+
+ /**
+ * Run a single eval step with the given inputs and options using the eval model.
+ *
+ * @param feeds - Representation of the model input.
+ * @param fetches - Representation of the model output.
+ * detail.
+ * @param options - Optional. A set of options that controls the behavior of model eval step.
+ * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding
+ values.
+ */
+ runEvalStep(
+ feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType,
+ options?: InferenceSession.RunOptions): Promise;
+
// #endregion
// #region copy parameters
@@ -90,14 +128,25 @@ export interface TrainingSession {
// #region metadata
/**
- * Get input names of the loaded model.
+ * Get input names of the loaded training model.
+ */
+ readonly trainingInputNames: readonly string[];
+
+ /**
+ * Get output names of the loaded training model.
*/
- readonly inputNames: readonly string[];
+ readonly trainingOutputNames: readonly string[];
/**
- * Get output names of the loaded model.
+ * Get input names of the loaded eval model. Is an empty array if no eval model is loaded.
*/
- readonly outputNames: readonly string[];
+ readonly evalInputNames: readonly string[];
+
+ /**
+ * Get output names of the loaded eval model. Is an empty array if no eval model is loaded.
+ */
+ readonly evalOutputNames: readonly string[];
+
// #endregion
}
diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md
index 00c27fe3ab034..2f510308d9306 100644
--- a/js/web/docs/webgpu-operators.md
+++ b/js/web/docs/webgpu-operators.md
@@ -33,6 +33,7 @@ Do not modify directly.*
| ConvTranspose | ai.onnx(1-10,11+); com.ms.internal.nhwc(1-10,11+) | need perf optimization; ConvTranspose3d is not supported; need implementing activation |
| Cos | ai.onnx(7+) | |
| Cosh | ai.onnx(9+) | |
+| CumSum | ai.onnx(11-13,14+) | |
| Div | ai.onnx(7-12,13,14+) | |
| Einsum | ai.onnx(12+) | |
| Elu | ai.onnx(6+) | |
diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts
index 4ee1fd5442d83..4f4a06c37a94f 100644
--- a/js/web/lib/wasm/jsep/backend-webgpu.ts
+++ b/js/web/lib/wasm/jsep/backend-webgpu.ts
@@ -254,11 +254,9 @@ export class WebGpuBackend {
}
isQueryEnabled(): boolean {
- if (this.device.features.has('timestamp-query') && this.env.webgpu.profilingMode === 'default') {
- return true;
- } else {
- return false;
- }
+ return this.device.features.has('timestamp-query') &&
+ (this.env.webgpu.profiling?.mode === 'default' ||
+ (!this.env.webgpu.profiling?.mode && this.env.webgpu.profilingMode === 'default'));
}
/**
@@ -338,51 +336,26 @@ export class WebGpuBackend {
let uniformBufferBinding: GPUBindingResource|undefined;
if (programUniforms) {
let currentOffset = 0;
- let preLength = 0;
const offsets: number[] = [];
- let maxAlignmentOfField = 1;
+
programUniforms.forEach(v => {
const data = typeof v.data === 'number' ? [v.data] : v.data;
if (data.length === 0) {
return;
}
// https://www.w3.org/TR/WGSL/#alignof
- let baseAlignment: number;
- switch (data.length) {
- case 1:
- baseAlignment = 4;
- break;
- case 2:
- baseAlignment = 8;
- break;
- case 3:
- baseAlignment = 16;
- break;
- case 4:
- baseAlignment = 16;
- break;
- case 5:
- baseAlignment = 16;
- break;
- case 6:
- baseAlignment = 16;
- break;
- default:
- throw new Error(`unsupported data length: ${data.length}`);
- }
-
- if (preLength === 5 || preLength === 6) {
- baseAlignment = 16;
- }
- if (baseAlignment > maxAlignmentOfField) {
- maxAlignmentOfField = baseAlignment;
- }
+ const baseAlignment = data.length <= 2 ? data.length * 4 : 16;
currentOffset = Math.ceil(currentOffset / baseAlignment) * baseAlignment;
- preLength = data.length;
offsets.push(currentOffset);
- currentOffset += data.length * 4;
+ // When data.length > 4, the uniform variable is of type array,N>, where N =
+ // Math.ceil(data.length / 4) and SizeOf(vec4) = 16. The total byte length is N *
+ // SizeOf(vec4).
+ currentOffset += data.length > 4 ? Math.ceil(data.length / 4) * 16 : data.length * 4;
});
+ // Meet alignment of struct here: https://www.w3.org/TR/WGSL/#alignment-and-size. For simplicity, set
+ // maxAlignmentOfField to 16 since the underlying buffer has been rounded up to 16.
+ const maxAlignmentOfField = 16;
currentOffset = Math.ceil(currentOffset / maxAlignmentOfField) * maxAlignmentOfField;
const arrayBuffer = new ArrayBuffer(currentOffset);
programUniforms.forEach((v, i) => {
diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts
index d66357e729d5d..e6db631c44eea 100644
--- a/js/web/lib/wasm/jsep/init.ts
+++ b/js/web/lib/wasm/jsep/init.ts
@@ -175,8 +175,7 @@ export const init = async(module: OrtWasmModule, env: Env): Promise => {
// jsepCreateKernel
(name: string, kernel: number, attribute: unknown) => backend.createKernel(
name, kernel, attribute,
- env.debug || env.webgpu.profilingMode === 'default' ? module.UTF8ToString(module._JsepGetNodeName(kernel)) :
- `${kernel}`),
+ env.debug || backend.isQueryEnabled() ? module.UTF8ToString(module._JsepGetNodeName(kernel)) : `${kernel}`),
// jsepReleaseKernel
(kernel: number) => backend.releaseKernel(kernel),
diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
index 80f6e3bc11195..8e1ec782079be 100644
--- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
+++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
@@ -10,6 +10,7 @@ import * as binaryOps from './ops/binary-op';
import {concat, parseConcatAttributes} from './ops/concat';
import {conv, parseConvAttributes} from './ops/conv';
import {convTranspose, parseConvTransposeAttributes} from './ops/conv-transpose';
+import {cumsum, parseCumSumAttributes} from './ops/cumsum';
import {einsum, parseEinsumAttributes} from './ops/einsum';
import {expand} from './ops/expand';
import {gather, parseGatherAttributes} from './ops/gather';
@@ -22,7 +23,7 @@ import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi
import {pad, parsePadAttributes} from './ops/pad';
import * as pool from './ops/pool';
import {range} from './ops/range';
-import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
+import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce';
import {parseResizeAttributes, resize} from './ops/resize';
import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm';
import {parseSliceAttributes, slice} from './ops/slice';
@@ -63,6 +64,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new
['ConvTranspose', [convTranspose, parseConvTransposeAttributes]],
['Cos', [unaryOps.cos]],
['Cosh', [unaryOps.cosh]],
+ ['CumSum', [cumsum, parseCumSumAttributes]],
['Div', [binaryOps.div]],
['Einsum', [einsum, parseEinsumAttributes]],
['Elu', [unaryOps.elu, unaryOps.parseAlphaAttributes]],
@@ -97,16 +99,16 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new
['Pow', [binaryOps.pow]],
['Range', [range]],
['Reciprocal', [unaryOps.reciprocal]],
- ['ReduceMin', [reduceMin, parseReduceAttributes]],
- ['ReduceMean', [reduceMean, parseReduceAttributes]],
- ['ReduceMax', [reduceMax, parseReduceAttributes]],
- ['ReduceSum', [reduceSum, parseReduceAttributes]],
- ['ReduceProd', [reduceProd, parseReduceAttributes]],
- ['ReduceL1', [reduceL1, parseReduceAttributes]],
- ['ReduceL2', [reduceL2, parseReduceAttributes]],
- ['ReduceLogSum', [reduceLogSum, parseReduceAttributes]],
- ['ReduceLogSumExp', [reduceLogSumExp, parseReduceAttributes]],
- ['ReduceSumSquare', [reduceSumSquare, parseReduceAttributes]],
+ ['ReduceMin', [reduceMin]],
+ ['ReduceMean', [reduceMean]],
+ ['ReduceMax', [reduceMax]],
+ ['ReduceSum', [reduceSum]],
+ ['ReduceProd', [reduceProd]],
+ ['ReduceL1', [reduceL1]],
+ ['ReduceL2', [reduceL2]],
+ ['ReduceLogSum', [reduceLogSum]],
+ ['ReduceLogSumExp', [reduceLogSumExp]],
+ ['ReduceSumSquare', [reduceSumSquare]],
['Relu', [unaryOps.relu]],
['Resize', [resize, parseResizeAttributes]],
['Sigmoid', [unaryOps.sigmoid]],
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
index a8f296ea0c865..47ec16a296712 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts
@@ -510,11 +510,7 @@ export const createMatmulProgramInfo =
name: 'MatMul',
shaderCache: {
hint: activationAttributes.activationCacheKey + `${elementsPerThread}` +
- `${activationAttributes.activation}` +
- `${activationAttributes.clipMax}` +
- `${activationAttributes.clipMin}` +
`${isVec4}` +
- `${hasBias}` +
`${isChannelsLast}`,
inputDependencies
},
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts
index b6c6853c8f222..1f27525f370f3 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/argminmax.ts
@@ -33,23 +33,23 @@ export const argMin = (context: ComputeContext, attributes: ArgMinMaxAttributes)
const idxZero = [];
for (let k = 0; k < input.rank; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
- idxZero.push(`inputIndices[${k}] = 0;`); // first element
+ idxZero.push(`input_indices[${k}] = 0;`); // first element
}
}
return [
- `${idxZero.join('\n')}`, `var value = ${input.getByOffset('inputOffset')};\nvar bestIndex : i32 = 0;`,
- `if (${input.getByOffset('inputOffset')} ${attributes.selectLastIndex > 0 ? '<=' : '<'} value) {
- value = ${input.getByOffset('inputOffset')};
- bestIndex = i32(lastIndex);
+ `${idxZero.join('\n')}`, `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`,
+ `if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '<=' : '<'} value) {
+ value = ${input.getByIndices('input_indices')};
+ best_index = i32(last_index);
}`,
- '', output.setByOffset('global_idx', 'bestIndex')
+ '', output.setByOffset('global_idx', 'best_index')
];
};
context.compute(
createReduceProgramInfo(
- 'ArgMin', {hint: attributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [attributes.axis], DataType.int64,
- attributes.keepDims),
+ 'ArgMin', {hint: attributes.cacheKey, inputDependencies: ['rank']}, [context.inputs[0]], argMinMaxOp,
+ [attributes.axis], DataType.int64, attributes.keepDims),
{inputs: [0]});
};
@@ -59,23 +59,23 @@ export const argMax = (context: ComputeContext, attributes: ArgMinMaxAttributes)
const idxZero = [];
for (let k = 0; k < input.rank; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
- idxZero.push(`inputIndices[${k}] = 0;`); // first element
+ idxZero.push(`input_indices[${k}] = 0;`); // first element
}
}
return [
- `${idxZero.join('\n')}`, `var value = ${input.getByOffset('inputOffset')};\nvar bestIndex : i32 = 0;`,
- `if (${input.getByOffset('inputOffset')} ${attributes.selectLastIndex > 0 ? '>=' : '>'} value) {
- value = ${input.getByOffset('inputOffset')};
- bestIndex = i32(lastIndex);
+ `${idxZero.join('\n')}`, `var value = ${input.getByIndices('input_indices')};\nvar best_index : i32 = 0;`,
+ `if (${input.getByIndices('input_indices')} ${attributes.selectLastIndex > 0 ? '>=' : '>'} value) {
+ value = ${input.getByIndices('input_indices')};
+ best_index = i32(last_index);
}`,
- '', output.setByOffset('global_idx', 'bestIndex')
+ '', output.setByOffset('global_idx', 'best_index')
];
};
context.compute(
createReduceProgramInfo(
- 'argMax', {hint: attributes.cacheKey}, [context.inputs[0]], argMinMaxOp, [attributes.axis], DataType.int64,
- attributes.keepDims),
+ 'argMax', {hint: attributes.cacheKey, inputDependencies: ['rank']}, [context.inputs[0]], argMinMaxOp,
+ [attributes.axis], DataType.int64, attributes.keepDims),
{inputs: [0]});
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
index b7a391ee667bb..5fffa2f266603 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
@@ -325,6 +325,24 @@ export const sumVector = (name: string, components: number) => {
return name;
};
+/**
+ * A helper function that returns variable element at index.
+ * @param name - the name of variable.
+ * @param index - the index of variable element.
+ * @param length - the length of variable.
+ */
+export const getElementAt = (name: string, index: number|string, length: number): string => {
+ if (name.startsWith('uniforms.') && length > 4) {
+ if (typeof (index) === 'string') {
+ return `${name}[(${index}) / 4][(${index}) % 4]`;
+ } else {
+ return `${name}[${Math.floor(index / 4)}][${index % 4}]`;
+ }
+ } else {
+ return length > 1 ? `${name}[${index}]` : name;
+ }
+};
+
/**
* A helper function to get a IndicesHelper for a given input or output.
*
@@ -362,11 +380,12 @@ const createIndicesHelper =
const uniformPrefix = useUniform ? 'uniforms.' : '';
const shape = `${uniformPrefix}${name}_shape`;
const strides = `${uniformPrefix}${name}_strides`;
+
let o2iSnippet = '';
for (let i = 0; i < rank - 1; i++) {
o2iSnippet += `
- let dim${i} = current / ${strides}[${i}];
- let rest${i} = current % ${strides}[${i}];
+ let dim${i} = current / ${getElementAt(strides, i, rank)};
+ let rest${i} = current % ${getElementAt(strides, i, rank)};
indices[${i}] = dim${i};
current = rest${i};
`;
@@ -389,7 +408,7 @@ const createIndicesHelper =
const offsets: string[] = [];
if (rank >= 2) {
for (let i = rank - 1; i >= 0; i--) {
- offsets.push(`${strides}[${i}] * (indices[${i}])`);
+ offsets.push(`${getElementAt(strides, i, rank)} * (indices[${i}])`);
}
}
@@ -410,7 +429,7 @@ const createIndicesHelper =
if (rank < 2) {
return `${varIndices}`;
} else {
- return `${varIndices}[${idx}]`;
+ return `${getElementAt(varIndices, idx, rank)}`;
}
};
@@ -418,7 +437,7 @@ const createIndicesHelper =
if (rank < 2) {
return `${varIndices}=${value};`;
} else {
- return `${varIndices}[${idx}]=${value};`;
+ return `${getElementAt(varIndices, idx, rank)}=${value};`;
}
};
@@ -660,7 +679,8 @@ export const internalVariable =
(name: string, type: number, shapeOrRank: number|readonly number[], components: 1|2|3|4 = 1): IndicesHelper =>
createIndicesHelper(name, type, shapeOrRank, 'internal', components);
-export type UniformsArrayType = Array<{name: string; type: string}>;
+export type UniformDataElementType = 'u32'|'f32'|'i32';
+export type UniformsArrayType = Array<{name: string; type: UniformDataElementType; length?: number}>;
/**
* A ShaderHelper is a helper class for generating WGSL code.
@@ -714,8 +734,9 @@ export interface ShaderHelper {
*
* @param name - the name of the uniform.
* @param type - the type of the uniform.
+ * @param length - the length of the uniform, default to 1 when it is not provided.
*/
- registerUniform(name: string, type: string): ShaderHelper;
+ registerUniform(name: string, type: string, length?: number): ShaderHelper;
/**
* A helper function to register multiple uniforms. Can be called multiple times to register multiple uniforms.
@@ -769,10 +790,10 @@ class ShaderHelperImpl implements ShaderHelper {
private appendVariableUniforms(variable: IndicesHelper): void {
if (variable.rank !== 0) {
if (variable.shape.startsWith('uniforms.')) {
- this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: variable.type.indices});
+ this.uniforms.push({name: variable.shape.replace('uniforms.', ''), type: 'u32', length: variable.rank});
}
if (variable.strides.startsWith('uniforms.')) {
- this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: variable.type.indices});
+ this.uniforms.push({name: variable.strides.replace('uniforms.', ''), type: 'u32', length: variable.rank});
}
}
}
@@ -808,8 +829,8 @@ class ShaderHelperImpl implements ShaderHelper {
return this;
}
- registerUniform(name: string, type: string): ShaderHelper {
- this.uniforms.push({name, type});
+ registerUniform(name: string, type: UniformDataElementType, length = 1): ShaderHelper {
+ this.uniforms.push({name, type, length});
return this;
}
@@ -827,8 +848,13 @@ class ShaderHelperImpl implements ShaderHelper {
}
const uniformSnippets: string[] = [];
- for (const {name, type} of this.uniforms) {
- uniformSnippets.push(`${name}:${type}`);
+ for (const {name, type, length} of this.uniforms) {
+ if (length && length > 4) {
+ uniformSnippets.push(`${name}:array, ${Math.ceil(length / 4)}>`);
+ } else {
+ const typeTemp = length == null || length === 1 ? type : `vec${length}<${type}>`;
+ uniformSnippets.push(`${name}:${typeTemp}`);
+ }
}
return `
@@ -872,5 +898,5 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly
return dims;
};
-// TODO: remove this limitation once >4D dims are supported by uniform.
-export const enableShapesUniforms = (rank: number): boolean => rank <= 4;
+// TODO: remove this when all related uses have been removed.
+export const enableShapesUniforms = (_rank: number): boolean => true;
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
index e880afe09a5d8..32b1d52ed94ca 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
@@ -209,18 +209,20 @@ const convTranspose2d =
(context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => {
const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs);
const isChannelsLast = attributes.format === 'NHWC';
- const hasBias = inputs.length === 3;
- if (adjustedAttributes.group !== 1) {
+ const outputShape = adjustedAttributes.outputShape;
+ const outChannels = outputShape[isChannelsLast ? 3 : 1];
+ const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
+ // Switch to naive method when outChannels and inputChannels are very small. It's because that in this case it's
+ // not suitable for matmul version since matmul uses tile size 32x32 resulting the underlying execution unit
+ // utilization rate is very low.
+ if (adjustedAttributes.group !== 1 || (outChannels === 1 && inputChannels === 1)) {
context.compute(createConvTranspose2DProgramInfo(inputs, adjustedAttributes));
return;
}
- const outputShape = adjustedAttributes.outputShape;
const outHeight = outputShape[isChannelsLast ? 1 : 2];
const outWidth = outputShape[isChannelsLast ? 2 : 3];
- const outChannels = outputShape[isChannelsLast ? 3 : 1];
const weightHeight = inputs[1].dims[2];
const weightWidth = inputs[1].dims[3];
- const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1];
const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels;
const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth;
@@ -240,6 +242,7 @@ const convTranspose2d =
// STEP.2: prepare reshaped inputs
const convTransposeInputs = [inputs[0], transposedWeight];
+ const hasBias = inputs.length === 3;
if (hasBias) {
if (!isChannelsLast && inputs[2].dims.length === 1) {
convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1]));
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
index c7ea0cffe51c3..33a5db7ff6b25 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
@@ -10,6 +10,7 @@ import {createConv2DMatMulProgramInfo} from './3rd-party/conv2d_mm_webgpu';
import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu';
import {createGroupedConvProgramInfo} from './conv-grouped';
import {InternalActivationAttributes, parseInternalActivationAttributes} from './fuse-utils';
+import {createNaiveMatmulProgramInfo} from './matmul';
import {createTransposeProgramInfo} from './transpose';
export const calculateOutputShape =
@@ -195,9 +196,19 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
if (hasBias) {
matmulInputs.push(inputs[2]);
}
- context.compute(
- createMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast),
- {inputs: matmulInputs});
+ const N = matmulOutputShape[2];
+ const K = matmulInputs[0].dims[matmulInputs[0].dims.length - 1];
+ // Tune the threshold.
+ if (N < 8 && K < 8) {
+ context.compute(
+ createNaiveMatmulProgramInfo(
+ matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast),
+ {inputs: matmulInputs});
+ } else {
+ context.compute(
+ createMatmulProgramInfo(matmulInputs, adjustedAttributes, outputShape, matmulOutputShape, isChannelsLast),
+ {inputs: matmulInputs});
+ }
return;
}
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts
new file mode 100644
index 0000000000000..2ff909c30e62e
--- /dev/null
+++ b/js/web/lib/wasm/jsep/webgpu/ops/cumsum.ts
@@ -0,0 +1,78 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+import {DataType} from '../../../wasm-common';
+import {TensorView} from '../../tensor-view';
+import {ShapeUtil} from '../../util';
+import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
+import {ComputeContext, ProgramInfo} from '../types';
+
+import {createTensorShapeVariables, getElementAt, inputVariable, outputVariable, ShaderHelper} from './common';
+
+
+export interface CumSumAttributes extends AttributeWithCacheKey {
+ readonly exclusive: boolean;
+ readonly reverse: boolean;
+}
+const createCumsumProgramInfo =
+ (inputType: number, inputShape: readonly number[], axisInput: TensorView, attributes: CumSumAttributes):
+ ProgramInfo => {
+ const outputSize = ShapeUtil.size(inputShape); // outputShape is same as inputShape.
+ const rank = inputShape.length; // input/output rank
+ const input = inputVariable('input', inputType, rank);
+ const output = outputVariable('output', inputType, rank);
+ const axisValue = axisInput.dataType === DataType.int32 ? axisInput.getInt32Array()[0] :
+ Number(axisInput.getBigInt64Array()[0]);
+ const axis = ShapeUtil.normalizeAxis(axisValue, rank);
+ const getShaderSource = (shaderHelper: ShaderHelper) => {
+ const index = ` i32(${input.indicesGet('inputIndices', 'uniforms.axis')}) `;
+ const max = getElementAt('uniforms.input_shape', 'uniforms.axis', rank);
+ const lowerLimit = attributes.reverse ? index + (attributes.exclusive ? ' + 1' : '') : '0';
+ const upperLimit = attributes.reverse ? max : index + (attributes.exclusive ? '' : ' + 1');
+ return `
+ ${
+ shaderHelper.registerUniform('outputSize', 'u32')
+ .registerUniform('axis', 'u32')
+ .declareVariables(input, output)}
+ ${shaderHelper.mainStart()}
+ ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
+ var inputIndices = ${output.offsetToIndices('global_idx')};
+ var sum = ${output.type.value}(0);
+ let first : i32 = ${lowerLimit};
+ let last : i32 = ${upperLimit};
+ for (var i : i32 = first; i < last; i++) {
+ ${input.indicesSet('inputIndices', 'uniforms.axis', 'u32(i)')};
+ sum = sum + ${input.getByIndices('inputIndices')};
+ }
+ ${output.setByOffset('global_idx', 'sum')};
+ }`;
+ };
+ return {
+ name: 'CumSum',
+ shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank']},
+ getRunData: () => ({
+ outputs: [{dims: inputShape, dataType: inputType}],
+ dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
+ programUniforms: [
+ {type: 'uint32', data: outputSize}, {type: 'int32', data: axis},
+ ...createTensorShapeVariables(inputShape), ...createTensorShapeVariables(inputShape)
+ ]
+
+ }),
+ getShaderSource
+ };
+ };
+
+
+export const cumsum = (context: ComputeContext, attributes: CumSumAttributes): void => {
+ const inputShape = context.inputs[0].dims;
+ const inputType = context.inputs[0].dataType;
+ const axis = context.inputs[1];
+ context.compute(createCumsumProgramInfo(inputType, inputShape, axis, attributes), {inputs: [0]});
+};
+
+export const parseCumSumAttributes = (attributes: Record): CumSumAttributes => {
+ const exclusive = attributes.exclusive as number === 1;
+ const reverse = attributes.reverse as number === 1;
+ return createAttributeWithCacheKey({exclusive, reverse});
+};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts
index d998013352d77..3dc4e957e0fee 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/expand.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/expand.ts
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
@@ -44,34 +45,51 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
const inputShape = inputs[0].dims;
const shape = Array.from(inputs[1].getBigInt64Array(), Number);
const outputShape: number[] = calculateOutputShape(inputShape, shape);
- const outputSize = ShapeUtil.size(outputShape);
-
const dataType = inputs[0].dataType;
+ const components = dataType === DataType.bool ? 4 : 1;
+ const outputSize = ShapeUtil.size(outputShape) / components;
+
const enableInputShapeUniform = enableShapesUniforms(inputShape.length);
- const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape;
- const input = inputVariable('input', dataType, inputShapeOrRank);
const enableOutputShapeUniform = enableShapesUniforms(outputShape.length);
- const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape;
- const output = outputVariable('output', dataType, outputShapeOrRank);
- const getShaderSource = (shaderHelper: ShaderHelper) => `
- const inputShape = ${input.indices(...inputShape)};
- ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)}
- ${shaderHelper.mainStart()}
- ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')}
- let outputIndices = ${output.offsetToIndices('global_idx')};
- var inputIndices: ${input.type.indices};
- for (var i = 0; i < ${inputShape.length}; i++) {
- if (${input.indicesGet('inputShape', 'i')} == 1) {
- ${input.indicesSet('inputIndices', 'i', 0)}
- } else {
- ${
- input.indicesSet(
- 'inputIndices', 'i', output.indicesGet('outputIndices', `i + ${outputShape.length - inputShape.length}`))}
- }
+
+ const getShaderSource = (shaderHelper: ShaderHelper) => {
+ const inputShapeOrRank = enableInputShapeUniform ? inputShape.length : inputShape;
+ const outputShapeOrRank = enableOutputShapeUniform ? outputShape.length : outputShape;
+ const input = inputVariable('input', dataType, inputShapeOrRank, components);
+ const output = outputVariable('output', dataType, outputShapeOrRank, components);
+ let assignment: string;
+ if (dataType === DataType.bool) {
+ const singleAssignment = (resStr: string, x: number, typeCast = '') => `
+ let outputIndices${x} = ${output.offsetToIndices(`outputOffset + ${x}u`)};
+ let offset${x} = ${input.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
+ let index${x} = offset${x} / 4u;
+ let component${x} = offset${x} % 4u;
+ ${resStr}[${x}] = ${typeCast}(${input.getByOffset(`index${x}`)}[component${x}]);
+ `;
+ assignment = `
+ let outputOffset = global_idx * ${components};
+ var data = vec4(0);
+ ${singleAssignment('data', 0, 'u32')}
+ ${singleAssignment('data', 1, 'u32')}
+ ${singleAssignment('data', 2, 'u32')}
+ ${singleAssignment('data', 3, 'u32')}
+ ${output.setByOffset('global_idx', 'data')}
+ }`;
+ } else {
+ assignment = `
+ let outputIndices = ${output.offsetToIndices('global_idx')};
+ let inputOffset = ${input.broadcastedIndicesToOffset('outputIndices', output)};
+ ${output.setByOffset('global_idx', input.getByOffset('inputOffset'))}
+ }`;
}
- ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))}
- }`;
+ return `
+ ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(input, output)}
+ ${shaderHelper.mainStart()}
+ ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')}
+ ${assignment}`;
+ };
+
const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}];
if (enableInputShapeUniform) {
programUniforms.push(...createTensorShapeVariables(inputShape));
@@ -81,7 +99,7 @@ const createExpandProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
}
return {
name: 'Expand',
- shaderCache: {hint: `${outputShape}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']},
+ shaderCache: {hint: `${outputShape.length}`, inputDependencies: [enableInputShapeUniform ? 'rank' : 'dims']},
getShaderSource,
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts
index 9924a50e2ae6f..a945954adcaa4 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/gather-elements.ts
@@ -4,9 +4,9 @@
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
-import {ComputeContext, ProgramInfo} from '../types';
+import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
-import {inputVariable, outputVariable, ShaderHelper} from './common';
+import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
export interface GatherElementsAttributes extends AttributeWithCacheKey {
axis: number;
@@ -32,65 +32,59 @@ const createGatherElementsProgramInfo =
const inputShape = inputs[0].dims;
const inputOutputDataType = inputs[0].dataType;
const inputRank = inputShape.length;
- const inputStrides = ShapeUtil.computeStrides(inputShape);
- const inputSize = ShapeUtil.size(inputShape);
const indicesShape = inputs[1].dims;
const indicesDataType = inputs[1].dataType;
- const indicesSize = ShapeUtil.size(indicesShape);
-
const axis = ShapeUtil.normalizeAxis(attributes.axis, inputRank);
const axisDimLimit = inputShape[axis];
const outputShape = indicesShape.slice(0);
const outputSize = ShapeUtil.size(outputShape);
- const input = inputVariable('input', inputOutputDataType, inputShape);
- const indices = inputVariable('indices', indicesDataType, [indicesSize]);
- const output = outputVariable('output', inputOutputDataType, outputShape);
+ const input = inputVariable('input', inputOutputDataType, inputRank);
+ const indices = inputVariable('indicesInput', indicesDataType, indicesShape.length);
+ const output = outputVariable('output', inputOutputDataType, outputShape.length);
+
+ const programUniforms: ProgramUniform[] =
+ [{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}];
+ programUniforms.push(...createTensorShapeVariables(inputShape));
+ programUniforms.push(...createTensorShapeVariables(indicesShape));
+ programUniforms.push(...createTensorShapeVariables(outputShape));
+ const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank', 'rank'];
// int64 indices would be treated as little endian i32 with assumption they fall in i32 limits
// That assumption is safe as it's not possible to allocate >2gb buffer for input tensor
// Input data will be treated as u32 or two u32 for 8-byte tensors
const getShaderSource = (shaderHelper: ShaderHelper) => `
- const inputStrides = array(${inputStrides.map(i => `${i}u`).join(',')});
- ${shaderHelper.declareVariables(input, indices, output)}
+ ${
+ shaderHelper.registerUniform('outputSize', 'u32')
+ .registerUniform('axisDimLimit', 'i32')
+ .registerUniform('axis', 'u32')
+ .declareVariables(input, indices, output)}
${shaderHelper.mainStart()}
- ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
+ ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
let outputIndices = ${output.offsetToIndices('global_idx')};
var idx = ${indices.getByOffset('global_idx')};
if (idx < 0) {
- idx = idx + ${axisDimLimit};
- }
-
- var srcOffset = u32(0);
-
- for (var i = 0; i < ${inputShape.length}; i++) {
- if (i == ${axis}) {
- srcOffset += u32(idx) * inputStrides[i];
- } else {
- srcOffset += ${output.indicesGet('outputIndices', 'i')} * inputStrides[i];
- }
- }
-
- // Should never hit this with valid values in indices
- // This is a guard against malicious data in the indices input
- if (srcOffset < 0 || srcOffset >= ${inputSize}) {
- return;
+ idx = idx + uniforms.axisDimLimit;
}
+ var inputIndices = ${input.type.indices}(outputIndices);
+ ${input.indicesSet('inputIndices', 'uniforms.axis', 'u32(idx)')};
+ let value = ${input.getByIndices('inputIndices')};
- output[global_idx] = input[srcOffset];
+ ${output.setByOffset('global_idx', 'value')};
}`;
return {
name: 'GatherElements',
- shaderCache: {hint: attributes.cacheKey},
+ shaderCache: {inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
- dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
+ dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
+ programUniforms
}),
getShaderSource,
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts
index 5d6d6debadb9a..53ca094abfd62 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/gather.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/gather.ts
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+import {DataType} from '../../../wasm-common';
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
@@ -29,7 +30,8 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath
outputShape.splice(axis, 1, ...indicesShape);
const axisDimLimit = inputShape[axis];
- const outputSize = ShapeUtil.size(outputShape);
+ const components = inputs[0].dataType === DataType.bool ? 4 : 1;
+ const outputSize = ShapeUtil.size(outputShape) / components;
const enableInputShapesUniforms = enableShapesUniforms(inputs[0].dims.length);
const inputShapeOrRank = enableInputShapesUniforms ? inputs[0].dims.length : inputs[0].dims;
@@ -38,10 +40,6 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath
const enableOutputShapesUniforms = enableShapesUniforms(outputShape.length);
const outputShapeOrRank = enableOutputShapesUniforms ? outputShape.length : outputShape;
- const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank);
- const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank);
- const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank);
-
const programUniforms: ProgramUniform[] =
[{type: 'uint32', data: outputSize}, {type: 'int32', data: axisDimLimit}, {type: 'uint32', data: axis}];
if (enableInputShapesUniforms) {
@@ -58,46 +56,75 @@ const createGatherProgramInfo = (inputs: readonly TensorView[], attributes: Gath
inputDependencies.push(enableInputShapesUniforms ? 'rank' : 'dims');
inputDependencies.push(enableIndicesShapesUniforms ? 'rank' : 'dims');
- const calcDataIndices = (): string => {
- const indicesRank = indicesShape.length;
- let calcStr = `var indicesIndices = ${indices.type.indices}(0);`;
- for (let i = 0; i < indicesRank; i++) {
- calcStr += `${indicesRank > 1 ? `indicesIndices[${i}]` : 'indicesIndices'} = ${
- outputShape.length > 1 ? `outputIndices[uniforms.axis + ${i}]` : 'outputIndices'};`;
- }
- calcStr += `
- var idx = ${indices.getByIndices('indicesIndices')};
- if (idx < 0) {
- idx = idx + uniforms.axisDimLimit;
+ const getShaderSource = (shaderHelper: ShaderHelper) => {
+ const data = inputVariable('data', inputs[0].dataType, inputShapeOrRank, components);
+ const indices = inputVariable('inputIndices', inputs[1].dataType, indicesShapeOrRank);
+ const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank, components);
+
+ const calcDataIndices = (x: number|string): string => {
+ const indicesRank = indicesShape.length;
+ let calcStr = `var indicesIndices${x} = ${indices.type.indices}(0);`;
+ for (let i = 0; i < indicesRank; i++) {
+ calcStr += `${indicesRank > 1 ? `indicesIndices${x}[${i}]` : `indicesIndices${x}`} = ${
+ outputShape.length > 1 ? `outputIndices${x}[uniforms.axis + ${i}]` : `outputIndices${x}`};`;
+ }
+ calcStr += `
+ var idx${x} = ${indices.getByIndices(`indicesIndices${x}`)};
+ if (idx${x} < 0) {
+ idx${x} = idx${x} + uniforms.axisDimLimit;
+ }
+ var dataIndices${x} = ${data.type.indices}(0);
+ `;
+ for (let i = 0, j = 0; i < inputRank; i++) {
+ if (i === axis) {
+ calcStr += `${inputRank > 1 ? `dataIndices${x}[${i}]` : `dataIndices${x}`} = u32(idx${x});`;
+ j += indicesRank;
+ } else {
+ calcStr += `${inputRank > 1 ? `dataIndices${x}[${i}]` : `dataIndices${x}`} = ${
+ outputShape.length > 1 ? `outputIndices${x}[${j}]` : `outputIndices${x}`};`;
+ j++;
}
- var dataIndices = ${data.type.indices}(0);
- `;
- for (let i = 0, j = 0; i < inputRank; i++) {
- if (i === axis) {
- calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = u32(idx);`;
- j += indicesRank;
- } else {
- calcStr += `${inputRank > 1 ? `dataIndices[${i}]` : 'dataIndices'} = ${
- outputShape.length > 1 ? `outputIndices[${j}]` : 'outputIndices'};`;
- j++;
}
+ return calcStr;
+ };
+ let assignment: string;
+ if (inputs[0].dataType === DataType.bool) {
+ const singleAssignment = (resStr: string, x: number, typeCast = '') => `
+ let outputIndices${x} = ${output.offsetToIndices(`outputOffset + ${x}u`)};
+ ${calcDataIndices(x)};
+ let offset${x} = ${data.indicesToOffset(`dataIndices${x}`)};
+ let index${x} = offset${x} / 4u;
+ let component${x} = offset${x} % 4u;
+ ${resStr}[${x}] = ${typeCast}(${data.getByOffset(`index${x}`)}[component${x}]);
+ `;
+ assignment = `
+ let outputOffset = global_idx * ${components};
+ var value = vec4(0);
+ ${singleAssignment('value', 0, 'u32')}
+ ${singleAssignment('value', 1, 'u32')}
+ ${singleAssignment('value', 2, 'u32')}
+ ${singleAssignment('value', 3, 'u32')}
+ ${output.setByOffset('global_idx', 'value')}
+ `;
+ } else {
+ assignment = `
+ let outputIndices = ${output.offsetToIndices('global_idx')};
+ ${calcDataIndices('')};
+ let value = ${data.getByIndices('dataIndices')};
+ ${output.setByOffset('global_idx', 'value')};
+ `;
}
- return calcStr;
- };
-
- const getShaderSource = (shaderHelper: ShaderHelper) => `
+ return `
${
- shaderHelper.registerUniform('outputSize', 'u32')
- .registerUniform('axisDimLimit', 'i32')
- .registerUniform('axis', 'u32')
- .declareVariables(data, indices, output)}
+ shaderHelper.registerUniform('outputSize', 'u32')
+ .registerUniform('axisDimLimit', 'i32')
+ .registerUniform('axis', 'u32')
+ .declareVariables(data, indices, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
- let outputIndices = ${output.offsetToIndices('global_idx')};
- ${calcDataIndices()};
- let value = ${data.getByIndices('dataIndices')};
- ${output.setByOffset('global_idx', 'value')};
+ ${assignment}
}`;
+ };
return {
name: 'Gather',
shaderCache: {hint: attributes.cacheKey, inputDependencies},
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
index 97f633c7cf47e..3a84844544c96 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/instance-norm.ts
@@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
-import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';
+import {fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType} from './common';
export interface InstanceNormAttributes extends AttributeWithCacheKey {
epsilon: number;
@@ -26,22 +26,25 @@ const createInstanceNormProgramInfo =
const axis = 2;
const normCount = ShapeUtil.sizeToDimension(xShape, axis);
const normSize = ShapeUtil.sizeFromDimension(xShape, axis);
+ const components = getMaxComponents(normSize);
+ const normPackedSize = normSize / components;
const C = xShape[1];
- const x = inputVariable('x', inputs[0].dataType, [xShape[0], xShape[1], normSize]);
+ const x = inputVariable('x', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components);
const scale = inputVariable('scale', inputs[1].dataType, inputs[1].dims);
const bias = inputVariable('bias', inputs[2].dataType, inputs[2].dims);
- const output = outputVariable('output', inputs[0].dataType, [xShape[0], xShape[1], normSize]);
+ const output = outputVariable('output', inputs[0].dataType, [xShape[0], xShape[1], normPackedSize], components);
const variables = [x, scale, bias, output];
const dataType = x.type.value;
+ const f32Type = components === 1 ? 'f32' : `vec${components}`;
const workgroupSize = 64;
const getShaderSource = (shaderHelper: ShaderHelper) => `
const C: u32 = ${C};
const normSize: u32 = ${normSize};
const epsilon: f32 = ${attributes.epsilon};
- var meanShared : ${dataType};
- var squaredNormShared : ${dataType};
- var workgroupShared : array<${dataType}, ${workgroupSize}>;
+ var meanShared : f32;
+ var squaredNormShared : f32;
+ var workgroupShared : array<${f32Type}, ${workgroupSize}>;
const workgroupSize = ${workgroupSize}u;
${shaderHelper.declareVariables(...variables)}
${shaderHelper.mainStart(workgroupSize)}
@@ -51,9 +54,9 @@ const createInstanceNormProgramInfo =
let localIndex = local_id.x;
// initialize workgroup memory
- var initial: ${dataType} = 0;
- for (var h = localIndex; h < normSize; h += workgroupSize) {
- initial = initial + ${x.get('batch', 'channel', 'h')};
+ var initial = ${f32Type}(0);
+ for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) {
+ initial = initial + ${f32Type}(${x.get('batch', 'channel', 'h')});
}
workgroupShared[localIndex] = initial;
workgroupBarrier();
@@ -66,14 +69,14 @@ const createInstanceNormProgramInfo =
workgroupBarrier();
}
if (localIndex == 0) {
- meanShared = workgroupShared[0] / ${dataType}(normSize);
+ meanShared = ${sumVector('workgroupShared[0]', components)} / f32(normSize);
}
workgroupBarrier();
// reinitialize workgroup memory.
- initial = 0;
- for (var h = localIndex; h < normSize; h += workgroupSize) {
- let deviation = ${x.get('batch', 'channel', 'h')} - meanShared;
+ initial = ${f32Type}(0);
+ for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) {
+ let deviation = ${f32Type}(${x.get('batch', 'channel', 'h')}) - ${f32Type}(meanShared);
initial = initial + deviation * deviation;
}
workgroupShared[localIndex] = initial;
@@ -87,15 +90,16 @@ const createInstanceNormProgramInfo =
workgroupBarrier();
}
if (localIndex == 0) {
- squaredNormShared = workgroupShared[0];
+ squaredNormShared = ${sumVector('workgroupShared[0]', components)};
}
workgroupBarrier();
- let invStdDev = 1 / sqrt(squaredNormShared / ${dataType}(normSize) + epsilon);
- let channelScale = invStdDev * ${scale.getByOffset('channel')};
- let channelShift = ${bias.getByOffset('channel')} - meanShared * channelScale;
- for (var h = localIndex; h < normSize; h += workgroupSize) {
- let value = ${x.get('batch', 'channel', 'h')} * channelScale + channelShift;
+ let invStdDev = 1 / sqrt(squaredNormShared / f32(normSize) + epsilon);
+ let channelScale = invStdDev * f32(${scale.getByOffset('channel')});
+ let channelShift = f32(${bias.getByOffset('channel')}) - meanShared * channelScale;
+ for (var h = localIndex; h < ${normPackedSize}; h += workgroupSize) {
+ let value = ${x.get('batch', 'channel', 'h')} * ${dataType}(${f32Type}(channelScale)) + ${dataType}(${
+ f32Type}(channelShift));
${output.set('batch', 'channel', 'h', 'value')};
}
}`;
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
index 19ca4ac5358ae..de9309d1e436f 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts
@@ -2,10 +2,150 @@
// Licensed under the MIT License.
import {TensorView} from '../../tensor-view';
-import {BroadcastUtil} from '../../util';
-import {ComputeContext} from '../types';
+import {BroadcastUtil, ShapeUtil} from '../../util';
+import {ComputeContext, ProgramInfo, ProgramUniform} from '../types';
import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu';
+import {createTensorShapeVariables, getBroadcastDims, getMaxComponents, IndicesHelper, inputVariable, internalVariable, outputVariable, ShaderHelper,} from './common';
+import {getActivationSnippet, InternalActivationAttributes} from './fuse-utils';
+
+export const createNaiveMatmulProgramInfo =
+ (inputs: readonly TensorView[], activationAttributes: InternalActivationAttributes, outputShape: readonly number[],
+ reshapedOutputShape?: readonly number[],
+ isChannelsLast = false /* only used for conv2dByMatMul*/): ProgramInfo => {
+ const aShape = inputs[0].dims;
+ const bShape = inputs[1].dims;
+
+ const M = aShape[aShape.length - 2];
+ const N = bShape[bShape.length - 1];
+ const K = aShape[aShape.length - 1];
+ const components = getMaxComponents(N);
+ const aComponents = getMaxComponents(K);
+ const outputNumber = getMaxComponents(M);
+ const outputSize = ShapeUtil.size(outputShape) / components / outputNumber;
+ const hasBias = inputs.length > 2;
+ const outerDims = reshapedOutputShape ? reshapedOutputShape.slice(0, -2) : outputShape.slice(0, -2);
+ const batchSize = ShapeUtil.size(outerDims);
+ const outputShapeInShader = [batchSize, M, N];
+ const programUniforms: ProgramUniform[] = [
+ {type: 'uint32', data: outputSize}, {type: 'uint32', data: M}, {type: 'uint32', data: N},
+ {type: 'uint32', data: K}, ...createTensorShapeVariables(outerDims), ...createTensorShapeVariables(aShape),
+ ...createTensorShapeVariables(bShape)
+ ];
+ if (hasBias) {
+ programUniforms.push(...createTensorShapeVariables(inputs[2].dims));
+ }
+ programUniforms.push(...createTensorShapeVariables(outputShapeInShader));
+
+ const getShaderSource = (shaderHelper: ShaderHelper) => {
+ const batchDims = internalVariable('batch_dims', inputs[0].dataType, outerDims.length);
+ const a = inputVariable('a', inputs[0].dataType, aShape.length, aComponents);
+ const b = inputVariable('b', inputs[1].dataType, bShape.length, components);
+ const output = outputVariable('output', inputs[0].dataType, outputShapeInShader.length, components);
+ const {activationFunction, applyActivation} = getActivationSnippet(activationAttributes, output.type.value);
+ const inputVariables = [a, b];
+ let processBias = '';
+ if (hasBias) {
+ const biasComponents = isChannelsLast ? components : 1;
+ inputVariables.push(inputVariable('bias', inputs[2].dataType, inputs[2].dims.length, biasComponents));
+ processBias = `${
+ isChannelsLast ? `value += bias[col / ${biasComponents}];` :
+ `value += ${output.type.value}(bias[row + i]);`}`;
+ }
+
+ const outerDimsA = aShape.slice(0, -2);
+ const outerDimsB = bShape.slice(0, -2);
+ const broadCastADims = getBroadcastDims(outerDimsA, outerDims);
+ const broadCastBDims = getBroadcastDims(outerDimsB, outerDims);
+ const getIndices = (variable: IndicesHelper, broadCastDims: number[]) => {
+ const rank = variable.rank;
+ const name = variable.name;
+ if (rank === 2) {
+ return `var ${name}_indices = ${variable.type.indices}(0u, 0u);`;
+ }
+ const batchRank = batchDims.rank;
+ let resStr = `var ${name}_indices: ${variable.type.indices};`;
+ for (let i = rank - 2 - 1, j = batchRank - 1; i >= 0; i--, j--) {
+ resStr += `\n${name}_indices[${i}] = ${batchRank > 1 ? `batch_indices[${j}]` : 'batch_indices'};`;
+ }
+ broadCastDims.forEach(i => {
+ resStr += `\n${name}_indices[${i}] = 0;`;
+ });
+ resStr += `${name}_indices[${rank - 2}] = 0u;
+ ${name}_indices[${rank - 1}] = 0u;`;
+ return resStr;
+ };
+
+ const calcResult = (): string => {
+ let calcStr = `var a_data: ${a.type.value};`;
+ for (let i = 0; i < aComponents; i++) {
+ calcStr += `
+ let b_data${i} = b[(b_offset + (k + ${i}) * uniforms.N + col) / ${components}];`;
+ }
+ for (let i = 0; i < outputNumber; i++) {
+ calcStr += `a_data = a[(a_offset + (row + ${i}) * uniforms.K + k) / ${aComponents}];`;
+
+ for (let j = 0; j < aComponents; j++) {
+ calcStr += `
+ values[${i}] = fma(${b.type.value}(a_data${aComponents === 1 ? '' : `[${j}]`}), b_data${j}, values[${
+ i}]);\n`;
+ }
+ }
+ return calcStr;
+ };
+
+ return `
+ ${
+ shaderHelper.registerUniform('outputSize', 'u32')
+ .registerUniform('M', 'u32')
+ .registerUniform('N', 'u32')
+ .registerUniform('K', 'u32')
+ .registerInternalVariables(batchDims)
+ .declareVariables(...inputVariables, output)}
+ ${activationFunction}
+ ${shaderHelper.mainStart()}
+ ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
+ let col = (global_idx % (uniforms.N / ${components})) * ${components};
+ var index1 = global_idx / (uniforms.N / ${components});
+ let stride1 = uniforms.M / ${outputNumber};
+ let row = (index1 % stride1) * ${outputNumber};
+ let batch = index1 / stride1;
+
+ ${outputShape.length === 2 ? '' : `let batch_indices = ${batchDims.offsetToIndices('batch')};`}
+ ${getIndices(a, broadCastADims)}
+ let a_offset = ${a.indicesToOffset('a_indices')};
+ ${getIndices(b, broadCastBDims)}
+ let b_offset = ${b.indicesToOffset('b_indices')};
+ var values: array<${output.type.value}, ${outputNumber}>;
+ for (var k: u32 = 0u; k < uniforms.K; k = k + ${aComponents}) {
+ ${calcResult()}
+ }
+ for (var i = 0u; i < ${outputNumber}u; i++) {
+ var value = values[i];
+ ${processBias}
+ ${applyActivation}
+ let cur_indices = ${output.type.indices}(batch, row + i, col);
+ let offset = ${output.indicesToOffset('cur_indices')};
+ ${output.setByOffset(`offset / ${components}`, 'value')};
+ }
+ }
+ `;
+ };
+ return {
+ name: 'MatMulNaive',
+ shaderCache: {
+ hint: `${activationAttributes.activationCacheKey}_${components}_${aComponents}_${outputNumber}_${
+ isChannelsLast}`,
+ inputDependencies: hasBias ? ['rank', 'rank', 'rank'] : ['rank', 'rank']
+ },
+ getRunData: () => ({
+ outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
+ dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
+ programUniforms
+ }),
+ getShaderSource
+ };
+ };
const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 2) {
@@ -23,5 +163,12 @@ export const matMul = (context: ComputeContext): void => {
if (!outputShape) {
throw new Error('Can\'t use matmul on the given tensors');
}
- context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape));
+ const N = outputShape[outputShape.length - 1];
+ const K = context.inputs[0].dims[context.inputs[0].dims.length - 1];
+ if (N < 8 && K < 8) {
+ context.compute(
+ createNaiveMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape));
+ } else {
+ context.compute(createMatmulProgramInfo(context.inputs, {activation: '', activationCacheKey: ''}, outputShape));
+ }
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts
index 1538644412afd..84d04efc37f28 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts
@@ -1,12 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+import {env} from 'onnxruntime-common';
+
import {TensorView} from '../../tensor-view';
import {PoolConvUtil, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
-import {ComputeContext, ProgramInfo} from '../types';
+import {ComputeContext, ProgramInfo, ProgramInputTensorInfoDependency, ProgramUniform} from '../types';
-import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
+import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
// TODO: support:
// - ceil_mode "test_maxpool_2d_ceil"
@@ -15,12 +17,9 @@ import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './comm
// - [MaxPool] output[1] "test_maxpool_with_argmax_2d_precomputed_pads"
const validateInputs = (inputs: readonly TensorView[]): void => {
- if (!inputs || inputs.length !== 1) {
+ if (env.webgpu.validateInputContent && (!inputs || inputs.length !== 1)) {
throw new Error('Pool ops requires 1 input.');
}
- if (inputs[0].dims.length !== 4 && inputs[0].dims.length !== 3) {
- throw new Error('Pool ops supports 1-D or 2-D inputs only for now.');
- }
};
const getAdjustedPoolAttributesAndOutputShape = (
@@ -51,30 +50,83 @@ const getAdjustedPoolAttributesAndOutputShape = (
- shaderHelper: ShaderHelper, x: IndicesHelper, xShape: readonly number[], outputShape: readonly number[],
- attributes: AttributeType, op1: string, op2: string, start: string): string => {
+const getUniformAndPadInfo = (
+ outputShape: readonly number[],
+ attributes: AttributeType): [ProgramUniform[], UniformsArrayType, boolean, boolean, boolean] => {
const isChannelsLast = attributes.format === 'NHWC';
- const inputDims = xShape;
- const dataType = x.type.value;
- const rank = inputDims.length;
const outputSize = ShapeUtil.size(outputShape);
- const output = outputVariable('output', x.type.tensor, outputShape);
-
+ const kernelSize = ShapeUtil.size(attributes.kernelShape);
+ const programUniforms: ProgramUniform[] = [{type: 'uint32', data: outputSize}, {type: 'uint32', data: kernelSize}];
+ const uniforms: UniformsArrayType = [{name: 'outputSize', type: 'u32'}, {name: 'kernelSize', type: 'u32'}];
if (attributes.kernelShape.length <= 2) {
const kw = attributes.kernelShape[attributes.kernelShape.length - 1];
const sw = attributes.strides[attributes.strides.length - 1];
const pwStart = attributes.pads[attributes.pads.length / 2 - 1];
const pwEnd = attributes.pads[attributes.pads.length - 1];
- const dimIdxW = rank - (isChannelsLast ? 2 : 1);
+ const pwStartEnd = !!(pwStart + pwEnd);
+ programUniforms.push(
+ {type: 'uint32', data: kw},
+ {type: 'uint32', data: sw},
+ {type: 'uint32', data: pwStart},
+ {type: 'uint32', data: pwEnd},
+ );
+ uniforms.push(
+ {name: 'kw', type: 'u32'}, {name: 'sw', type: 'u32'}, {name: 'pwStart', type: 'u32'},
+ {name: 'pwEnd', type: 'u32'});
+
+ let phStartEnd = false;
+ if (attributes.kernelShape.length === 2) {
+ const kh = attributes.kernelShape[attributes.kernelShape.length - 2];
+ const sh = attributes.strides[attributes.strides.length - 2];
+ const phStart = attributes.pads[attributes.pads.length / 2 - 2];
+ const phEnd = attributes.pads[attributes.pads.length - 2];
+ phStartEnd = !!(phStart + phEnd);
+ programUniforms.push(
+ {type: 'uint32', data: kh}, {type: 'uint32', data: sh}, {type: 'uint32', data: phStart},
+ {type: 'uint32', data: phEnd});
+
+ uniforms.push(
+ {name: 'kh', type: 'u32'}, {name: 'sh', type: 'u32'}, {name: 'phStart', type: 'u32'},
+ {name: 'phEnd', type: 'u32'});
+ }
+ return [programUniforms, uniforms, true, pwStartEnd, phStartEnd];
+ } else {
+ if (isChannelsLast) {
+ throw new Error('Pooling with kernelShape.length > 2 is not supported for NHWC format.');
+ }
+ const kernelStrides = ShapeUtil.computeStrides(attributes.kernelShape);
+ programUniforms.push(
+ {type: 'uint32', data: kernelStrides}, {type: 'uint32', data: attributes.pads},
+ {type: 'uint32', data: attributes.strides});
+ uniforms.push(
+ {name: 'kernelStrides', type: 'u32', length: kernelStrides.length},
+ {name: 'pads', type: 'u32', length: attributes.pads.length},
+ {name: 'strides', type: 'u32', length: attributes.strides.length});
+
+ const hasPads = attributes.pads.reduce((sum, cur) => sum + cur);
+ return [programUniforms, uniforms, !!hasPads, false, false];
+ }
+};
+
+const generatePoolingCode = (
+ shaderHelper: ShaderHelper, x: IndicesHelper, rank: number, outputShapeRank: number, attributes: AttributeType,
+ op1: string, op2: string, start: number, uniforms: UniformsArrayType, hasPads: boolean, pwStartEnd: boolean,
+ phStartEnd: boolean): string => {
+ const isChannelsLast = attributes.format === 'NHWC';
+ const dataType = x.type.value;
+ const output = outputVariable('output', x.type.tensor, outputShapeRank);
+
+ if (attributes.kernelShape.length <= 2) {
let codeW = '';
let codeH = '';
let codeHEnd = '';
- if (pwStart + pwEnd !== 0) {
+ const dimIdxW = rank - (isChannelsLast ? 2 : 1);
+ if (pwStartEnd === true) {
codeW = `
- for (var i: u32 = 0u; i < ${kw}u; i++) {
- xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
- if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= ${inputDims[dimIdxW]}) {
+ for (var i: u32 = 0u; i < uniforms.kw; i++) {
+ xIndices[${dimIdxW}] = indices[${dimIdxW}] * uniforms.sw - uniforms.pwStart + i;
+ if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}]
+ >= uniforms.x_shape[${dimIdxW}]) {
pad++;
continue;
}
@@ -83,33 +135,28 @@ const generatePoolingCode = = ${dimH}) {
- pad+= ${kw};
+ for (var j: u32 = 0u; j < uniforms.kh; j++) {
+ xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j;
+ if (xIndices[${dimIdxH}] < 0 || xIndices[${dimIdxH}] >= uniforms.x_shape[${dimIdxH}]) {
+ pad += i32(uniforms.kw);
continue;
}
`;
} else {
codeH = `
- for (var j: u32 = 0u; j < ${kh}u; j++) {
- xIndices[${dimIdxH}] = indices[${dimIdxH}] * ${sh} - ${phStart} + j;
+ for (var j: u32 = 0u; j < uniforms.kh; j++) {
+ xIndices[${dimIdxH}] = indices[${dimIdxH}] * uniforms.sh - uniforms.phStart + j;
`;
}
codeHEnd = `
@@ -118,15 +165,15 @@ const generatePoolingCode = 2 is not supported for NHWC format.');
}
- const kernelSize = ShapeUtil.size(attributes.kernelShape);
- const kernelStrides = ShapeUtil.computeStrides(attributes.kernelShape);
- const stridesRank = kernelStrides.length;
+ const stridesRank = attributes.kernelShape.length;
const padsRank = attributes.pads.length;
- const hasPads = attributes.pads.reduce((sum, cur) => sum + cur);
let padCode = '';
if (hasPads) {
padCode = `
- if (xIndices[j] >= inputDims[j]) {
+ if (xIndices[j] >= uniforms.x_shape[j]) {
pad++;
isPad = true;
break;
@@ -166,37 +210,32 @@ const generatePoolingCode = (${attributes.pads.map(i => `${i}u`).join(',')});
- const inputDims = array(${inputDims.map(i => `${i}u`).join(',')});
- const kernelStrides = array(${kernelStrides.map(i => `${i}u`).join(',')});
- const strides = array(${attributes.strides.map(i => `${i}u`).join(',')});
+ ${shaderHelper.registerUniforms(uniforms).declareVariables(x, output)}
${shaderHelper.mainStart()}
- ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
-
+ ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
let indices = ${output.offsetToIndices('global_idx')};
- let xIndices = ${output.offsetToIndices('global_idx')};
+ var xIndices = ${output.offsetToIndices('global_idx')};
var offsets: array;
- var value = ${output.type.value}(${start});
+ var value = ${dataType}(${start});
var pad = 0;
var isPad = false;
- for (var i: u32 = 0u; i < ${kernelSize}u; i++) {
+ for (var i: u32 = 0u; i < uniforms.kernelSize; i++) {
var offset = i;
for (var j = 0u; j < ${stridesRank - 1}u; j++) {
- offsets[j] = offset / kernelStrides[j];
- offset -= offsets[j] * kernelStrides[j];
+ offsets[j] = offset / ${getElementAt('uniforms.kernelStrides', 'j', stridesRank)};
+ offset -= offsets[j] * ${getElementAt('uniforms.kernelStrides', 'j', stridesRank)};
}
offsets[${stridesRank - 1}] = offset;
isPad = false;
for (var j = ${rank - stridesRank}u; j < ${rank}u; j++) {
- xIndices[j] = indices[j] * strides[j - ${rank - stridesRank}u]
- + offsets[j - ${rank - stridesRank}u] - pads[j - 2u];
+ xIndices[j] = indices[j] * ${
+ getElementAt('uniforms.strides', `j - ${rank - stridesRank}u`, stridesRank)}
+ + offsets[j - ${rank - stridesRank}u] - ${getElementAt('uniforms.pads', 'j - 2u', padsRank)};
${padCode}
}
${op2}
@@ -236,27 +275,35 @@ const createAveragePoolProgramInfo =
(name: string, input: TensorView, isGlobalOperator: boolean, attributes: AveragePoolAttributes): ProgramInfo => {
const [adjustedAttributes, outputShape] =
getAdjustedPoolAttributesAndOutputShape(input, attributes, isGlobalOperator);
- const kernelSize = ShapeUtil.size(adjustedAttributes.kernelShape);
-
- const x = inputVariable('x', input.dataType, input.dims);
+ const x = inputVariable('x', input.dataType, input.dims.length);
const dataType = x.type.value;
const op1 = 'value += x_val;';
let op2 = '';
if (adjustedAttributes.countIncludePad) {
- op2 += `value /= ${dataType}(${kernelSize});`;
+ op2 += `value /= ${dataType}(uniforms.kernelSize);`;
} else {
- op2 += `value /= ${dataType}(${kernelSize} - pad);`;
+ op2 += `value /= ${dataType}(i32(uniforms.kernelSize) - pad);`;
}
+ const [programUniforms, uniforms, hasPads, pwStartEnd, phStartEnd] =
+ getUniformAndPadInfo(outputShape, adjustedAttributes);
+ programUniforms.push(...createTensorShapeVariables(input.dims));
+ programUniforms.push(...createTensorShapeVariables(outputShape));
+ const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
return {
name,
- shaderCache: {hint: attributes.cacheKey},
+ shaderCache: {
+ hint: attributes.cacheKey + hasPads + pwStartEnd + phStartEnd + adjustedAttributes.countIncludePad,
+ inputDependencies
+ },
getRunData: () => ({
outputs: [{dims: outputShape, dataType: input.dataType}],
- dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}
+ dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)},
+ programUniforms
}),
- getShaderSource: shaderHelper =>
- generatePoolingCode(shaderHelper, x, input.dims, outputShape, adjustedAttributes, op1, op2, '0.0'),
+ getShaderSource: shaderHelper => generatePoolingCode(
+ shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, 0.0, uniforms,
+ hasPads, pwStartEnd, phStartEnd),
};
};
@@ -312,16 +359,23 @@ const createMaxPoolProgramInfo =
value = max(x_val, value);
`;
const op2 = '';
- const x = inputVariable('x', input.dataType, input.dims);
+ const x = inputVariable('x', input.dataType, input.dims.length);
+ const inputDependencies: ProgramInputTensorInfoDependency[] = ['rank'];
+ const [programUniforms, uniforms, hasPads, pwStartEnd, phStartEnd] =
+ getUniformAndPadInfo(outputShape, adjustedAttributes);
+ programUniforms.push(...createTensorShapeVariables(input.dims));
+ programUniforms.push(...createTensorShapeVariables(outputShape));
return {
name,
- shaderCache: {hint: attributes.cacheKey},
+ shaderCache: {hint: attributes.cacheKey + hasPads, inputDependencies},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: input.dataType}],
- dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)}
+ dispatchGroup: {x: Math.ceil(ShapeUtil.size(outputShape) / 64 /* workgroup size */)},
+ programUniforms
}),
- getShaderSource: shaderHelper =>
- generatePoolingCode(shaderHelper, x, input.dims, outputShape, adjustedAttributes, op1, op2, '-1e5'),
+ getShaderSource: shaderHelper => generatePoolingCode(
+ shaderHelper, x, input.dims.length, outputShape.length, adjustedAttributes, op1, op2, -1e5, uniforms,
+ hasPads, pwStartEnd, phStartEnd),
};
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts
index b5c956e57a9b1..e8851ac546942 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/reduce.ts
@@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramShaderCacheInfo} from '../types';
-import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
+import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
import {reduceL1Shared, reduceL2Shared, reduceLogSumExpShared, reduceLogSumShared, reduceMaxShared, reduceMeanShared, reduceMinShared, reduceProdShared, reduceSumShared, reduceSumSquareShared} from './reduce-shared';
const validateInputs = (inputs: readonly TensorView[]): void => {
@@ -30,14 +30,14 @@ export type ReduceOp =
(input: IndicesHelper, output: IndicesHelper,
axes: readonly number[]) => [string, string, string, string, ...string[]];
-const noOp: ReduceOp = (input) => ['', '', `var value = ${input.getByOffset('inputOffset')};`, ''];
+const noOp: ReduceOp = (input) => ['', '', `var value = ${input.getByIndices('input_indices')};`, ''];
export const createReduceProgramInfo =
(name: string, shaderCache: ProgramShaderCacheInfo, inputs: readonly TensorView[], reduceOp: ReduceOp,
axesInput: number[], outputDataType: DataType, keepDims = false, noopWithEmptyAxes = false): ProgramInfo => {
const outputShape: number[] = [];
const inputShape = inputs[0].dims;
-
- const axes = ShapeUtil.normalizeAxes(axesInput, inputs[0].dims.length);
+ const inputRank = inputShape.length;
+ const axes = ShapeUtil.normalizeAxes(axesInput, inputRank);
const reduceOnAllAxes = !noopWithEmptyAxes && axes.length === 0;
inputShape.forEach((d, i) => {
if (reduceOnAllAxes || axes.indexOf(i) >= 0) {
@@ -48,53 +48,50 @@ export const createReduceProgramInfo =
outputShape.push(d);
}
});
-
- const idxCopy: string[] = []; // copy output indexes to input indexes
-
- const input = inputVariable('_A', inputs[0].dataType, inputShape);
- const output = outputVariable('output', outputDataType, outputShape);
- const ops = reduceOp(input, output, axes);
- const inputOffsetAssignment = `inputOffset = ${input.indicesToOffset('inputIndices')};`;
- const initinputOffsetLet = `let ${inputOffsetAssignment};`;
- const initinputOffsetVar = `var ${inputOffsetAssignment};`;
- const initinputOffset = (ops[1] === '') ? '' : initinputOffsetVar;
- let reduceOps = ((ops[1] === '') ? initinputOffsetLet : inputOffsetAssignment) + '\n' + ops[2];
-
- for (let k = 0, l = 0; k < inputs[0].dims.length; k++) {
- // if this axis is reduced
- if (reduceOnAllAxes || axes.indexOf(k) >= 0) {
- if (keepDims) {
+ const outputRank = outputShape.length;
+ const outputSize = ShapeUtil.size(outputShape);
+ const getShaderSource = (shaderHelper: ShaderHelper) => {
+ const idxCopy: string[] = []; // copy output indexes to input indexes
+
+ const input = inputVariable('_A', inputs[0].dataType, inputRank);
+ const output = outputVariable('output', outputDataType, outputRank);
+ const ops = reduceOp(input, output, axes);
+ let reduceOps = ops[2];
+
+ for (let k = 0, l = 0; k < inputRank; k++) {
+ // if this axis is reduced
+ if (reduceOnAllAxes || axes.indexOf(k) >= 0) {
+ if (keepDims) {
+ l++;
+ }
+ // loop over the d-th axis
+ reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputShape[k]}; j${k}++) {
+ ${ops[2].includes('last_index') ? `let last_index = j${k};` : ''}
+ ${input.indicesSet('input_indices', k, `j${k}`)}
+ ${reduceOps}
+ }`;
+ } else {
+ idxCopy.push(`${input.indicesSet('input_indices', k, output.indicesGet('output_indices', l))};`);
l++;
}
- // loop over the d-th axis
- reduceOps = `for(var j${k}: u32 = 0; j${k} < ${inputs[0].dims[k]}; j${k}++) {
- ${ops[2].includes('lastIndex') ? `let lastIndex = j${k};` : ''}
- ${input.indicesSet('inputIndices', k, `j${k}`)}
- ${reduceOps}
- }`;
- } else {
- idxCopy.push(`${input.indicesSet('inputIndices', k, output.indicesGet('outputIndices', l))};`);
- l++;
}
- }
+ return `
- const outputSize = ShapeUtil.size(outputShape);
- const getShaderSource = (shaderHelper: ShaderHelper) => `
- ${shaderHelper.declareVariables(input, output)}
+ ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
${shaderHelper.mainStart()}
- ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
- var inputIndices: ${input.type.indices};
- let outputIndices = ${output.offsetToIndices('global_idx')};
+ ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
+ var input_indices: ${input.type.indices};
+ let output_indices = ${output.offsetToIndices('global_idx')};
${idxCopy.join('\n')}
${ops[0]} // init ops for reduce max/min
- ${initinputOffset}
${ops[1]}
${reduceOps}
${ops[3]}
${ops.length === 4 ? output.setByOffset('global_idx', 'value') : ops.slice(4).join('\n')}
}`;
+ };
return {
name,
@@ -102,7 +99,11 @@ export const createReduceProgramInfo =
getShaderSource,
getRunData: () => ({
outputs: [{dims: outputShape, dataType: outputDataType}],
- dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
+ dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
+ programUniforms: [
+ {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputShape),
+ ...createTensorShapeVariables(outputShape)
+ ]
}),
};
};
@@ -125,7 +126,7 @@ const runReduceProgram =
context.compute(
createReduceProgramInfo(
- name, {hint: updatedAttributes.cacheKey}, [inputs[0]],
+ name, {hint: updatedAttributes.cacheKey, inputDependencies: ['rank']}, [inputs[0]],
updatedAttributes.noopWithEmptyAxes && updatedAttributes.axes.length === 0 ? noOp : reduceOp,
updatedAttributes.axes, inputs[0].dataType, updatedAttributes.keepDims,
updatedAttributes.noopWithEmptyAxes),
@@ -137,7 +138,7 @@ const reduceLogSumNaive = (context: ComputeContext, attributes: ReduceAttributes
const reduceOp: ReduceOp = (input, output) =>
[`var value = ${output.type.storage}(0);`,
'',
- `value += ${input.getByOffset('inputOffset')};`,
+ `value += ${input.getByIndices('input_indices')};`,
'value = log(value);',
];
runReduceProgram(context, 'ReduceLogSum', attributes, reduceOp);
@@ -148,7 +149,7 @@ const reduceL1Naive = (context: ComputeContext, attributes: ReduceAttributes): v
const reduceOp: ReduceOp = (input, output) =>
[`var value = ${output.type.storage}(0);`,
'',
- `value += abs(${input.getByOffset('inputOffset')});`,
+ `value += abs(${input.getByIndices('input_indices')});`,
'',
];
runReduceProgram(context, 'ReduceL1', attributes, reduceOp);
@@ -159,7 +160,7 @@ const reduceL2Naive = (context: ComputeContext, attributes: ReduceAttributes): v
const reduceOp: ReduceOp = (input, output) =>
[`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`,
'',
- `t = ${input.getByOffset('inputOffset')}; value += (t * t);`,
+ `t = ${input.getByIndices('input_indices')}; value += (t * t);`,
'value = sqrt(value);',
];
runReduceProgram(context, 'ReduceL2', attributes, reduceOp);
@@ -170,7 +171,7 @@ const reduceLogSumExpNaive = (context: ComputeContext, attributes: ReduceAttribu
const reduceOp: ReduceOp = (input, output) =>
[`var value = ${output.type.storage}(0);`,
'',
- `value += exp(${input.getByOffset('inputOffset')});`,
+ `value += exp(${input.getByIndices('input_indices')});`,
'value = log(value);',
];
runReduceProgram(context, 'ReduceLogSumExp', attributes, reduceOp);
@@ -182,14 +183,14 @@ const reduceMaxNaive = (context: ComputeContext, attributes: ReduceAttributes):
const idxZero = [];
for (let k = 0; k < input.rank; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
- idxZero.push(input.indicesSet('inputIndices', k, 0));
+ idxZero.push(input.indicesSet('input_indices', k, 0));
}
}
return [
`${idxZero.join('\n')}`,
- `var value = ${input.getByOffset('inputOffset')};`,
- `value = max(value, ${input.getByOffset('inputOffset')});`,
+ `var value = ${input.getByIndices('input_indices')};`,
+ `value = max(value, ${input.getByIndices('input_indices')});`,
'',
];
};
@@ -210,7 +211,7 @@ const reduceMeanNaive = (context: ComputeContext, attributes: ReduceAttributes):
return [
'var sum = f32(0);',
'',
- `sum += f32(${input.getByOffset('inputOffset')});`,
+ `sum += f32(${input.getByIndices('input_indices')});`,
`let value = ${output.type.value}(sum / ${size});`,
];
};
@@ -223,14 +224,14 @@ const reduceMinNaive = (context: ComputeContext, attributes: ReduceAttributes):
const idxZero = [];
for (let k = 0; k < input.rank; k++) {
if (axes.indexOf(k) >= 0 || axes.length === 0) {
- idxZero.push(`inputIndices[${k}] = 0;`); // first element
+ idxZero.push(`input_indices[${k}] = 0;`); // first element
}
}
return [
`${idxZero.join('\n')}`,
- `var value = ${input.getByOffset('inputOffset')};`,
- `value = min(value, ${input.getByOffset('inputOffset')});`,
+ `var value = ${input.getByIndices('input_indices')};`,
+ `value = min(value, ${input.getByIndices('input_indices')});`,
'',
];
};
@@ -242,7 +243,7 @@ const reduceProdNaive = (context: ComputeContext, attributes: ReduceAttributes):
const reduceOp: ReduceOp = (input, output) =>
[`var value = ${output.type.storage}(1);`,
'',
- `value *= ${input.getByOffset('inputOffset')};`,
+ `value *= ${input.getByIndices('input_indices')};`,
'',
];
runReduceProgram(context, 'ReduceProd', attributes, reduceOp);
@@ -253,7 +254,7 @@ const reduceSumNaive = (context: ComputeContext, attributes: ReduceAttributes):
const reduceOp: ReduceOp = (input, output) =>
[`var value = ${output.type.storage}(0);`,
'',
- `value += ${input.getByOffset('inputOffset')};`,
+ `value += ${input.getByIndices('input_indices')};`,
'',
];
runReduceProgram(context, 'ReduceSum', attributes, reduceOp);
@@ -264,7 +265,7 @@ const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttribu
const reduceOp: ReduceOp = (input, output) =>
[`var t = ${output.type.value}(0); var value = ${output.type.value}(0);`,
'',
- `t = ${input.getByOffset('inputOffset')}; value += t * t;`,
+ `t = ${input.getByIndices('input_indices')}; value += t * t;`,
'',
];
runReduceProgram(context, 'ReduceSumSquare', attributes, reduceOp);
@@ -273,7 +274,7 @@ const reduceSumSquareNaive = (context: ComputeContext, attributes: ReduceAttribu
const useNaiveReduceMethod =
(shape: readonly number[], axes: readonly number[], noopWithEmptyAxes: boolean): boolean => {
if (axes.length === 0) {
- return noopWithEmptyAxes ? true : false;
+ return noopWithEmptyAxes;
}
let outputSize = 1;
@@ -289,7 +290,7 @@ const useNaiveReduceMethod =
// The condition data is very rough, although considering the count of Execution Unit (EU), the potential
// work groups in a EU and the counts of loops in the naive and shared methods, also doing experiments
// on some machines.
- return reduceSize < 32 && outputSize > 1024 ? true : false;
+ return reduceSize < 32 && outputSize > 1024;
};
export const reduceMean = (context: ComputeContext, attributes: ReduceAttributes): void => {
@@ -371,6 +372,3 @@ export const reduceLogSum = (context: ComputeContext, attributes: ReduceAttribut
reduceLogSumShared(context, attributes);
}
};
-
-export const parseReduceAttributes = (attributes: Record): ReduceAttributes =>
- createAttributeWithCacheKey(attributes as Omit);
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts
index 973a607f9377e..e1369c2c2b43b 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/resize.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/resize.ts
@@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
-import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
+import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
type CoordinateTransformMode = 'half_pixel'|'asymmetric'|'pytorch_half_pixel'|'tf_half_pixel_for_nn'|'align_corners'|
'tf_crop_and_resize'|'half_pixel_symmetric';
@@ -245,69 +245,67 @@ const adjustOutputShape = (inputShape: readonly number[], scales: number[], attr
};
const calculateOriginalIndicesFromOutputIndices =
- (output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scales: readonly number[],
- roi: readonly number[]): string => `
- fn calculateOriginalIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> array<${
+ (output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[], scalesLength: number,
+ roiLength: number): string => `
+ fn calculateOriginalIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> array<${
output.type.value}, ${outputShape.length}> {
- const inputShape = array(${inputShape.map(i => `${i}u`).join(',')});
- const outputShape = array(${outputShape.map(i => `${i}u`).join(',')});
- const scales = array<${output.type.value}, ${scales.length}>(${scales.map(i => `${i}f`).join(',')});
- const roi = array<${output.type.value}, ${roi.length}>(${roi.map(i => `${i}f`).join(',')});
- var originalIndices: array<${output.type.value}, ${outputShape.length}>;
+ var original_indices: array<${output.type.value}, ${outputShape.length}>;
for (var i:u32 = 0; i < ${outputShape.length}; i++) {
- var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
- if (scales[i] == 1.0) {
- originalIndices[i] = ${output.type.value}(outputIndex);
+ var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')});
+ var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)};
+ var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)};
+ var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)};
+ if (scale == 1.0) {
+ original_indices[i] = output_index;
} else {
- originalIndices[i] = getOriginalCoordinateFromResizedCoordinate(${output.type.value}(outputIndex), scales[i],
- ${output.type.value}(outputShape[i]), ${output.type.value}(inputShape[i]), roi[i], roi[i + ${
- inputShape.length}]);
+ var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)});
+ var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)});
+ original_indices[i] = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i,
+ input_shape_i, roi_low, roi_hi);
}
}
- return originalIndices;
+ return original_indices;
}`;
const calculateInputIndicesFromOutputIndices =
(input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[],
- scales: readonly number[], roi: readonly number[], useExtrapolation: boolean): string => `
- fn calculateInputIndicesFromOutputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
- const inputShape = array(${inputShape.map(i => `${i}u`).join(',')});
- const outputShape = array(${outputShape.map(i => `${i}u`).join(',')});
- const scales = array<${input.type.value}, ${scales.length}>(${scales.map(i => `${i}`).join(',')});
- const roi = array<${input.type.value}, ${roi.length}>(${roi.map(i => `${i}`).join(',')});
- var inputIndices: ${input.type.indices};
- for (var i:u32 = 0; i < ${outputShape.length}; i++) {
- var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
- var inputIndex: u32;
- if (scales[i] == 1.0) {
- inputIndex = outputIndex;
- } else {
- var original_idx = getOriginalCoordinateFromResizedCoordinate(${input.type.value}(outputIndex), scales[i],
- ${input.type.value}(outputShape[i]), ${input.type.value}(inputShape[i]), roi[i], roi[i + ${
- inputShape.length}]);
- if (!${useExtrapolation} || (original_idx >= 0 && original_idx < ${input.type.value}(inputShape[i]))) {
- if (original_idx < 0) {
- inputIndex = 0;
- } else if (original_idx > (${input.type.value}(inputShape[i]) - 1)) {
- inputIndex = inputShape[i] - 1;
- } else {
- inputIndex = u32(getNearestPixelFromOriginal(original_idx, scales[i] < 1));
- }
+ scalesLength: number, roiLength: number, useExtrapolation: boolean): string => `
+ fn calculateInputIndicesFromOutputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} {
+ var input_indices: ${input.type.indices};
+ for (var i:u32 = 0; i < ${outputShape.length}; i++) {
+ var output_index = ${output.type.value}(${output.indicesGet('output_indices', 'i')});
+ var input_index: u32;
+ var scale = ${getElementAt('uniforms.scales', 'i', scalesLength)};
+ if (scale == 1.0) {
+ input_index = u32(output_index);
+ } else {
+ var roi_low = ${getElementAt('uniforms.roi', 'i', roiLength)};
+ var roi_hi = ${getElementAt('uniforms.roi', `i + ${inputShape.length}`, roiLength)};
+ var input_shape_i = ${output.type.value}(${getElementAt('uniforms.input_shape', 'i', inputShape.length)});
+ var output_shape_i = ${output.type.value}(${getElementAt('uniforms.output_shape', 'i', outputShape.length)});
+ var original_idx = getOriginalCoordinateFromResizedCoordinate(output_index, scale, output_shape_i,
+ input_shape_i, roi_low, roi_hi);
+ if (!${useExtrapolation} || (original_idx >= 0 && original_idx < input_shape_i)) {
+ if (original_idx < 0) {
+ input_index = 0;
+ } else if (original_idx > (input_shape_i - 1)) {
+ input_index = u32(input_shape_i) - 1;
} else {
- inputIndex = u32(original_idx);
+ input_index = u32(getNearestPixelFromOriginal(original_idx, scale < 1));
}
+ } else {
+ input_index = u32(original_idx);
}
- ${input.indicesSet('inputIndices', 'i', 'inputIndex')}
}
- return inputIndices;
+ ${input.indicesSet('input_indices', 'i', ' input_index')}
+ }
+ return input_indices;
}`;
-
const checkInputIndices = (input: IndicesHelper, inputShape: readonly number[]): string => `
- fn checkInputIndices(inputIndices: ${input.type.indices}) -> bool {
- const inputShape = array(${inputShape.map(i => `${i}u`).join(',')});
+ fn checkInputIndices(input_indices: ${input.type.indices}) -> bool {
for (var i:u32 = 0; i < ${inputShape.length}; i++) {
- var inputIndex = ${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'};
- if (inputIndex < 0 || inputIndex >= inputShape[i]) {
+ var input_index = ${input.indicesGet('input_indices', 'i')};
+ if (input_index < 0 || input_index >= ${getElementAt('uniforms.input_shape', 'i', inputShape.length)}) {
return false;
}
}
@@ -322,18 +320,18 @@ const bilinearInterpolation =
const dType = input.type.value;
return `
fn getInputValue(batch: u32, channel: u32, row: u32, col: u32) -> ${dType} {
- var inputIndices: ${input.type.indices};
- inputIndices[${heightIdx}] = max(0, min(row, ${inputShape[heightIdx]} - 1));
- inputIndices[${widthIdx}] = max(0, min(col, ${inputShape[widthIdx]} - 1));
+ var input_indices: ${input.type.indices};
+ ${input.indicesSet('input_indices', heightIdx, `max(0, min(row, ${inputShape[heightIdx]} - 1))`)};
+ ${input.indicesSet('input_indices', widthIdx, `max(0, min(col, ${inputShape[widthIdx]} - 1))`)};
if (${inputShape.length} > 2) {
- inputIndices[${channelIdx}] = channel;
- inputIndices[${batchIdx}] = batch;
+ ${input.indicesSet('input_indices', channelIdx, 'channel')};
+ ${input.indicesSet('input_indices', batchIdx, 'batch')};
};
- return input[${input.indicesToOffset('inputIndices')}];
+ return ${input.getByIndices('input_indices')};
}
- fn bilinearInterpolation(outputIndices: ${output.type.indices}) -> ${dType} {
- var originalIndices = calculateOriginalIndicesFromOutputIndices(outputIndices);
+ fn bilinearInterpolation(output_indices: ${output.type.indices}) -> ${dType} {
+ var originalIndices = calculateOriginalIndicesFromOutputIndices(output_indices);
var row:${dType} = originalIndices[${heightIdx}];
var col:${dType} = originalIndices[${widthIdx}];
if (${useExtrapolation} && (row < 0 || row > (${inputShape[heightIdx]} - 1) || col < 0 || col > ${
@@ -373,10 +371,10 @@ const bicubicInterpolation =
const createCubicInterpolationFunction = (idx: number): string => {
const direction = idx === heightIdx ? 'row' : 'col';
return `
- fn ${direction}CubicInterpolation(inputIndices: ${input.type.indices}, outputIndices: ${
+ fn ${direction}CubicInterpolation(input_indices: ${input.type.indices}, output_indices: ${
output.type.indices}) -> ${dType} {
- var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : `outputIndices[${idx}]`};
- var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(outputIndex), ${scales[idx]},
+ var output_index = ${output.indicesGet('output_indices', idx)};
+ var originalIdx: ${dType} = getOriginalCoordinateFromResizedCoordinate(${dType}(output_index), ${scales[idx]},
${dType}(${outputShape[idx]}), ${dType}(${inputShape[idx]}), ${roi[idx]}, ${roi[idx]} + ${inputShape.length});
var fractOriginalIdx: ${dType} = originalIdx - floor(originalIdx);
var coefs = getCubicInterpolationCoefs(fractOriginalIdx);
@@ -397,10 +395,11 @@ const bicubicInterpolation =
${direction} = max(0, min(${direction}, ${inputShape[idx]} - 1));
}
}
- var inputIndicesCopy: ${input.type.indices} = inputIndices;
- inputIndicesCopy[${idx}] = u32(${direction});
- data[i + 1] = ${idx === heightIdx ? `input[${input.indicesToOffset('inputIndicesCopy')}];` : `
- rowCubicInterpolation(inputIndicesCopy, outputIndices);`}
+ var input_indices_copy: ${input.type.indices} = input_indices;
+ ${input.indicesSet('input_indices_copy', idx, `u32(${direction})`)};
+ data[i + 1] = ${
+ idx === heightIdx ? input.getByIndices('input_indices_copy') :
+ 'rowCubicInterpolation(input_indices_copy, output_indices)'};
}
return cubicInterpolation1D(data, coefs);
}`;
@@ -429,9 +428,9 @@ const bicubicInterpolation =
return (x[0] * coefs[0] + x[1] * coefs[1]+ x[2] * coefs[2]+ x[3] * coefs[3]) / coefsSum;
}
- fn bicubicInterpolation(outputIndices: ${output.type.indices}) -> ${dType} {
- var inputIndices: ${input.type.indices} = outputIndices;
- return colCubicInterpolation(inputIndices, outputIndices);
+ fn bicubicInterpolation(output_indices: ${output.type.indices}) -> ${dType} {
+ var input_indices: ${input.type.indices} = output_indices;
+ return colCubicInterpolation(input_indices, output_indices);
}
`;
};
@@ -450,8 +449,8 @@ const createResizeProgramInfo =
outputShape = adjustOutputShape(inputShape, scales, attributes);
}
}
- const output = outputVariable('output', inputTensor.dataType, outputShape);
- const input = inputVariable('input', inputTensor.dataType, inputShape);
+ const output = outputVariable('output', inputTensor.dataType, outputShape.length);
+ const input = inputVariable('input', inputTensor.dataType, inputShape.length);
const outputSize = ShapeUtil.size(outputShape);
const noScale = inputShape.length === outputShape.length && inputShape.every((d, i) => d === outputShape[i]);
const useExtrapolation = attributes.coordinateTransformMode === 'tf_crop_and_resize';
@@ -467,11 +466,11 @@ const createResizeProgramInfo =
${getNearestPixelFromOriginal(attributes.nearestMode, opsetVersion, dataType)};
${
calculateInputIndicesFromOutputIndices(
- input, output, inputShape, outputShape, scales, roi, useExtrapolation)};
+ input, output, inputShape, outputShape, scales.length, roi.length, useExtrapolation)};
`;
case 'linear':
return `
- ${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales, roi)};
+ ${calculateOriginalIndicesFromOutputIndices(output, inputShape, outputShape, scales.length, roi.length)};
${
bilinearInterpolation(
input, output, inputShape, scales, useExtrapolation, attributes.extrapolationValue)};
@@ -488,25 +487,29 @@ const createResizeProgramInfo =
}
})()};
`}
- ${shaderHelper.declareVariables(input, output)}
+ ${
+ shaderHelper.registerUniform('output_size', 'u32')
+ .registerUniform('scales', 'f32', scales.length)
+ .registerUniform('roi', 'f32', roi.length)
+ .declareVariables(input, output)}
${shaderHelper.mainStart()}
- ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
+ ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
${noScale ? 'output[global_idx] = input[global_idx];' : `
- let outputIndices = ${output.offsetToIndices('global_idx')};
- var inputIndices: ${input.type.indices};
+ let output_indices = ${output.offsetToIndices('global_idx')};
+ var input_indices: ${input.type.indices};
${(() => {
switch (attributes.mode) {
case 'nearest':
- return `inputIndices = calculateInputIndicesFromOutputIndices(outputIndices);
- if (checkInputIndices(inputIndices)) {
- output[global_idx] = input[${input.indicesToOffset('inputIndices')}];
+ return `input_indices = calculateInputIndicesFromOutputIndices(output_indices);
+ if (checkInputIndices(input_indices)) {
+ output[global_idx] = ${input.getByIndices('input_indices')};
} else {
output[global_idx] = ${attributes.extrapolationValue};
}`;
case 'linear':
- return 'output[global_idx] = bilinearInterpolation(outputIndices);';
+ return 'output[global_idx] = bilinearInterpolation(output_indices);';
case 'cubic':
- return 'output[global_idx] = bicubicInterpolation(outputIndices);';
+ return 'output[global_idx] = bicubicInterpolation(output_indices);';
default:
throw Error(`Unsupported resize mode: ${attributes.mode}`);
}
@@ -518,12 +521,20 @@ const createResizeProgramInfo =
name: 'Resize',
shaderCache: {
hint: `${attributes.cacheKey}|${opsetVersion}|${scales.length > 0 ? scales : ''}|${
- sizes.length > 0 ? sizes : ''}|${noScale}`
+ sizes.length > 0 ? sizes : ''}|${roi.length > 0 ? roi : ''}|${noScale}`,
+ inputDependencies: ['rank']
},
getShaderSource,
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputTensor.dataType}],
- dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}
+ dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
+ programUniforms: [
+ {type: 'uint32', data: outputSize},
+ {type: 'float32', data: scales},
+ {type: 'float32', data: roi},
+ ...createTensorShapeVariables(inputShape),
+ ...createTensorShapeVariables(outputShape),
+ ]
})
};
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts
index 7458579bf4340..5212c6475dce0 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/slice.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/slice.ts
@@ -7,7 +7,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types';
-import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
+import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper, UniformsArrayType} from './common';
export interface SliceAttributes extends AttributeWithCacheKey {
readonly starts: number[];
@@ -77,30 +77,25 @@ const fixStartEndValues =
};
const calculateInputIndicesImpl =
- (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[], outputShape: readonly number[],
- enableInputShapeUniforms: boolean): string =>
- `fn calculateInputIndices(outputIndices: ${output.type.indices}) -> ${input.type.indices} {
- var inputIndices: ${input.type.indices};
+ (input: IndicesHelper, output: IndicesHelper, inputShape: readonly number[]): string =>
+ `fn calculateInputIndices(output_indices: ${output.type.indices}) -> ${input.type.indices} {
+ var input_indices: ${input.type.indices};
var carry = 0u;
for (var i = ${inputShape.length}; i >= 0; i--) {
- let input_shape_i = ${
- enableInputShapeUniforms ? `uniforms.input_shape${inputShape.length > 1 ? '[i]' : ''}` : 'inputShape[i]'};
- let steps_i = ${
- enableInputShapeUniforms ? `uniforms.steps${inputShape.length > 1 ? '[i]' : ''}` : 'steps[i]'};
- let signs_i = ${
- enableInputShapeUniforms ? `uniforms.signs${inputShape.length > 1 ? '[i]' : ''}` : 'signs[i]'};
- let starts_i = ${
- enableInputShapeUniforms ? `uniforms.starts${inputShape.length > 1 ? '[i]' : ''}` : 'starts[i]'};
- var outputIndex = ${outputShape.length === 1 ? 'outputIndices' : 'outputIndices[i]'};
- var inputIndex = outputIndex * steps_i + starts_i + carry;
- carry = inputIndex / input_shape_i;
- inputIndex = inputIndex % input_shape_i;
+ let input_shape_i = ${getElementAt('uniforms.input_shape', 'i', inputShape.length)};
+ let steps_i = ${getElementAt('uniforms.steps', 'i', inputShape.length)};
+ let signs_i = ${getElementAt('uniforms.signs', 'i', inputShape.length)};
+ let starts_i = ${getElementAt('uniforms.starts', 'i', inputShape.length)};
+ var output_index = ${output.indicesGet('output_indices', 'i')};
+ var input_index = output_index * steps_i + starts_i + carry;
+ carry = input_index / input_shape_i;
+ input_index = input_index % input_shape_i;
if (signs_i < 0) {
- inputIndex = input_shape_i - inputIndex - 1u + starts_i;
+ input_index = input_shape_i - input_index - 1u + starts_i;
}
- ${inputShape.length === 1 ? 'inputIndices' : 'inputIndices[i]'} = inputIndex;
+ ${input.indicesSet('input_indices', 'i', 'input_index')};
}
- return inputIndices;
+ return input_indices;
}`;
const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: SliceAttributes): ProgramInfo => {
@@ -145,60 +140,38 @@ const createSliceProgramInfo = (inputs: readonly TensorView[], attributes: Slice
}
});
// Output rank is expected to be less than or equal to the input rank.
- const enableShapeUniforms = enableShapesUniforms(inputs[0].dims.length);
- const inputShapeOrRank = enableShapeUniforms ? inputs[0].dims.length : inputs[0].dims;
-
const outputShape = inputShape.slice(0);
axes.forEach((axis, _) => {
outputShape[axis] = Math.ceil((ends[axis] - starts[axis]) / steps[axis]);
});
- const outputShapeOrRank = enableShapeUniforms ? outputShape.length : outputShape;
-
const outputTensorInfo: TensorInfo = {dims: outputShape, dataType: inputs[0].dataType};
- const output = outputVariable('output', inputs[0].dataType, outputShapeOrRank);
- const input = inputVariable('input', inputs[0].dataType, inputShapeOrRank);
+ const output = outputVariable('output', inputs[0].dataType, outputShape.length);
+ const input = inputVariable('input', inputs[0].dataType, inputs[0].dims.length);
const outputSize = ShapeUtil.size(outputShape);
- const programUniforms: ProgramUniform[] = [];
- const uniforms: UniformsArrayType = [];
- if (enableShapeUniforms) {
- uniforms.push({name: 'starts', type: starts.length > 1 ? `vec${starts.length}` : 'u32'});
- uniforms.push({name: 'signs', type: signs.length > 1 ? `vec${signs.length}` : 'i32'});
- uniforms.push({name: 'steps', type: steps.length > 1 ? `vec${steps.length}` : 'u32'});
- programUniforms.push({type: 'uint32', data: starts});
- programUniforms.push({type: 'int32', data: signs});
- programUniforms.push({type: 'uint32', data: steps});
- }
- uniforms.push({name: 'outputSize', type: 'u32'});
- programUniforms.push({type: 'uint32', data: outputSize});
- if (enableShapeUniforms) {
- programUniforms.push(...createTensorShapeVariables(inputs[0].dims));
- programUniforms.push(...createTensorShapeVariables(outputShape));
- }
+ const uniforms: UniformsArrayType = [
+ {name: 'outputSize', type: 'u32'}, {name: 'starts', type: 'u32', length: starts.length},
+ {name: 'signs', type: 'i32', length: signs.length}, {name: 'steps', type: 'u32', length: steps.length}
+ ];
+
+ const programUniforms: ProgramUniform[] = [
+ {type: 'uint32', data: outputSize}, {type: 'uint32', data: starts}, {type: 'int32', data: signs},
+ {type: 'uint32', data: steps}, ...createTensorShapeVariables(inputs[0].dims),
+ ...createTensorShapeVariables(outputShape)
+ ];
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.registerUniforms(uniforms).declareVariables(input, output)}
- ${enableShapeUniforms ? '' : [
- `const signs = array(${signs.map(i => `${i}i`).join(',')});`,
- `const starts = array(${starts.map(i => `${i}u`).join(',')});`,
- `const steps = array(${steps.map(i => `${i}u`).join(',')});`,
- `const inputShape = array(${inputShape.map(i => `${i}u`).join(',')});`
- ].join('\n')}
-
- ${calculateInputIndicesImpl(input, output, inputShape, outputShape, enableShapeUniforms)}
+ ${calculateInputIndicesImpl(input, output, inputShape)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.outputSize')}
- let outputIndices = ${output.offsetToIndices('global_idx')};
- let inputIndices = calculateInputIndices(outputIndices);
- ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))}
+ let output_indices = ${output.offsetToIndices('global_idx')};
+ let input_indices = calculateInputIndices(output_indices);
+ ${output.setByOffset('global_idx', input.getByIndices('input_indices'))}
}`;
return {
name: 'Slice',
- shaderCache: {
- hint: enableShapeUniforms ? `${signs.length}_${starts.length}_${steps.length}` :
- `${attributes.cacheKey} | ${inputs[4]?.dims ?? ''}`,
- inputDependencies: [enableShapeUniforms ? 'rank' : 'dims']
- },
+ shaderCache: {hint: `${signs.length}_${starts.length}_${steps.length}`, inputDependencies: ['rank']},
getShaderSource,
getRunData: () => ({
outputs: [outputTensorInfo],
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/split.ts b/js/web/lib/wasm/jsep/webgpu/ops/split.ts
index fd60d81b87ae1..b8582614fa214 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/split.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/split.ts
@@ -4,9 +4,9 @@
import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
-import {ComputeContext, ProgramInfo, TensorInfo} from '../types';
+import {ComputeContext, ProgramInfo, ProgramUniform, TensorInfo} from '../types';
-import {IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
+import {createTensorShapeVariables, getElementAt, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
export interface SplitAttributes extends AttributeWithCacheKey {
readonly axis: number;
@@ -34,7 +34,7 @@ const createSplitAttributesFromInputs =
const calculateOutputIndexImpl = (numberOfTensors: number): string => `
fn calculateOutputIndex(index: u32) -> u32 {
for (var i: u32 = 0u; i < ${numberOfTensors}u; i += 1u ) {
- if (index < sizeInConcatAxis[i]) {
+ if (index < ${getElementAt('uniforms.size_in_split_axis', 'i', numberOfTensors)}) {
return i;
}
}
@@ -48,15 +48,15 @@ const writeBufferDataImpl = (outputs: readonly IndicesHelper[]) => {
if (numberOfTensors === 1) {
codeLines.push(returnSnippet);
} else if (i === 0) {
- codeLines.push(`if (outputNumber == ${i}u) { ${returnSnippet} }`);
+ codeLines.push(`if (output_number == ${i}u) { ${returnSnippet} }`);
} else if (i === numberOfTensors - 1) {
codeLines.push(`else { ${returnSnippet} }`);
} else {
- codeLines.push(`else if (outputNumber == ${i}) { ${returnSnippet} }`);
+ codeLines.push(`else if (output_number == ${i}) { ${returnSnippet} }`);
}
}
return `
- fn writeBufferData(outputNumber: u32, indices: ${outputs[0].type.indices}, global_idx: u32) {
+ fn writeBufferData(output_number: u32, indices: ${outputs[0].type.indices}, global_idx: u32) {
${codeLines.join('\n')}
}`;
};
@@ -65,48 +65,54 @@ const createSplitProgramInfo = (inputs: readonly TensorView[], attributes: Split
const inputShape = inputs[0].dims;
const inputSize = ShapeUtil.size(inputShape);
const dataType = inputs[0].dataType;
- const rank = inputShape.length;
- const axis = attributes.axis;
- const adjustedAxis = (axis < 0) ? inputShape.length + axis : axis;
+ const axis = ShapeUtil.normalizeAxis(attributes.axis, inputShape.length);
const outputs = new Array(attributes.numOutputs);
const input = inputVariable('input', dataType, inputShape);
- const sizeInConcatAxis = new Array(attributes.numOutputs);
+ const sizeInSplitAxis = new Array(attributes.numOutputs);
const outputsTensorInfo: TensorInfo[] = [];
const outputShapes: number[][] = [];
let previousSum = 0;
+ const programUniforms: ProgramUniform[] = [{type: 'uint32', data: inputSize}];
for (let i = 0; i < attributes.numOutputs; i++) {
previousSum += attributes.splitSizes[i];
- sizeInConcatAxis[i] = previousSum;
+ sizeInSplitAxis[i] = previousSum;
const outputShape = inputShape.slice();
outputShape[attributes.axis] = attributes.splitSizes[i];
outputShapes.push(outputShape);
- outputs[i] = outputVariable(`output${i}`, dataType, outputShapes[i]);
+ outputs[i] = outputVariable(`output${i}`, dataType, outputShape);
outputsTensorInfo.push({dims: outputShapes[i], dataType: inputs[0].dataType});
}
- const indicesAxis = rank < 2 ? 'indices' : `indices[${adjustedAxis}]`;
+ programUniforms.push({type: 'uint32', data: sizeInSplitAxis});
+ programUniforms.push(...createTensorShapeVariables(inputShape));
+ outputShapes.forEach((outputShape) => programUniforms.push(...createTensorShapeVariables(outputShape)));
const getShaderSource = (shaderHelper: ShaderHelper) => `
- ${shaderHelper.declareVariables(input, ...outputs)}
- const sizeInConcatAxis = array(${sizeInConcatAxis.map(i => `${i}u`).join(',')});
- ${calculateOutputIndexImpl(sizeInConcatAxis.length)}
+ ${
+ shaderHelper.registerUniform('input_size', 'u32')
+ .registerUniform('size_in_split_axis', 'u32', sizeInSplitAxis.length)
+ .declareVariables(input, ...outputs)}
+ ${calculateOutputIndexImpl(sizeInSplitAxis.length)}
${writeBufferDataImpl(outputs)}
${shaderHelper.mainStart()}
- ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(inputSize)}
+ ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.input_size')}
var indices = ${input.offsetToIndices('global_idx')};
- let outputNumber = calculateOutputIndex(${indicesAxis});
- if (outputNumber != 0) {
- ${indicesAxis} -= sizeInConcatAxis[outputNumber - 1u];
+ var index = ${input.indicesGet('indices', axis)};
+ let output_number = calculateOutputIndex(index);
+ if (output_number != 0) {
+ index -= ${getElementAt('uniforms.size_in_split_axis', 'output_number - 1u', sizeInSplitAxis.length)};
+ ${input.indicesSet('indices', axis, 'index')};
}
- writeBufferData(outputNumber, indices, global_idx);
+ writeBufferData(output_number, indices, global_idx);
}`;
return {
name: 'Split',
- shaderCache: {hint: attributes.cacheKey},
+ shaderCache: {hint: attributes.cacheKey, inputDependencies: ['rank']},
getShaderSource,
getRunData: () => ({
outputs: outputsTensorInfo,
dispatchGroup: {x: Math.ceil(inputSize / 64 /* workgroup size */)},
+ programUniforms
})
};
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts
index e294541a775ca..90a36a7bec2a9 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/tile.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/tile.ts
@@ -6,7 +6,7 @@ import {TensorView} from '../../tensor-view';
import {ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo} from '../types';
-import {inputVariable, outputVariable, ShaderHelper} from './common';
+import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
const getRepeats = (repeatsTensorView: TensorView): readonly number[] =>
Array.from(repeatsTensorView.getBigInt64Array(), Number);
@@ -54,30 +54,35 @@ export const createTileProgramInfo = (inputs: readonly TensorView[]): ProgramInf
const outputSize = ShapeUtil.size(outputShape);
const dataType = inputs[0].dataType;
- const input = inputVariable('input', dataType, inputShape);
- const output = outputVariable('output', dataType, outputShape);
+ const input = inputVariable('input', dataType, inputShape.length);
+ const output = outputVariable('output', dataType, outputShape.length);
const getShaderSource = (shaderHelper: ShaderHelper) => `
const inputShape = ${input.indices(...inputShape)};
- ${shaderHelper.declareVariables(input, output)}
+ ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
${shaderHelper.mainStart()}
- ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
- let outputIndices = ${output.offsetToIndices('global_idx')};
- var inputIndices: ${input.type.indices};
+ ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size')}
+ let output_indices = ${output.offsetToIndices('global_idx')};
+ var input_indices: ${input.type.indices};
for (var i = 0; i < ${inputShape.length}; i++) {
- let inputDimValue = ${output.indicesGet('outputIndices', 'i')} % ${input.indicesGet('inputShape', 'i')};
+ let input_dim_i = ${input.indicesGet('uniforms.input_shape', 'i')};
+ let input_dim_value = ${output.indicesGet('output_indices', 'i')} % input_dim_i;
- ${input.indicesSet('inputIndices', 'i', 'inputDimValue')}
+ ${input.indicesSet('input_indices', 'i', 'input_dim_value')}
}
- ${output.setByOffset('global_idx', input.getByIndices('inputIndices'))}
+ ${output.setByOffset('global_idx', input.getByIndices('input_indices'))}
}`;
return {
name: 'Tile',
- shaderCache: {hint: `${repeats}`},
+ shaderCache: {hint: `${repeats}`, inputDependencies: ['rank']},
getRunData: () => ({
outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
+ programUniforms: [
+ {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims),
+ ...createTensorShapeVariables(outputShape)
+ ],
}),
getShaderSource,
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
index 119609e06f5a3..51114d8a99dd1 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
@@ -7,7 +7,7 @@ import {MAX_CLIP, MIN_CLIP, ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
-import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from './common';
+import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglValueType} from './common';
type BuiltinFunctionName = string;
type ElementwiseCustomExpression = (expression: string) => string;
@@ -132,7 +132,7 @@ const generateClipAttributesFromInputs = (inputs: readonly TensorView[]): ClipAt
export const clip = (context: ComputeContext, clipAttributes: ClipAttributes): void => {
const attributes = context.inputs.length === 1 ? clipAttributes : generateClipAttributesFromInputs(context.inputs);
- const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
+ const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(
createElementwiseProgramInfo(
context.inputs[0], 'Clip', a => `clamp(${a}, clip_min_, clip_max_)`, `
@@ -163,15 +163,16 @@ export const parseAlphaAttributes = (attributes: Record): Alpha
createAttributeWithCacheKey(attributes as {alpha: number});
export const elu = (context: ComputeContext, attributes: AlphaAttributes): void => {
+ const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Elu', a => `elu_vf32(${a})`, `
- const elu_alpha_: f32 = f32(${attributes.alpha});
+ const elu_alpha_ = ${dataType}(${attributes.alpha});
- fn elu_f32(a: f32) -> f32 {
+ fn elu_f32(a: ${dataType}) -> ${dataType} {
return select((exp(a) - 1.0) * elu_alpha_, a, a >= 0.0);
}
- fn elu_vf32(v: vec4) -> vec4 {
+ fn elu_vf32(v: vec4<${dataType}>) -> vec4<${dataType}> {
return vec4(elu_f32(v.x), elu_f32(v.y), elu_f32(v.z), elu_f32(v.w));
}`,
attributes.cacheKey));
@@ -192,7 +193,7 @@ fn erf_vf32(v: ${dataType}) -> ${dataType} {
}`;
export const erf = (context: ComputeContext): void => {
- const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
+ const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Erf', a => `erf_vf32(${a})`, erfImpl(`vec4<${dataType}>`, dataType)));
};
@@ -206,16 +207,17 @@ export const floor = (context: ComputeContext): void => {
};
export const gelu = (context: ComputeContext): void => {
- const dataType = tensorTypeToWsglStorageType(context.inputs[0].dataType);
+ const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
context.inputs[0], 'Gelu', a => `0.5 * ${a} * (1.0 + erf_vf32(${a} * 0.7071067811865475))`,
erfImpl(`vec4<${dataType}>`, dataType)));
};
export const leakyRelu = (context: ComputeContext, attributes: AlphaAttributes): void => {
+ const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
- context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4(0.0))`,
- `const leaky_relu_alpha_: f32 = f32(${attributes.alpha});`, attributes.cacheKey));
+ context.inputs[0], 'LeakyRelu', a => `select(leaky_relu_alpha_ * ${a}, ${a}, ${a} >= vec4<${dataType}>(0.0))`,
+ `const leaky_relu_alpha_ = ${dataType}(${attributes.alpha});`, attributes.cacheKey));
};
export const not = (context: ComputeContext): void => {
@@ -231,8 +233,9 @@ export const reciprocal = (context: ComputeContext): void => {
};
export const relu = (context: ComputeContext): void => {
+ const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
- context.inputs[0], 'Relu', a => `select(vec4(0.0), ${a}, ${a} > vec4(0.0))`));
+ context.inputs[0], 'Relu', a => `select(vec4<${dataType}>(0.0), ${a}, ${a} > vec4<${dataType}>(0.0))`));
};
export const sigmoid = (context: ComputeContext): void => {
@@ -260,9 +263,10 @@ export const tanh = (context: ComputeContext): void => {
};
export const thresholdedRelu = (context: ComputeContext, attributes: AlphaAttributes): number => {
+ const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
context.compute(createElementwiseProgramInfo(
- context.inputs[0], 'ThresholdedRelu', a => `select(vec4(0.0), ${a}, ${a} > thresholded_relu_alpha_)`,
- `const thresholded_relu_alpha_: vec4 = vec4(${attributes.alpha});`, attributes.cacheKey));
+ context.inputs[0], 'ThresholdedRelu', a => `select(vec4<${dataType}>(0.0), ${a}, ${a} > thresholded_relu_alpha_)`,
+ `const thresholded_relu_alpha_ = vec4<${dataType}>(${attributes.alpha});`, attributes.cacheKey));
return 0;
};
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/where.ts b/js/web/lib/wasm/jsep/webgpu/ops/where.ts
index 6f66dd86b4088..687ee054096cc 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/where.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/where.ts
@@ -6,18 +6,15 @@ import {TensorView} from '../../tensor-view';
import {BroadcastUtil, ShapeUtil} from '../../util';
import {ComputeContext, ProgramInfo} from '../types';
-import {inputVariable, outputVariable, ShaderHelper} from './common';
+import {createTensorShapeVariables, inputVariable, outputVariable, ShaderHelper} from './common';
const createWhereOpProgramShader =
(shaderHelper: ShaderHelper, inputs: readonly TensorView[], dimsOutput: readonly number[], isBroadcast: boolean,
typeOutput: number) => {
- const outputSize = ShapeUtil.size(dimsOutput);
- const vecSize = Math.ceil(outputSize / 4);
-
- const output = outputVariable('outputData', typeOutput, dimsOutput, 4);
- const a = inputVariable('aData', inputs[1].dataType, inputs[1].dims, 4);
- const b = inputVariable('bData', inputs[2].dataType, inputs[2].dims, 4);
- const c = inputVariable('cData', inputs[0].dataType, inputs[0].dims, 4);
+ const output = outputVariable('output_data', typeOutput, dimsOutput.length, 4);
+ const a = inputVariable('a_data', inputs[1].dataType, inputs[1].dims.length, 4);
+ const b = inputVariable('b_data', inputs[2].dataType, inputs[2].dims.length, 4);
+ const c = inputVariable('c_data', inputs[0].dataType, inputs[0].dims.length, 4);
let assignment: string;
const expression = (a: string, b: string, c: string) => `select(${b}, ${a}, ${c})`;
@@ -27,20 +24,20 @@ const createWhereOpProgramShader =
expression(a.getByOffset('global_idx'), b.getByOffset('global_idx'), c.getByOffset('global_idx')));
} else {
const singleAssignment = (resStr: string, x: number, typeCast = '') => {
- const expressionA = `aData[indexA${x}][componentA${x}]`;
- const expressionB = `bData[indexB${x}][componentB${x}]`;
+ const expressionA = `a_data[index_a${x}][component_a${x}]`;
+ const expressionB = `b_data[index_b${x}][component_b${x}]`;
// eslint-disable-next-line no-bitwise
- const expressionC = `bool(cData[indexC${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
+ const expressionC = `bool(c_data[index_c${x}] & ${0xff000000 >>> ((3 - x) * 8)}u)`;
return `
- let outputIndices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
- let offsetA${x} = ${a.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
- let offsetB${x} = ${b.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
- let offsetC${x} = ${c.broadcastedIndicesToOffset(`outputIndices${x}`, output)};
- let indexA${x} = offsetA${x} / 4u;
- let indexB${x} = offsetB${x} / 4u;
- let indexC${x} = offsetC${x} / 4u;
- let componentA${x} = offsetA${x} % 4u;
- let componentB${x} = offsetB${x} % 4u;
+ let output_indices${x} = ${output.offsetToIndices(`global_idx * 4u + ${x}u`)};
+ let offset_a${x} = ${a.broadcastedIndicesToOffset(`output_indices${x}`, output)};
+ let offset_b${x} = ${b.broadcastedIndicesToOffset(`output_indices${x}`, output)};
+ let offset_c${x} = ${c.broadcastedIndicesToOffset(`output_indices${x}`, output)};
+ let index_a${x} = offset_a${x} / 4u;
+ let index_b${x} = offset_b${x} / 4u;
+ let index_c${x} = offset_c${x} / 4u;
+ let component_a${x} = offset_a${x} % 4u;
+ let component_b${x} = offset_b${x} % 4u;
${resStr}[${x}] = ${typeCast}(${expression(expressionA, expressionB, expressionC)});
`;
};
@@ -51,21 +48,21 @@ const createWhereOpProgramShader =
${singleAssignment('data', 1, 'u32')}
${singleAssignment('data', 2, 'u32')}
${singleAssignment('data', 3, 'u32')}
- outputData[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`;
+ output_data[global_idx] = dot(vec4(0x1, 0x100, 0x10000, 0x1000000), vec4(data));`;
} else {
assignment = `
- ${singleAssignment('outputData[global_idx]', 0)}
- ${singleAssignment('outputData[global_idx]', 1)}
- ${singleAssignment('outputData[global_idx]', 2)}
- ${singleAssignment('outputData[global_idx]', 3)}
+ ${singleAssignment('output_data[global_idx]', 0)}
+ ${singleAssignment('output_data[global_idx]', 1)}
+ ${singleAssignment('output_data[global_idx]', 2)}
+ ${singleAssignment('output_data[global_idx]', 3)}
`;
}
}
return `
- ${shaderHelper.declareVariables(c, a, b, output)}
+ ${shaderHelper.registerUniform('vec_size', 'u32').declareVariables(c, a, b, output)}
${shaderHelper.mainStart()}
- ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(vecSize)}
+ ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.vec_size')}
${assignment}
}`;
};
@@ -79,6 +76,7 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
const isBroadcast = !(ShapeUtil.areEqual(dimsA, dimsB) && ShapeUtil.areEqual(dimsB, dimsC));
let outputShape = dimsA;
let outputSize = ShapeUtil.size(dimsA);
+ const vecSize = Math.ceil(outputSize / 4);
// TODO: deal with zero-sized tensors (eg. dims=[1,0])
if (isBroadcast) {
@@ -92,11 +90,16 @@ const createWhereOpProgramInfo = (inputs: readonly TensorView[]): ProgramInfo =>
return {
name: 'Where',
+ shaderCache: {inputDependencies: ['rank', 'rank', 'rank']},
getShaderSource: (shaderHelper) =>
createWhereOpProgramShader(shaderHelper, inputs, outputShape, isBroadcast, outputDataType),
getRunData: () => ({
outputs: [{dims: outputShape, dataType: outputDataType}],
- dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)}
+ dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* vec size */)},
+ programUniforms: [
+ {type: 'uint32', data: vecSize}, ...createTensorShapeVariables(dimsC), ...createTensorShapeVariables(dimsA),
+ ...createTensorShapeVariables(dimsB), ...createTensorShapeVariables(outputShape)
+ ],
}),
};
};
diff --git a/js/web/lib/wasm/jsep/webgpu/program-manager.ts b/js/web/lib/wasm/jsep/webgpu/program-manager.ts
index 0b0a545f46481..ae5bf68483b46 100644
--- a/js/web/lib/wasm/jsep/webgpu/program-manager.ts
+++ b/js/web/lib/wasm/jsep/webgpu/program-manager.ts
@@ -75,12 +75,11 @@ export class ProgramManager {
const kernelId = this.backend.currentKernelId!;
const kernelInfo = this.backend.kernels.get(kernelId)!;
- const kernelName = `[${kernelInfo[0]}] ${kernelInfo[1]}`;
void syncData.buffer.mapAsync(GPUMapMode.READ).then(() => {
const mappedData = new BigUint64Array(syncData.buffer.getMappedRange());
- const startTimeU64 = mappedData[0];
- const endTimeU64 = mappedData[1];
+ const [startTimeU64, endTimeU64] = mappedData;
+ const [kernelType, kernelName] = kernelInfo;
syncData.buffer.unmap();
@@ -96,17 +95,33 @@ export class ProgramManager {
}
this.backend.gpuDataManager.release(syncData.id);
- let inputShapes = '';
- inputTensorViews.forEach((value, i) => {
- inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
- });
- let outputShapes = '';
- outputTensorViews.forEach((value, i) => {
- outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
- });
- // eslint-disable-next-line no-console
- console.log(`[profiling] kernel "${kernelId}|${kernelName}" ${inputShapes}${outputShapes}execution time: ${
- endTime - startTime} ns`);
+ if (this.backend.env.webgpu.profiling?.ondata) {
+ this.backend.env.webgpu.profiling.ondata({
+ version: 1,
+ inputsMetadata: inputTensorViews.map(
+ value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})),
+ outputsMetadata: outputTensorViews.map(
+ value => ({dims: value.dims, dataType: tensorDataTypeEnumToString(value.dataType)})),
+ kernelId,
+ kernelType,
+ kernelName,
+ startTime,
+ endTime,
+ });
+ } else {
+ // if no callback is provided, print the profiling message to console
+ let inputShapes = '';
+ inputTensorViews.forEach((value, i) => {
+ inputShapes += `input[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
+ });
+ let outputShapes = '';
+ outputTensorViews.forEach((value, i) => {
+ outputShapes += `output[${i}]: [${value.dims}] | ${tensorDataTypeEnumToString(value.dataType)}, `;
+ });
+ // eslint-disable-next-line no-console
+ console.log(`[profiling] kernel "${kernelId}|${kernelName}|${buildArtifact.programInfo.name}" ${inputShapes}${
+ outputShapes}execution time: ${endTime - startTime} ns`);
+ }
});
}
diff --git a/js/web/lib/wasm/session-handler-training.ts b/js/web/lib/wasm/session-handler-training.ts
index 7de3f4dc2c89e..71815f21e650a 100644
--- a/js/web/lib/wasm/session-handler-training.ts
+++ b/js/web/lib/wasm/session-handler-training.ts
@@ -6,7 +6,7 @@ import {env, InferenceSession, OnnxValue, SessionHandler, Tensor, TrainingSessio
import {SerializableModeldata, TensorMetadata} from './proxy-messages';
import {decodeTensorMetadata, encodeTensorMetadata} from './session-handler-inference';
import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
-import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getParametersSize, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runTrainStep} from './wasm-training-core-impl';
+import {createCheckpointHandle, createTrainingSessionHandle, getContiguousParameters, getModelInputOutputNames, getParametersSize, lazyResetGrad, loadParametersBuffer, releaseTrainingSessionAndCheckpoint, runEvalStep, runOptimizerStep, runTrainStep} from './wasm-training-core-impl';
export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
private sessionId: number;
@@ -15,8 +15,8 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
inputNames: string[];
outputNames: string[];
- inputEncodedNames: number[];
- outputEncodedNames: number[];
+ evalInputNames: string[] = [];
+ evalOutputNames: string[] = [];
async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise {
let buffer: Uint8Array;
@@ -51,8 +51,12 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
}
this.checkpointId = createCheckpointHandle(checkpointData);
- [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] =
+ this.sessionId =
createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options);
+ [this.inputNames, this.outputNames] = getModelInputOutputNames(this.sessionId, false);
+ if (evalModelUriOrBuffer !== '') {
+ [this.evalInputNames, this.evalOutputNames] = getModelInputOutputNames(this.sessionId, true);
+ }
}
/**
@@ -101,6 +105,10 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
return resultMap;
}
+ async lazyResetGrad(): Promise {
+ await lazyResetGrad(this.sessionId);
+ }
+
async runTrainStep(
feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
options: InferenceSession.RunOptions): Promise {
@@ -118,6 +126,27 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
}
+ async runOptimizerStep(options: InferenceSession.RunOptions): Promise {
+ await runOptimizerStep(this.sessionId, options);
+ }
+
+ async runEvalStep(
+ feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType,
+ options: InferenceSession.RunOptions): Promise {
+ const [, inputIndices, inputs] = this.convertMapIntoValuesArrayAndIndicesArray(
+ feeds, this.evalInputNames,
+ (t, i): TensorMetadata => encodeTensorMetadata(t, () => `input "${this.evalInputNames[inputIndices[i]]}"`));
+
+ const [outputArray, outputIndices, outputs] =
+ this.convertMapIntoValuesArrayAndIndicesArray(
+ fetches, this.evalOutputNames,
+ (t, i): TensorMetadata|null =>
+ t ? encodeTensorMetadata(t, () => `output "${this.evalOutputNames[outputIndices[i]]}"`) : null);
+
+ const results = await runEvalStep(this.sessionId, inputIndices, inputs, outputIndices, outputs, options);
+ return this.convertTensorMetadataToReturnType(results, outputArray, outputIndices);
+ }
+
async getParametersSize(trainableOnly: boolean): Promise {
return getParametersSize(this.sessionId, trainableOnly);
}
@@ -131,7 +160,6 @@ export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSes
}
async dispose(): Promise {
- return releaseTrainingSessionAndCheckpoint(
- this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
+ return releaseTrainingSessionAndCheckpoint(this.checkpointId, this.sessionId);
}
}
diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts
index c0a4235113148..0cc28188a6093 100644
--- a/js/web/lib/wasm/wasm-training-core-impl.ts
+++ b/js/web/lib/wasm/wasm-training-core-impl.ts
@@ -3,7 +3,7 @@
import {InferenceSession, Tensor} from 'onnxruntime-common';
-import {SerializableModeldata, SerializableSessionMetadata, TensorMetadata} from './proxy-messages';
+import {SerializableModeldata, TensorMetadata} from './proxy-messages';
import {setRunOptions} from './run-options';
import {setSessionOptions} from './session-options';
import {dataLocationStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common';
@@ -77,50 +77,44 @@ const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolea
};
const getModelInputOutputNamesLoop =
- (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => {
+ (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): string[] => {
const names = [];
const wasm = getInstance();
- const namesUTF8Encoded = [];
-
for (let i = 0; i < count; i++) {
if (wasm._OrtTrainingGetModelInputOutputName) {
const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel);
ifErrCodeCheckLastError(name, `Can't get input or output name -- is input: ${isInput}, index ${i}`, false);
- namesUTF8Encoded.push(name);
names.push(wasm.UTF8ToString(name));
+ wasm._free(name);
} else {
throw new Error(NO_TRAIN_FUNCS_MSG);
}
}
- return [names, namesUTF8Encoded];
+ return names;
};
-const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => {
- const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false);
+export const getModelInputOutputNames = (trainingSessionId: number, isEvalModel: boolean): [string[], string[]] => {
+ let inputNames: string[] = [];
+ let outputNames: string[] = [];
+
+ const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, isEvalModel);
- const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false);
- const [outputNames, outputNamesUTF8Encoded] =
- getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false);
+ inputNames = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, isEvalModel);
+ outputNames = getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, isEvalModel);
- return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded];
+ return [inputNames, outputNames];
};
export const createTrainingSessionHandle =
(checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata,
- optimizerModelData: SerializableModeldata,
- options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => {
+ optimizerModelData: SerializableModeldata, options: InferenceSession.SessionOptions): number => {
const wasm = getInstance();
let trainingSessionHandle = 0;
let sessionOptionsHandle = 0;
let allocs: number[] = [];
- let inputNamesUTF8Encoded: number[] = [];
- let outputNamesUTF8Encoded: number[] = [];
-
- let inputNames: string[] = [];
- let outputNames: string[] = [];
try {
[sessionOptionsHandle, allocs] = setSessionOptions(options);
@@ -133,11 +127,7 @@ export const createTrainingSessionHandle =
}
ifErrCodeCheckLastError(trainingSessionHandle, 'Error occurred when trying to create a TrainingSession', false);
-
- [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] =
- getTrainingModelInputOutputNames(trainingSessionHandle);
- return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded];
-
+ return trainingSessionHandle;
} catch (e) {
if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) {
wasm._OrtTrainingReleaseSession(trainingSessionHandle);
@@ -152,8 +142,6 @@ export const createTrainingSessionHandle =
wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
}
allocs.forEach(alloc => wasm._free(alloc));
- inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
- outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
}
};
@@ -265,6 +253,17 @@ const moveOutputToTensorMetadataArr =
return output;
};
+export const lazyResetGrad = async(trainingSessionId: number): Promise => {
+ const wasm = getInstance();
+
+ if (wasm._OrtTrainingLazyResetGrad) {
+ const errorCode = wasm._OrtTrainingLazyResetGrad(trainingSessionId);
+ ifErrCodeCheckLastError(errorCode, 'Can\'t call lazyResetGrad.');
+ } else {
+ throw new Error(NO_TRAIN_FUNCS_MSG);
+ }
+};
+
export const runTrainStep = async(
trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[],
outputTensors: Array, options: InferenceSession.RunOptions): Promise => {
@@ -317,6 +316,83 @@ export const runTrainStep = async(
}
};
+export const runOptimizerStep =
+ async(trainingSessionId: number, options: InferenceSession.RunOptions): Promise => {
+ const wasm = getInstance();
+
+ let runOptionsHandle = 0;
+ let runOptionsAllocs: number[] = [];
+
+ try {
+ [runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
+
+ if (wasm._OrtTrainingOptimizerStep) {
+ const errCode = wasm._OrtTrainingOptimizerStep(trainingSessionId, runOptionsHandle);
+ ifErrCodeCheckLastError(errCode, 'Failed to call OrtTrainingOptimizerStep in the WebAssembly layer');
+ } else {
+ throw new Error(NO_TRAIN_FUNCS_MSG);
+ }
+ } finally {
+ if (runOptionsHandle !== 0) {
+ wasm._OrtReleaseRunOptions(runOptionsHandle);
+ }
+ runOptionsAllocs.forEach(p => wasm._free(p));
+ }
+};
+
+export const runEvalStep = async(
+ trainingSessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[],
+ outputTensors: Array, options: InferenceSession.RunOptions): Promise => {
+ const wasm = getInstance();
+
+ const inputCount = inputIndices.length;
+ const outputCount = outputIndices.length;
+
+ let runOptionsHandle = 0;
+ let runOptionsAllocs: number[] = [];
+
+ const inputTensorHandles: number[] = [];
+ const outputTensorHandles: number[] = [];
+ const inputOutputAllocs: number[] = [];
+
+ const beforeRunStack = wasm.stackSave();
+
+ try {
+ // prepare parameters by moving them to heap
+ [runOptionsHandle, runOptionsAllocs] = setRunOptions(options);
+
+ // handle inputs -- you don't want anything added to the index
+ const inputValuesOffset = createAndAllocateTensors(
+ trainingSessionId, inputIndices, inputTensors, inputTensorHandles, inputOutputAllocs, 0);
+ // handle outputs
+ // you want inputCount to be added to the index of every output tensor passed to prepareInputOutputTensor
+ const outputValuesOffset = createAndAllocateTensors(
+ trainingSessionId, outputIndices, outputTensors, outputTensorHandles, inputOutputAllocs, inputCount);
+
+ if (wasm._OrtTrainingEvalStep) {
+ const errorCode = wasm._OrtTrainingEvalStep(
+ trainingSessionId, inputValuesOffset, inputCount, outputValuesOffset, outputCount, runOptionsHandle);
+
+ ifErrCodeCheckLastError(errorCode, 'failed to call OrtTrainingEvalStep in the WebAssembly layer');
+ } else {
+ throw new Error(NO_TRAIN_FUNCS_MSG);
+ }
+
+ return moveOutputToTensorMetadataArr(outputValuesOffset, outputCount, outputTensorHandles, outputTensors);
+ } finally {
+ wasm.stackRestore(beforeRunStack);
+
+ inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v));
+ outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v));
+ inputOutputAllocs.forEach(p => wasm._free(p));
+
+ if (runOptionsHandle !== 0) {
+ wasm._OrtReleaseRunOptions(runOptionsHandle);
+ }
+ runOptionsAllocs.forEach(p => wasm._free(p));
+ }
+};
+
export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => {
const wasm = getInstance();
const stack = wasm.stackSave();
@@ -439,17 +515,13 @@ export const loadParametersBuffer =
}
};
-export const releaseTrainingSessionAndCheckpoint =
- (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]):
- void => {
- const wasm = getInstance();
- inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
- outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
+export const releaseTrainingSessionAndCheckpoint = (checkpointId: number, sessionId: number): void => {
+ const wasm = getInstance();
- if (wasm._OrtTrainingReleaseSession) {
- wasm._OrtTrainingReleaseSession(sessionId);
- }
- if (wasm._OrtTrainingReleaseCheckpoint) {
- wasm._OrtTrainingReleaseCheckpoint(checkpointId);
- }
- };
+ if (wasm._OrtTrainingReleaseSession) {
+ wasm._OrtTrainingReleaseSession(sessionId);
+ }
+ if (wasm._OrtTrainingReleaseCheckpoint) {
+ wasm._OrtTrainingReleaseCheckpoint(checkpointId);
+ }
+};
diff --git a/js/web/test/data/ops/cumsum.jsonc b/js/web/test/data/ops/cumsum.jsonc
new file mode 100644
index 0000000000000..b3173afb695ea
--- /dev/null
+++ b/js/web/test/data/ops/cumsum.jsonc
@@ -0,0 +1,1362 @@
+[
+ {
+ "name": "CumSum",
+ "operator": "CumSum",
+ "attributes": [
+ { "name": "exclusive", "data": 0, "type": "int" },
+ { "name": "reverse", "data": 0, "type": "int" }
+ ],
+ "opset": {
+ "domain": "",
+ "version": 11
+ },
+ "cases": [
+ {
+ "name": "CumSum 1-D; axis = 0; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5],
+ "dims": [5],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 3, 6, 10, 15],
+ "dims": [5],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 1-D; axis = -1; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5],
+ "dims": [5],
+ "type": "float32"
+ },
+ {
+ "data": [-1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 3, 6, 10, 15],
+ "dims": [5],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = 0; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 2, 3, 5, 7, 9],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = -1; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [-1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 3, 6, 4, 9, 15],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = 1; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 3, 6, 4, 9, 15],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = -2; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [-2],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 2, 3, 5, 7, 9],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (3x3); axis = 0; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
+ "dims": [3, 3],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 2, 3, 5, 7, 9, 12, 15, 18],
+ "dims": [3, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (3x3); axis = 1; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
+ "dims": [3, 3],
+ "type": "float32"
+ },
+ {
+ "data": [1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 3, 6, 4, 9, 15, 7, 15, 24],
+ "dims": [3, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = 0; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 2, 3, 4, 6, 8, 10, 12],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = 1; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 2, 4, 6, 5, 6, 12, 14],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = -1; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [-1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 3, 3, 7, 5, 11, 7, 15],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = 2; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [2],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 3, 3, 7, 5, 11, 7, 15],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = -2; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [-2],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 2, 4, 6, 5, 6, 12, 14],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = -3; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [-3],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 2, 3, 4, 6, 8, 10, 12],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "CumSum",
+ "operator": "CumSum",
+ "attributes": [
+ { "name": "exclusive", "data": 1, "type": "int" },
+ { "name": "reverse", "data": 0, "type": "int" }
+ ],
+ "opset": {
+ "domain": "",
+ "version": 11
+ },
+ "cases": [
+ {
+ "name": "CumSum 1-D; axis = 0; exclusive = 1, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5],
+ "dims": [5],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0, 1, 3, 6, 10],
+ "dims": [5],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 1-D; axis = -1; exclusive = 1, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5],
+ "dims": [5],
+ "type": "float32"
+ },
+ {
+ "data": [-1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0, 1, 3, 6, 10],
+ "dims": [5],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = 0; exclusive = 1, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0, 0, 0, 1, 2, 3],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = -1; exclusive = 1, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [-1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0, 1, 3, 0, 4, 9],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = 1; exclusive = 1, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0, 1, 3, 0, 4, 9],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = -2",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [-2],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0, 0, 0, 1, 2, 3],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (3x3); axis = 0; exclusive = 1, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
+ "dims": [3, 3],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0, 0, 0, 1, 2, 3, 5, 7, 9],
+ "dims": [3, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (3x3); axis = 1; exclusive = 1, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
+ "dims": [3, 3],
+ "type": "float32"
+ },
+ {
+ "data": [1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0, 1, 3, 0, 4, 9, 0, 7, 15],
+ "dims": [3, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = 0; exclusive = 1, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0, 0, 0, 0, 1, 2, 3, 4],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = 1; exclusive = 1, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0, 0, 1, 2, 0, 0, 5, 6],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = -1; exclusive = 1, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [-1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0, 1, 0, 3, 0, 5, 0, 7],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = 2; exclusive = 1, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [2],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0, 1, 0, 3, 0, 5, 0, 7],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = -2; exclusive = 1, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [-2],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0, 0, 1, 2, 0, 0, 5, 6],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = -3; exclusive = 1, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [-3],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0, 0, 0, 0, 1, 2, 3, 4],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "CumSum",
+ "operator": "CumSum",
+ "attributes": [
+ { "name": "exclusive", "data": 0, "type": "int" },
+ { "name": "reverse", "data": 1, "type": "int" }
+ ],
+ "opset": {
+ "domain": "",
+ "version": 11
+ },
+ "cases": [
+ {
+ "name": "CumSum 1-D; axis = 0; exclusive = 0, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5],
+ "dims": [5],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [15, 14, 12, 9, 5],
+ "dims": [5],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 1-D; axis = -1; exclusive = 0, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5],
+ "dims": [5],
+ "type": "float32"
+ },
+ {
+ "data": [-1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [15, 14, 12, 9, 5],
+ "dims": [5],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = 0; exclusive = 0, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [5, 7, 9, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = -1; exclusive = 0, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [-1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [6, 5, 3, 15, 11, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = 1; exclusive = 0, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [6, 5, 3, 15, 11, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = -2; exclusive = 0, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [-2],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [5, 7, 9, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (3x3); axis = 0; exclusive = 0, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
+ "dims": [3, 3],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [12, 15, 18, 11, 13, 15, 7, 8, 9],
+ "dims": [3, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (3x3); axis = 1; exclusive = 0, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
+ "dims": [3, 3],
+ "type": "float32"
+ },
+ {
+ "data": [1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [6, 5, 3, 15, 11, 6, 24, 17, 9],
+ "dims": [3, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = 0; exclusive = 0, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [6, 8, 10, 12, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = 1; exclusive = 0, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [4, 6, 3, 4, 12, 14, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = -1; exclusive = 0, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [-1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [3, 2, 7, 4, 11, 6, 15, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = 2; exclusive = 0, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [2],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [3, 2, 7, 4, 11, 6, 15, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = -2; exclusive = 0, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [-2],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [4, 6, 3, 4, 12, 14, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = -3; exclusive = 0, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [-3],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [6, 8, 10, 12, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "CumSum",
+ "operator": "CumSum",
+ "attributes": [
+ { "name": "exclusive", "data": 1, "type": "int" },
+ { "name": "reverse", "data": 1, "type": "int" }
+ ],
+ "opset": {
+ "domain": "",
+ "version": 11
+ },
+ "cases": [
+ {
+ "name": "CumSum 1-D; axis = 0; exclusive = 1, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5],
+ "dims": [5],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [14, 12, 9, 5, 0],
+ "dims": [5],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 1-D; axis = -1; exclusive = 1, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5],
+ "dims": [5],
+ "type": "float32"
+ },
+ {
+ "data": [-1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [14, 12, 9, 5, 0],
+ "dims": [5],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = 0; exclusive = 1, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [4, 5, 6, 0, 0, 0],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = -1; exclusive = 1, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [-1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [5, 3, 0, 11, 6, 0],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = 1; exclusive = 1, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [5, 3, 0, 11, 6, 0],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (2x3); axis = -2; exclusive = 1, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6],
+ "dims": [2, 3],
+ "type": "float32"
+ },
+ {
+ "data": [-2],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [4, 5, 6, 0, 0, 0],
+ "dims": [2, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (3x3); axis = 0; exclusive = 1, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
+ "dims": [3, 3],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [11, 13, 15, 7, 8, 9, 0, 0, 0],
+ "dims": [3, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 2-D (3x3); axis = 1; exclusive = 1, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8, 9],
+ "dims": [3, 3],
+ "type": "float32"
+ },
+ {
+ "data": [1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [5, 3, 0, 11, 6, 0, 17, 9, 0],
+ "dims": [3, 3],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = 0; exclusive = 1, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [0],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [5, 6, 7, 8, 0, 0, 0, 0],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = 1; exclusive = 1, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [3, 4, 0, 0, 7, 8, 0, 0],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = -1; exclusive = 1, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [-1],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [2, 0, 4, 0, 6, 0, 8, 0],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = 2; exclusive = 1, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [2],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [2, 0, 4, 0, 6, 0, 8, 0],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = -2; exclusive = 1, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [-2],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [3, 4, 0, 0, 7, 8, 0, 0],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ },
+ {
+ "name": "CumSum 3-D; axis = -3; exclusive = 1, reverse = 1",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ },
+ {
+ "data": [-3],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [5, 6, 7, 8, 0, 0, 0, 0],
+ "dims": [2, 2, 2],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "CumSum",
+ "operator": "CumSum",
+ "attributes": [
+ { "name": "exclusive", "data": 0, "type": "int" },
+ { "name": "reverse", "data": 0, "type": "int" }
+ ],
+ "opset": {
+ "domain": "",
+ "version": 11
+ },
+ "cases": [
+ {
+ "name": "CumSum 5-D; axis = 0; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5],
+ "dims": [1, 1, 1, 1, 5],
+ "type": "float32"
+ },
+ {
+ "data": [4],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 3, 6, 10, 15],
+ "dims": [1, 1, 1, 1, 5],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "CumSum",
+ "operator": "CumSum",
+ "attributes": [
+ { "name": "exclusive", "data": 0, "type": "int" },
+ { "name": "reverse", "data": 0, "type": "int" }
+ ],
+ "opset": {
+ "domain": "",
+ "version": 11
+ },
+ "cases": [
+ {
+ "name": "CumSum int32; axis = 0; exclusive = 0, reverse = 0",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5],
+ "dims": [1, 1, 1, 1, 5],
+ "type": "int32"
+ },
+ {
+ "data": [4],
+ "dims": [],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 3, 6, 10, 15],
+ "dims": [1, 1, 1, 1, 5],
+ "type": "int32"
+ }
+ ]
+ }
+ ]
+ }
+]
diff --git a/js/web/test/data/ops/expand.jsonc b/js/web/test/data/ops/expand.jsonc
index 35888e2fc3709..22bc04d558d98 100644
--- a/js/web/test/data/ops/expand.jsonc
+++ b/js/web/test/data/ops/expand.jsonc
@@ -112,6 +112,79 @@
"type": "float32"
}
]
+ },
+ {
+ "name": "Expand 5 - shape < input.size()",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
+ "dims": [1, 1, 1, 2, 6],
+ "type": "float32"
+ },
+ {
+ "data": [2, 1, 6],
+ "dims": [3],
+ "type": "int64"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12],
+ "dims": [1, 1, 2, 2, 6],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
+ },
+ {
+ "name": "Expand - bool",
+ "operator": "Expand",
+ "attributes": [],
+ "cases": [
+ {
+ "name": "Expand - last dim is divisible by 4",
+ "inputs": [
+ {
+ "data": [true, false, false, true],
+ "dims": [4],
+ "type": "bool"
+ },
+ {
+ "data": [2, 4],
+ "dims": [2],
+ "type": "int64"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [true, false, false, true, true, false, false, true],
+ "dims": [2, 4],
+ "type": "bool"
+ }
+ ]
+ },
+ {
+ "name": "Expand - last dim is not divisible by 4",
+ "inputs": [
+ {
+ "data": [true, false, false, true, true, true, false, false, false, true, true, true],
+ "dims": [2, 6],
+ "type": "bool"
+ },
+ {
+ "data": [2, 1],
+ "dims": [2],
+ "type": "int64"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [true, false, false, true, true, true, false, false, false, true, true, true],
+ "dims": [2, 6],
+ "type": "bool"
+ }
+ ]
}
]
}
diff --git a/js/web/test/data/ops/gather.jsonc b/js/web/test/data/ops/gather.jsonc
index 3b1b0e3821832..0be077d237b88 100644
--- a/js/web/test/data/ops/gather.jsonc
+++ b/js/web/test/data/ops/gather.jsonc
@@ -93,5 +93,34 @@
]
}
]
+ },
+ {
+ "name": "Gather - bool",
+ "operator": "Gather",
+ "attributes": [],
+ "cases": [
+ {
+ "name": "data[2,4] indices[1]",
+ "inputs": [
+ {
+ "data": [true, false, false, true, false, false, true, true],
+ "dims": [2, 4],
+ "type": "bool"
+ },
+ {
+ "data": [1],
+ "dims": [1],
+ "type": "int32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [false, false, true, true],
+ "dims": [1, 4],
+ "type": "bool"
+ }
+ ]
+ }
+ ]
}
]
diff --git a/js/web/test/data/ops/global-average-pool.jsonc b/js/web/test/data/ops/global-average-pool.jsonc
index fdf3a8fe1e7a2..17aa061841b2c 100644
--- a/js/web/test/data/ops/global-average-pool.jsonc
+++ b/js/web/test/data/ops/global-average-pool.jsonc
@@ -61,6 +61,29 @@
"type": "float32"
}
]
+ },
+ {
+ "name": "T[1,3,2,2,2] T[1,3,1,1,1]",
+ "inputs": [
+ {
+ "data": [
+ 1.764052391052246, 0.40015721321105957, 0.978738009929657, 2.2408931255340576, 1.8675580024719238,
+ -0.9772778749465942, 0.9500884413719177, -0.15135720372200012, -0.10321885347366333, 0.4105985164642334,
+ 0.14404356479644775, 1.4542734622955322, 0.7610377073287964, 0.12167501449584961, 0.44386324286460876,
+ 0.3336743414402008, 1.4940791130065918, -0.2051582634449005, 0.3130677044391632, -0.8540957570075989,
+ -2.5529897212982178, 0.653618574142456, 0.8644362092018127, -0.7421650290489197
+ ],
+ "dims": [1, 3, 2, 2, 2],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [0.8841065168380737, 0.4457433819770813, -0.12865088880062103],
+ "dims": [1, 3, 1, 1, 1],
+ "type": "float32"
+ }
+ ]
}
]
}
diff --git a/js/web/test/test-main.ts b/js/web/test/test-main.ts
index 24ab0694b32b8..9bd0ec1425f95 100644
--- a/js/web/test/test-main.ts
+++ b/js/web/test/test-main.ts
@@ -56,7 +56,7 @@ if (options.globalEnvFlags) {
ort.env.wasm.initTimeout = flags.wasm.initTimeout;
}
if (flags.webgpu?.profilingMode !== undefined) {
- ort.env.webgpu.profilingMode = flags.webgpu.profilingMode;
+ ort.env.webgpu.profiling = {mode: flags.webgpu.profilingMode};
}
if (flags.webgpu?.validateInputContent !== undefined) {
ort.env.webgpu.validateInputContent = flags.webgpu.validateInputContent;
diff --git a/js/web/tsconfig.json b/js/web/tsconfig.json
index d60d746e9328d..80d0cd0642b80 100644
--- a/js/web/tsconfig.json
+++ b/js/web/tsconfig.json
@@ -6,5 +6,5 @@
"typeRoots": ["./node_modules/@webgpu/types", "./node_modules/@types", "../node_modules/@types"]
},
"include": ["lib", "test"],
- "exclude": ["lib/wasm/proxy-worker"]
+ "exclude": ["lib/wasm/proxy-worker", "test/ort.test.js", "test/ort.test.min.js"]
}
diff --git a/objectivec/include/ort_env.h b/objectivec/include/ort_env.h
index 8456b57bfa402..67db76668b3bb 100644
--- a/objectivec/include/ort_env.h
+++ b/objectivec/include/ort_env.h
@@ -24,6 +24,9 @@ NSString* _Nullable ORTVersion(void);
/**
* The ORT environment.
+ * It maintains shared state including the default logger.
+ *
+ * @note One ORTEnv should be created before and destroyed after other ORT API usage.
*/
@interface ORTEnv : NSObject
diff --git a/objectivec/include/ort_training_session.h b/objectivec/include/ort_training_session.h
index 15c0137817ae2..2ad4fed93c331 100644
--- a/objectivec/include/ort_training_session.h
+++ b/objectivec/include/ort_training_session.h
@@ -39,7 +39,7 @@ NS_ASSUME_NONNULL_BEGIN
* session which will be moved to the device specified in the session option if needed.
*
* @param env The `ORTEnv` instance to use for the training session.
- * @param sessionOptions The `ORTSessionOptions` to use for the training session.
+ * @param sessionOptions The optional `ORTSessionOptions` to use for the training session.
* @param checkpoint Training states that are used as a starting point for training.
* @param trainModelPath The path to the training onnx model.
* @param evalModelPath The path to the evaluation onnx model.
@@ -52,7 +52,7 @@ NS_ASSUME_NONNULL_BEGIN
* keeps a strong (owning) pointer to the checkpoint state.
*/
- (nullable instancetype)initWithEnv:(ORTEnv*)env
- sessionOptions:(ORTSessionOptions*)sessionOptions
+ sessionOptions:(nullable ORTSessionOptions*)sessionOptions
checkpoint:(ORTCheckpoint*)checkpoint
trainModelPath:(NSString*)trainModelPath
evalModelPath:(nullable NSString*)evalModelPath
diff --git a/objectivec/ort_session.mm b/objectivec/ort_session.mm
index d27c3e2cefcfb..87288bd1e9dc7 100644
--- a/objectivec/ort_session.mm
+++ b/objectivec/ort_session.mm
@@ -23,6 +23,7 @@
NS_ASSUME_NONNULL_BEGIN
@implementation ORTSession {
+ ORTEnv* _env; // keep a strong reference so the ORTEnv doesn't get destroyed before this does
std::optional _session;
}
@@ -44,6 +45,7 @@ - (nullable instancetype)initWithEnv:(ORTEnv*)env
}
}
+ _env = env;
_session = Ort::Session{[env CXXAPIOrtEnv],
path.UTF8String,
[sessionOptions CXXAPIOrtSessionOptions]};
diff --git a/objectivec/ort_training_session.mm b/objectivec/ort_training_session.mm
index 285151b412bf0..5387bfda6d411 100644
--- a/objectivec/ort_training_session.mm
+++ b/objectivec/ort_training_session.mm
@@ -19,8 +19,9 @@
NS_ASSUME_NONNULL_BEGIN
@implementation ORTTrainingSession {
- std::optional _session;
+ ORTEnv* _env; // keep a strong reference so the ORTEnv doesn't get destroyed before this does
ORTCheckpoint* _checkpoint;
+ std::optional _session;
}
- (Ort::TrainingSession&)CXXAPIOrtTrainingSession {
@@ -28,7 +29,7 @@ @implementation ORTTrainingSession {
}
- (nullable instancetype)initWithEnv:(ORTEnv*)env
- sessionOptions:(ORTSessionOptions*)sessionOptions
+ sessionOptions:(nullable ORTSessionOptions*)sessionOptions
checkpoint:(ORTCheckpoint*)checkpoint
trainModelPath:(NSString*)trainModelPath
evalModelPath:(nullable NSString*)evalModelPath
@@ -39,9 +40,17 @@ - (nullable instancetype)initWithEnv:(ORTEnv*)env
}
try {
+ if (!sessionOptions) {
+ sessionOptions = [[ORTSessionOptions alloc] initWithError:error];
+ if (!sessionOptions) {
+ return nil;
+ }
+ }
+
std::optional evalPath = utils::toStdOptionalString(evalModelPath);
std::optional optimizerPath = utils::toStdOptionalString(optimizerModelPath);
+ _env = env;
_checkpoint = checkpoint;
_session = Ort::TrainingSession{
[env CXXAPIOrtEnv],
@@ -50,6 +59,7 @@ - (nullable instancetype)initWithEnv:(ORTEnv*)env
trainModelPath.UTF8String,
evalPath,
optimizerPath};
+
return self;
}
ORT_OBJC_API_IMPL_CATCH_RETURNING_NULLABLE(error)
diff --git a/objectivec/test/ort_session_test.mm b/objectivec/test/ort_session_test.mm
index f00f5db2f995f..508289f7bc748 100644
--- a/objectivec/test/ort_session_test.mm
+++ b/objectivec/test/ort_session_test.mm
@@ -295,6 +295,32 @@ - (void)testStringInputs {
XCTAssertTrue([stringData isEqualToArray:outputStringData]);
}
+- (void)testKeepORTEnvReference {
+ ORTEnv* __weak envWeak = _ortEnv;
+ // Remove sole strong reference to the ORTEnv created in setUp.
+ _ortEnv = nil;
+ // There should be no more strong references to it.
+ XCTAssertNil(envWeak);
+
+ // Create a new ORTEnv.
+ NSError* err = nil;
+ ORTEnv* env = [[ORTEnv alloc] initWithLoggingLevel:ORTLoggingLevelWarning
+ error:&err];
+ ORTAssertNullableResultSuccessful(env, err);
+
+ ORTSession* session = [[ORTSession alloc] initWithEnv:env
+ modelPath:[ORTSessionTest getAddModelPath]
+ sessionOptions:[ORTSessionTest makeSessionOptions]
+ error:&err];
+ ORTAssertNullableResultSuccessful(session, err);
+
+ envWeak = env;
+ // Remove strong reference to the ORTEnv passed to the ORTSession initializer.
+ env = nil;
+ // ORTSession should keep a strong reference to it.
+ XCTAssertNotNil(envWeak);
+}
+
@end
NS_ASSUME_NONNULL_END
diff --git a/onnxruntime/contrib_ops/cpu/image_scaler.h b/onnxruntime/contrib_ops/cpu/image_scaler.h
index 9e9d9908ab188..865bca51f1e85 100644
--- a/onnxruntime/contrib_ops/cpu/image_scaler.h
+++ b/onnxruntime/contrib_ops/cpu/image_scaler.h
@@ -16,8 +16,8 @@ template
class ImageScaler final : public OpKernel {
public:
ImageScaler(const OpKernelInfo& info) : OpKernel(info) {
- ORT_ENFORCE(info.GetAttr("scale", &scale_).IsOK());
- ORT_ENFORCE(info.GetAttrs("bias", bias_).IsOK());
+ ORT_THROW_IF_ERROR(info.GetAttr("scale", &scale_));
+ ORT_THROW_IF_ERROR(info.GetAttrs("bias", bias_));
}
Status Compute(OpKernelContext* context) const override {
diff --git a/onnxruntime/contrib_ops/cpu/math/sparse_dense_matmul.cc b/onnxruntime/contrib_ops/cpu/math/sparse_dense_matmul.cc
index b00b10ad649b1..46a8b70d289b7 100644
--- a/onnxruntime/contrib_ops/cpu/math/sparse_dense_matmul.cc
+++ b/onnxruntime/contrib_ops/cpu/math/sparse_dense_matmul.cc
@@ -47,7 +47,6 @@ struct ComputeCtx {
float alpha;
};
-#if !defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__)
template
inline void SparseDenseMatMulImpl(const ComputeCtx& ctx, const ConstSparseMatrixMap& map_A,
const ConstEigenMatrixMapRowMajor& map_B, EigenMatrixMapRowMajor& output_map) {
@@ -64,7 +63,8 @@ inline void SparseDenseMatMulImpl(const ComputeCtx& ctx, const ConstSparseMatrix
template <>
inline void SparseDenseMatMulImpl(const ComputeCtx& ctx, const ConstSparseMatrixMap& map_A,
- const ConstEigenMatrixMapRowMajor& map_B, EigenMatrixMapRowMajor& output_map) {
+ const ConstEigenMatrixMapRowMajor& map_B,
+ EigenMatrixMapRowMajor& output_map) {
if (ctx.trans_A && ctx.trans_B) {
output_map = map_A.transpose() * ctx.alpha * map_B.transpose();
} else if (ctx.trans_A && !ctx.trans_B) {
@@ -84,21 +84,47 @@ struct SparseToDenseCsr {
const auto& b_dims = B.Shape().GetDims();
const auto& out_dims = output.Shape().GetDims();
auto csr_view = A.AsCsr();
-
- ConstSparseMatrixMap map_A(a_dims[0], a_dims[1], A.NumValues(),
- csr_view.Outer().Data(),
- csr_view.Inner().Data(),
+ const Eigen::Index* inner_index_pointer = nullptr;
+ const Eigen::Index* outer_index_pointer = nullptr;
+ // For auto-release the above two pointers when they are not NULL.
+ std::unique_ptr buffer_holder_inner, buffer_holder_outer;
+ if constexpr (std::is_integral::value &&
+ std::is_signed::value &&
+ (sizeof(Eigen::Index) == sizeof(int64_t))) {
+ // On macOS the following reinterpret_cast is necessary because Eigen::Index is an alias of `long` but int64_t is
+ // `long long`. Though they have the same size, compilers still do not allow an implicit casting between them.
+ inner_index_pointer = reinterpret_cast(csr_view.Inner().Data());
+ outer_index_pointer = reinterpret_cast(csr_view.Outer().Data());
+ } else {
+ // In a 32-bit build we need to cast the following two tensors to 32 bits
+ gsl::span inner_data = csr_view.Inner().DataAsSpan();
+ gsl::span outer_data = csr_view.Outer().DataAsSpan();
+ buffer_holder_inner.reset(new Eigen::Index[inner_data.size()]);
+ buffer_holder_outer.reset(new Eigen::Index[outer_data.size()]);
+ inner_index_pointer = buffer_holder_inner.get();
+ outer_index_pointer = buffer_holder_outer.get();
+
+ std::transform(inner_data.begin(), inner_data.end(),
+ buffer_holder_inner.get(), [](int64_t v) -> Eigen::Index {
+ return narrow(v);
+ });
+ std::transform(outer_data.begin(), outer_data.end(),
+ buffer_holder_outer.get(), [](int64_t v) -> Eigen::Index {
+ return narrow(v);
+ });
+ }
+ ConstSparseMatrixMap map_A(narrow(a_dims[0]), narrow(a_dims[1]),
+ narrow(A.NumValues()), outer_index_pointer, inner_index_pointer,
A.Values().Data());
- ConstEigenMatrixMapRowMajor map_B(B.Data(), b_dims[0], b_dims[1]);
- EigenMatrixMapRowMajor output_map(output.MutableData(), out_dims[0], out_dims[1]);
+ ConstEigenMatrixMapRowMajor map_B(B.Data(), narrow(b_dims[0]), narrow(b_dims[1]));
+ EigenMatrixMapRowMajor output_map(output.MutableData(), narrow(out_dims[0]),
+ narrow(out_dims[1]));
// XXX: Consider re-writing it as a parallel loop as Eigen requires it to use OpenMP
// XXX: Consider vectorization
SparseDenseMatMulImpl(ctx, map_A, map_B, output_map);
}
};
-#endif //! defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__)
-
template
inline T Mul(T a_value, float, T b_value) {
return a_value * b_value;
@@ -121,9 +147,11 @@ struct SparseToDenseCoo {
auto coo_view = A.AsCoo();
const auto& ind_dims = coo_view.Indices().Shape().GetDims();
ORT_RETURN_IF_NOT(ind_dims.size() == 2, "COO indices must be 2-D, got: ", ind_dims.size());
- ConstEigenMatrixMapRowMajor a_indicies_map(coo_view.Indices().Data(), narrow(ind_dims[0]), narrow(ind_dims[1]));
+ ConstEigenMatrixMapRowMajor a_indicies_map(coo_view.Indices().Data(), narrow(ind_dims[0]),
+ narrow(ind_dims[1]));
ConstEigenMatrixMapRowMajor map_b(B.Data(), narrow(b_dims[0]), narrow(b_dims[1]));
- EigenMatrixMapRowMajor output_map(output.MutableData(), narrow(out_dims[0]), narrow(out_dims[1]));
+ EigenMatrixMapRowMajor output_map(output.MutableData(), narrow(out_dims[0]),
+ narrow(out_dims[1]));
output_map.setZero();
const auto rhs_right = (ctx.trans_B) ? b_dims[0] : b_dims[1];
@@ -140,7 +168,8 @@ struct SparseToDenseCoo {
ORT_RETURN_IF_NOT(m < out_left, "COO m index: ", m, " is out of bounds of out_left: ", out_left);
const T a_value = a_values[i];
for (int64_t n = 0; n < rhs_right; ++n) {
- const T b_value = (ctx.trans_B) ? map_b(narrow(n), narrow(k)) : map_b(narrow(k), narrow(n));
+ const T b_value =
+ (ctx.trans_B) ? map_b(narrow(n), narrow(k)) : map_b(narrow(k), narrow(n));
output_map(narrow(m), narrow(n)) += Mul(a_value, ctx.alpha, b_value);
}
}
@@ -170,8 +199,9 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const {
const auto inner_B = (trans_b_attr_) ? b_dims[1] : b_dims[0];
const auto outer_B = (trans_b_attr_) ? b_dims[0] : b_dims[1];
- ORT_RETURN_IF_NOT(inner_A == inner_B, "Can not multiply A and B as inner dimension does not match. inner_A: ",
- inner_A, " vs inner_B: ", inner_B);
+ ORT_RETURN_IF_NOT(inner_A == inner_B,
+ "Can not multiply A and B as inner dimension does not match. inner_A: ", inner_A,
+ " vs inner_B: ", inner_B);
TensorShape output_shape{outer_A, outer_B};
auto* output = ctx->Output(0, output_shape);
@@ -184,12 +214,10 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const {
auto coo_view = A->AsCoo();
const auto num_dims = coo_view.Indices().Shape().NumDimensions();
ORT_RETURN_IF_NOT(num_dims == 2, "Expecting COO 2-D indices shape");
- ORT_RETURN_IF_NOT(A->Values().Shape().Size() * 2 == coo_view.Indices().Shape().Size(), "Expecting 2xValues == indices");
+ ORT_RETURN_IF_NOT(A->Values().Shape().Size() * 2 == coo_view.Indices().Shape().Size(),
+ "Expecting 2xValues == indices");
auto status = t_disp.InvokeRet(compute_ctx, *A, *B, *output);
ORT_RETURN_IF_ERROR(status);
-// Eigen has a bug in x86 where it calculates reallocation size as -1
-// and throws bad_alloc
-#if !defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__)
} else if (A->Format() == SparseFormat::kCsrc) {
auto csr_view = A->AsCsr();
ORT_RETURN_IF_NOT(A->Values().Shape().Size() == csr_view.Inner().Shape().Size(),
@@ -199,11 +227,6 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const {
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Currently support only COO and CSR(x64) formats");
}
-#else
- } else {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "WASM and 32-bit builds support only COO format");
- }
-#endif //! defined(__i386__) && !defined(_M_IX86) && !defined(__wasm__) && !defined(__ANDROID__)
return Status::OK();
}
@@ -211,4 +234,4 @@ Status SparseToDenseMatMul::Compute(OpKernelContext* ctx) const {
} // namespace contrib
} // namespace onnxruntime
-#endif //! defined(DISABLE_SPARSE_TENSORS)
\ No newline at end of file
+#endif //! defined(DISABLE_SPARSE_TENSORS)
diff --git a/onnxruntime/contrib_ops/cuda/bert/transformer_cuda_common.h b/onnxruntime/contrib_ops/cuda/bert/transformer_cuda_common.h
index faf9310c4c3fd..a0da24210459c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/transformer_cuda_common.h
+++ b/onnxruntime/contrib_ops/cuda/bert/transformer_cuda_common.h
@@ -3,7 +3,7 @@
#pragma once
-#include "core/providers/cuda/cuda_common.h"
+#include
namespace onnxruntime {
namespace contrib {
diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc
index 574a3133de815..0f42363bca22d 100644
--- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.cc
@@ -24,9 +24,7 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
-#define NCCL_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(NCCL_CALL(expr))
-
-static ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type) {
+ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type) {
if (type == DataTypeImpl::GetType()) {
return ncclUint8;
} else if (type == DataTypeImpl::GetType()) {
diff --git a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h
index 7fc26e6be57b9..9ea61f2bd952d 100644
--- a/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h
+++ b/onnxruntime/contrib_ops/cuda/collective/nccl_kernels.h
@@ -7,17 +7,21 @@
#if defined(ORT_USE_NCCL)
#include
-#include
#include
-#include
+#include
#include
#include
+#include
#endif
namespace onnxruntime {
namespace contrib {
namespace cuda {
+#define NCCL_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(NCCL_CALL(expr))
+
+ncclDataType_t GetNcclDataType(onnxruntime::MLDataType type);
+
// -----------------------------------------------------------------------
// Defines a new version of nccl classes
// that independent with training::DistributedRunContext, only rely on MPI
diff --git a/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
new file mode 100644
index 0000000000000..40a667ffd5d83
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/collective/sharded_moe.cc
@@ -0,0 +1,204 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/common/safeint.h"
+#include "core/providers/cuda/cuda_common.h"
+#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
+#include "sharded_moe.h"
+
+using namespace onnxruntime::cuda;
+using namespace ::onnxruntime::common;
+using namespace ONNX_NAMESPACE;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+#if defined(ORT_USE_NCCL)
+
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ ShardedMoE, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .MayInplace(0, 0) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ ShardedMoE);
+
+REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
+
+using namespace ONNX_NAMESPACE;
+
+template
+ShardedMoE::ShardedMoE(const OpKernelInfo& op_kernel_info) : NcclKernel(op_kernel_info), MoEBase(op_kernel_info) {
+ ORT_ENFORCE(op_kernel_info.GetAttr("local_experts_start_index", &local_experts_start_index_).IsOK());
+ rank_to_experts_start_index_.resize(nccl_->Size());
+ // Initialize rank_to_experts_start_index_[0] to a value to convey that it is not initialized.
+ rank_to_experts_start_index_[0] = std::numeric_limits::min();
+}
+
+template
+Status ShardedMoE::ComputeInternal(OpKernelContext* context) const {
+ typedef typename ToCudaType::MappedType CudaT;
+ auto stream = context->GetComputeStream();
+
+ auto& device_prop = GetDeviceProp();
+ const int sm = device_prop.major * 10 + device_prop.minor;
+
+ AllocatorPtr allocator;
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
+
+ // Create a {Rank, ExpertsStartIndex} map on Host.
+ AutoDestoryCudaEvent cuda_event;
+ cudaEvent_t& copy_event = cuda_event.Get();
+ ORT_RETURN_IF_ERROR(SynchronizeExpertsStartIndex(allocator, context, copy_event));
+
+ const Tensor* input = context->Input(0);
+ const Tensor* router_probs = context->Input(1);
+ const Tensor* fc1_experts_weights = context->Input(2);
+ const Tensor* fc2_experts_weights = context->Input(3);
+ const Tensor* fc1_experts_bias_optional = context->Input(4);
+ const Tensor* fc2_experts_bias_optional = context->Input(5);
+
+ MoEParameters moe_params;
+ ORT_RETURN_IF_ERROR(CheckInputs(moe_params, input, router_probs, fc1_experts_weights, fc2_experts_weights,
+ fc1_experts_bias_optional, fc2_experts_bias_optional));
+ ORT_RETURN_IF_NOT(moe_params.num_experts % nccl_->Size() == 0,
+ "num_experts should be divisible by world_size");
+
+ ort_fastertransformer::CutlassMoeFCRunner moe_runner(sm);
+
+ size_t ws_size =
+ moe_runner.getWorkspaceSize(static_cast(moe_params.num_rows), static_cast(moe_params.hidden_size),
+ static_cast(moe_params.inter_size), static_cast(moe_params.num_experts),
+ static_cast(k_));
+
+ size_t fc2_output_size = k_ * moe_params.num_rows * moe_params.hidden_size * sizeof(CudaT);
+ size_t expert_scales_size = k_ * moe_params.num_rows * sizeof(CudaT);
+ size_t expanded_source_row_to_expanded_dest_row_size = k_ * moe_params.num_rows * sizeof(int);
+ size_t expert_for_source_row_size = k_ * moe_params.num_rows * sizeof(int);
+
+ // TODO: allocate one buffer and reuse it.
+ IAllocatorUniquePtr work_space = IAllocator::MakeUniquePtr