diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index cc63844f46d28..b03a3019764ca 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -73,6 +73,11 @@ option(onnxruntime_ENABLE_PYTHON "Enable python buildings" OFF) # Enable it may cause LNK1169 error option(onnxruntime_ENABLE_MEMLEAK_CHECKER "Experimental: Enable memory leak checker in Windows debug build" OFF) option(onnxruntime_USE_CUDA "Build with CUDA support" OFF) +# Enable ONNX Runtime CUDA EP's internal unit tests that directly access the EP's internal functions instead of through +# OpKernels. When the option is ON, we will have two copies of GTest library in the same process. It is not a typical +# use. If you hit any problem with that, please do not report it to GTest. Turn OFF the following build option instead. +cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS;LINUX" OFF) + option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF) option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF) option(onnxruntime_USE_COREML "Build with CoreML support" OFF) @@ -146,10 +151,11 @@ option(onnxruntime_DISABLE_SPARSE_TENSORS "Disable sparse tensors data types" OF option(onnxruntime_DISABLE_OPTIONAL_TYPE "Disable optional type" OFF) option(onnxruntime_DISABLE_FLOAT8_TYPES "Disable float 8 types" OFF) option(onnxruntime_MINIMAL_BUILD "Exclude as much as possible from the build. Support ORT format models. No support for ONNX format models." OFF) -cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON" OFF) +cmake_dependent_option(onnxruntime_DISABLE_RTTI "Disable RTTI" ON "NOT onnxruntime_ENABLE_PYTHON;NOT onnxruntime_USE_CUDA" OFF) # For now onnxruntime_DISABLE_EXCEPTIONS will only work with onnxruntime_MINIMAL_BUILD, more changes (ONNX, non-CPU EP, ...) are required to run this standalone cmake_dependent_option(onnxruntime_DISABLE_EXCEPTIONS "Disable exception handling. Requires onnxruntime_MINIMAL_BUILD currently." ON "onnxruntime_MINIMAL_BUILD;NOT onnxruntime_ENABLE_PYTHON" OFF) -option(onnxruntime_DISABLE_ABSEIL "Do not link to Abseil. Redefine Inlined containers to STD containers." OFF) +# Even when onnxruntime_DISABLE_ABSEIL is ON, ONNX Runtime still needs to link to abseil. +option(onnxruntime_DISABLE_ABSEIL "Do not use Abseil data structures in ONNX Runtime source code. Redefine Inlined containers to STD containers." OFF) option(onnxruntime_EXTENDED_MINIMAL_BUILD "onnxruntime_MINIMAL_BUILD with support for execution providers that compile kernels." OFF) option(onnxruntime_MINIMAL_BUILD_CUSTOM_OPS "Add custom operator kernels support to a minimal build." OFF) @@ -269,10 +275,6 @@ if (onnxruntime_ENABLE_TRAINING_APIS) endif() endif() -if (onnxruntime_USE_CUDA) - set(onnxruntime_DISABLE_RTTI OFF) -endif() - if (onnxruntime_USE_ROCM) if (WIN32) message(FATAL_ERROR "ROCM does not support build in Windows!") diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake index e1671bcf43ed9..019c6341d2e46 100644 --- a/cmake/external/onnxruntime_external_deps.cmake +++ b/cmake/external/onnxruntime_external_deps.cmake @@ -37,8 +37,12 @@ if (onnxruntime_BUILD_UNIT_TESTS) set(gtest_disable_pthreads ON) endif() set(INSTALL_GTEST OFF CACHE BOOL "" FORCE) - # Set it to ON will cause crashes in onnxruntime_test_all when onnxruntime_USE_CUDA is ON - set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) + if (CMAKE_SYSTEM_NAME STREQUAL "iOS") + # Needs to update onnxruntime/test/xctest/xcgtest.mm + set(GTEST_HAS_ABSL OFF CACHE BOOL "" FORCE) + else() + set(GTEST_HAS_ABSL ON CACHE BOOL "" FORCE) + endif() # gtest and gmock FetchContent_Declare( googletest diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 4861643832cab..96c05e5282bb5 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -460,13 +460,17 @@ if (onnxruntime_USE_CUDA) if (onnxruntime_REDUCED_OPS_BUILD) substitute_op_reduction_srcs(onnxruntime_providers_cuda_src) endif() - # cuda_provider_interface.cc is removed from the object target: onnxruntime_providers_cuda_obj and - # add to the lib onnxruntime_providers_cuda separatedly. - # onnxruntime_providers_cuda_ut can share all the object files with onnxruntime_providers_cuda except cuda_provider_interface.cc. - set(cuda_provider_interface_src ${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_provider_interface.cc) - list(REMOVE_ITEM onnxruntime_providers_cuda_src ${cuda_provider_interface_src}) - onnxruntime_add_object_library(onnxruntime_providers_cuda_obj ${onnxruntime_providers_cuda_src}) - onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${cuda_provider_interface_src} $) + if(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + # cuda_provider_interface.cc is removed from the object target: onnxruntime_providers_cuda_obj and + # add to the lib onnxruntime_providers_cuda separatedly. + # onnxruntime_providers_cuda_ut can share all the object files with onnxruntime_providers_cuda except cuda_provider_interface.cc. + set(cuda_provider_interface_src ${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_provider_interface.cc) + list(REMOVE_ITEM onnxruntime_providers_cuda_src ${cuda_provider_interface_src}) + onnxruntime_add_object_library(onnxruntime_providers_cuda_obj ${onnxruntime_providers_cuda_src}) + onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${cuda_provider_interface_src} $) + else() + onnxruntime_add_shared_library_module(onnxruntime_providers_cuda ${onnxruntime_providers_cuda_src}) + endif() # config_cuda_provider_shared_module can be used to config onnxruntime_providers_cuda_obj, onnxruntime_providers_cuda & onnxruntime_providers_cuda_ut. # This function guarantees that all 3 targets have the same configurations. function(config_cuda_provider_shared_module target) @@ -600,7 +604,9 @@ if (onnxruntime_USE_CUDA) target_compile_definitions(${target} PRIVATE ENABLE_ATEN) endif() endfunction() - config_cuda_provider_shared_module(onnxruntime_providers_cuda_obj) + if(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + config_cuda_provider_shared_module(onnxruntime_providers_cuda_obj) + endif() config_cuda_provider_shared_module(onnxruntime_providers_cuda) install(TARGETS onnxruntime_providers_cuda diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index ec83eb2095071..0e642c5a7e0aa 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -35,13 +35,17 @@ function(AddTest) if (MSVC AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8) #TODO: fix the warnings, they are dangerous - target_compile_options(${_UT_TARGET} PRIVATE "/wd4244") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4244>" + "$<$>:/wd4244>") endif() if (MSVC) - target_compile_options(${_UT_TARGET} PRIVATE "/wd6330") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd6330>" + "$<$>:/wd6330>") #Abseil has a lot of C4127/C4324 warnings. - target_compile_options(${_UT_TARGET} PRIVATE "/wd4127") - target_compile_options(${_UT_TARGET} PRIVATE "/wd4324") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4127>" + "$<$>:/wd4127>") + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd4324>" + "$<$>:/wd4324>") endif() set_target_properties(${_UT_TARGET} PROPERTIES FOLDER "ONNXRuntimeTest") @@ -60,6 +64,11 @@ function(AddTest) Threads::Threads) target_compile_definitions(${_UT_TARGET} PRIVATE -DUSE_ONNXRUNTIME_DLL) else() + if(onnxruntime_USE_CUDA) + #XXX: we should not need to do this. onnxruntime_test_all.exe should not have direct dependency on CUDA DLLs, + # otherwise it will impact when CUDA DLLs can be unloaded. + target_link_libraries(${_UT_TARGET} PRIVATE cudart) + endif() target_link_libraries(${_UT_TARGET} PRIVATE ${_UT_LIBS} GTest::gtest GTest::gmock ${onnxruntime_EXTERNAL_LIBRARIES}) endif() @@ -85,29 +94,22 @@ function(AddTest) # include dbghelp in case tests throw an ORT exception, as that exception includes a stacktrace, which requires dbghelp. target_link_libraries(${_UT_TARGET} PRIVATE debug dbghelp) - if (onnxruntime_USE_CUDA) - # disable a warning from the CUDA headers about unreferenced local functions - if (MSVC) - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd4505>" - "$<$>:/wd4505>") - endif() - endif() if (MSVC) # warning C6326: Potential comparison of a constant with another constant. # Lot of such things came from gtest - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd6326>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") # Raw new and delete. A lot of such things came from googletest. - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") # "Global initializer calls a non-constexpr function." - target_compile_options(${_UT_TARGET} PRIVATE "$<$:-Xcompiler /wd26426>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") endif() target_compile_options(${_UT_TARGET} PRIVATE ${disabled_warnings}) else() target_compile_options(${_UT_TARGET} PRIVATE ${DISABLED_WARNINGS_FOR_TVM}) - target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:-Xcompiler -Wno-error=sign-compare>" + target_compile_options(${_UT_TARGET} PRIVATE "$<$:SHELL:--compiler-options -Wno-error=sign-compare>" "$<$>:-Wno-error=sign-compare>") target_compile_options(${_UT_TARGET} PRIVATE "-Wno-error=uninitialized") endif() @@ -698,7 +700,7 @@ onnxruntime_add_static_library(onnxruntime_test_utils ${onnxruntime_test_utils_s if(MSVC) target_compile_options(onnxruntime_test_utils PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") - target_compile_options(onnxruntime_test_utils PRIVATE "$<$:-Xcompiler /wd6326>" + target_compile_options(onnxruntime_test_utils PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") else() target_compile_definitions(onnxruntime_test_utils PUBLIC -DNSYNC_ATOMIC_CPP11) @@ -755,13 +757,8 @@ set_target_properties(onnx_test_runner_common PROPERTIES FOLDER "ONNXRuntimeTest set(all_tests ${onnxruntime_test_common_src} ${onnxruntime_test_ir_src} ${onnxruntime_test_optimizer_src} ${onnxruntime_test_framework_src} ${onnxruntime_test_providers_src} ${onnxruntime_test_quantiztion_src}) -if(NOT TARGET onnxruntime AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") - list(APPEND all_tests ${onnxruntime_shared_lib_test_SRC}) -endif() -if (onnxruntime_USE_CUDA) - onnxruntime_add_static_library(onnxruntime_test_cuda_ops_lib ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) - list(APPEND onnxruntime_test_common_libs onnxruntime_test_cuda_ops_lib) +if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) file(GLOB onnxruntime_test_providers_cuda_ut_src CONFIGURE_DEPENDS "${TEST_SRC_DIR}/providers/cuda/test_cases/*" ) @@ -822,7 +819,9 @@ if (onnxruntime_USE_TENSORRT) # made test name contain the "ep" and "model path" information, so we can easily filter the tests using cuda ep or other ep with *cpu_* or *xxx_*. list(APPEND test_all_args "--gtest_filter=-*cpu_*:*cuda_*" ) endif () - +if(NOT onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS) + list(REMOVE_ITEM all_tests ${TEST_SRC_DIR}/providers/cuda/cuda_provider_test.cc) +endif() AddTest( TARGET onnxruntime_test_all SOURCES ${all_tests} ${onnxruntime_unittest_main_src} @@ -832,11 +831,15 @@ AddTest( DEPENDS ${all_dependencies} TEST_ARGS ${test_all_args} ) + if (MSVC) # The warning means the type of two integral values around a binary operator is narrow than their result. # If we promote the two input values first, it could be more tolerant to integer overflow. # However, this is test code. We are less concerned. - target_compile_options(onnxruntime_test_all PRIVATE "/wd26451" "/wd4244") + target_compile_options(onnxruntime_test_all PRIVATE "$<$:SHELL:--compiler-options /wd26451>" + "$<$>:/wd26451>") + target_compile_options(onnxruntime_test_all PRIVATE "$<$:SHELL:--compiler-options /wd4244>" + "$<$>:/wd4244>") else() target_compile_options(onnxruntime_test_all PRIVATE "-Wno-parentheses") endif() @@ -1092,18 +1095,18 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_include_directories(onnxruntime_benchmark PRIVATE ${ONNXRUNTIME_ROOT} ${onnxruntime_graph_header} ${ONNXRUNTIME_ROOT}/core/mlas/inc) target_compile_definitions(onnxruntime_benchmark PRIVATE BENCHMARK_STATIC_DEFINE) if(WIN32) - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd4141>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd4141>" "$<$>:/wd4141>") # Avoid using new and delete. But this is a benchmark program, it's ok if it has a chance to leak. - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26400>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26400>" "$<$>:/wd26400>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26814>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26814>" "$<$>:/wd26814>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26814>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26814>" "$<$>:/wd26497>") - target_compile_options(onnxruntime_benchmark PRIVATE "$<$:-Xcompiler /wd26426>" + target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") target_compile_options(onnxruntime_benchmark PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") @@ -1255,7 +1258,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) list(APPEND onnxruntime_shared_lib_test_LIBS cpuinfo) endif() if (onnxruntime_USE_CUDA) - list(APPEND onnxruntime_shared_lib_test_LIBS onnxruntime_test_cuda_ops_lib cudart) + list(APPEND onnxruntime_shared_lib_test_LIBS cudart) endif() if (onnxruntime_USE_TENSORRT) list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER}) @@ -1270,7 +1273,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) LIBS ${onnxruntime_shared_lib_test_LIBS} DEPENDS ${all_dependencies} ) - + if (onnxruntime_USE_CUDA) + target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) + endif() if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc" @@ -1356,13 +1362,13 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) ) onnxruntime_add_executable(onnxruntime_mlas_test ${onnxruntime_mlas_test_src}) if(MSVC) - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /utf-8>" "$<$>:/utf-8>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:-Xcompiler /wd6326>" + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd6326>" "$<$>:/wd6326>") - target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:-Xcompiler /wd26426>" + target_compile_options(onnxruntime_mlas_test PRIVATE "$<$:SHELL:--compiler-options /wd26426>" "$<$>:/wd26426>") endif() if(${CMAKE_SYSTEM_NAME} STREQUAL "iOS") @@ -1476,7 +1482,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") "${TEST_SRC_DIR}/testdata/custom_op_library/cuda/cuda_ops.*") list(APPEND custom_op_lib_include ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} ${onnxruntime_CUDNN_HOME}/include) if (HAS_QSPECTRE) - list(APPEND custom_op_lib_option "$<$:SHELL:-Xcompiler /Qspectre>") + list(APPEND custom_op_lib_option "$<$:SHELL:--compiler-options /Qspectre>") endif() endif() @@ -1503,7 +1509,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten") else() set(ONNXRUNTIME_CUSTOM_OP_LIB_LINK_FLAG "-DEF:${TEST_SRC_DIR}/testdata/custom_op_library/custom_op_library.def") if (NOT onnxruntime_USE_CUDA) - target_compile_options(custom_op_library PRIVATE "$<$:-Xcompiler /wd26409>" + target_compile_options(custom_op_library PRIVATE "$<$:SHELL:--compiler-options /wd26409>" "$<$>:/wd26409>") endif() endif() diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2e247df227cbf..b8070f1a7268a 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -746,6 +746,8 @@ struct OrtApi { /** \brief Create an OrtEnv * + * \note Invoking this function will return the same instance of the environment as that returned by a previous call + * to another env creation function; all arguments to this function will be ignored. * \param[in] log_severity_level The log severity level. * \param[in] logid The log identifier. * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv @@ -756,17 +758,20 @@ struct OrtApi { /** \brief Create an OrtEnv * + * \note Invoking this function will return the same instance of the environment as that returned by a previous call + * to another env creation function; all arguments to this function will be ignored. If you want to provide your + * own logging function, consider setting it using the SetUserLoggingFunction API instead. * \param[in] logging_function A pointer to a logging function. * \param[in] logger_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to - * `logging_function`. + * `logging_function`. This parameter is optional. * \param[in] log_severity_level The log severity level. * \param[in] logid The log identifier. * \param[out] out Returned newly created OrtEnv. Must be freed with OrtApi::ReleaseEnv * * \snippet{doc} snippets.dox OrtStatus Return Value */ - ORT_API2_STATUS(CreateEnvWithCustomLogger, OrtLoggingFunction logging_function, _In_opt_ void* logger_param, - OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); + ORT_API2_STATUS(CreateEnvWithCustomLogger, _In_ OrtLoggingFunction logging_function, _In_opt_ void* logger_param, + _In_ OrtLoggingLevel log_severity_level, _In_ const char* logid, _Outptr_ OrtEnv** out); /** \brief Enable Telemetry * @@ -4415,13 +4420,58 @@ struct OrtApi { */ ORT_API2_STATUS(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resouce_version, _In_ int resource_id, _Outptr_ void** resource); + /** \brief Set user logging function + * + * By default the logger created by the CreateEnv* functions is used to create the session logger as well. + * This function allows a user to override this default session logger with a logger of their own choosing. This way + * the user doesn't have to create a separate environment with a custom logger. This addresses the problem when + * the user already created an env but now wants to use a different logger for a specific session (for debugging or + * other reasons). + * + * \param[in] options + * \param[in] user_logging_function A pointer to a logging function. + * \param[in] user_logging_param A pointer to arbitrary data passed as the ::OrtLoggingFunction `param` parameter to + * `user_logging_function`. This parameter is optional. + * + * \snippet{doc} snippets.dox OrtStatus Return Value + * + * \since Version 1.17. + */ + ORT_API2_STATUS(SetUserLoggingFunction, _Inout_ OrtSessionOptions* options, + _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param); - // todo - doc + /** + * Get number of input from OrtShapeInferContext + * + * \param[in] context + * \param[out] out The number of inputs + */ ORT_API2_STATUS(ShapeInferContext_GetInputCount, _In_ const OrtShapeInferContext* context, _Out_ size_t* out); - // todo - return const ... *? + + /** + * Get type and shape info of an input + * + * \param[in] context + * \param[in] index The index of the input + * \param[out] info Type shape info of the input + */ ORT_API2_STATUS(ShapeInferContext_GetInputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _Outptr_ OrtTensorTypeAndShapeInfo** info); - // customer needs to manage the lifetime of info + + /** + * Set type and shape info of an ouput + * + * \param[in] context + * \param[in] index The index of the ouput + * \param[out] info Type shape info of the output + */ ORT_API2_STATUS(ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info); + + /** + * Set symbolic shape to type shape info + * \param[in] context + * \param[in] dim_params Symbolic strings + * \param[in] dim_params_length Number of strings + */ ORT_API2_STATUS(SetSymbolicDimensions, _In_ OrtTensorTypeAndShapeInfo* info, _In_ const char* dim_params[], _In_ size_t dim_params_length); }; diff --git a/js/common/lib/backend-impl.ts b/js/common/lib/backend-impl.ts index 75feba1d0ae08..e129c6971a85c 100644 --- a/js/common/lib/backend-impl.ts +++ b/js/common/lib/backend-impl.ts @@ -26,7 +26,7 @@ const backendsSortedByPriority: string[] = []; * @ignore */ export const registerBackend = (name: string, backend: Backend, priority: number): void => { - if (backend && typeof backend.init === 'function' && typeof backend.createSessionHandler === 'function') { + if (backend && typeof backend.init === 'function' && typeof backend.createInferenceSessionHandler === 'function') { const currentBackend = backends.get(name); if (currentBackend === undefined) { backends.set(name, {backend, priority}); diff --git a/js/common/lib/backend.ts b/js/common/lib/backend.ts index 804f33f00d103..dd04ef3f15997 100644 --- a/js/common/lib/backend.ts +++ b/js/common/lib/backend.ts @@ -3,6 +3,7 @@ import {InferenceSession} from './inference-session.js'; import {OnnxValue} from './onnx-value.js'; +import {TrainingSession} from './training-session.js'; /** * @ignore @@ -14,16 +15,23 @@ export declare namespace SessionHandler { } /** - * Represent a handler instance of an inference session. + * Represents shared SessionHandler functionality * * @ignore */ -export interface SessionHandler { +interface SessionHandler { dispose(): Promise; readonly inputNames: readonly string[]; readonly outputNames: readonly string[]; +} +/** + * Represent a handler instance of an inference session. + * + * @ignore + */ +export interface InferenceSessionHandler extends SessionHandler { startProfiling(): void; endProfiling(): void; @@ -31,6 +39,20 @@ export interface SessionHandler { options: InferenceSession.RunOptions): Promise; } +/** + * Represent a handler instance of a training inference session. + * + * @ignore + */ +export interface TrainingSessionHandler extends SessionHandler { + runTrainStep( + feeds: SessionHandler.FeedsType, fetches: SessionHandler.FetchesType, + options: InferenceSession.RunOptions): Promise; + + loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + getContiguousParameters(trainableOnly: boolean): Promise; +} + /** * Represent a backend that provides implementation of model inferencing. * @@ -42,8 +64,13 @@ export interface Backend { */ init(): Promise; - createSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise; + createInferenceSessionHandler(uriOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise; + + createTrainingSessionHandler? + (checkpointStateUriOrBuffer: TrainingSession.URIorBuffer, trainModelUriOrBuffer: TrainingSession.URIorBuffer, + evalModelUriOrBuffer: TrainingSession.URIorBuffer, optimizerModelUriOrBuffer: TrainingSession.URIorBuffer, + options: InferenceSession.SessionOptions): Promise; } export {registerBackend} from './backend-impl.js'; diff --git a/js/common/lib/env.ts b/js/common/lib/env.ts index 525272294c587..76575ef7b9368 100644 --- a/js/common/lib/env.ts +++ b/js/common/lib/env.ts @@ -9,6 +9,7 @@ export declare namespace Env { 'ort-wasm.wasm'?: string; 'ort-wasm-threaded.wasm'?: string; 'ort-wasm-simd.wasm'?: string; + 'ort-training-wasm-simd.wasm'?: string; 'ort-wasm-simd-threaded.wasm'?: string; /* eslint-enable @typescript-eslint/naming-convention */ }; @@ -105,6 +106,12 @@ export declare namespace Env { * see comments on {@link GpuBufferType} for more details about why not use types defined in "@webgpu/types". */ readonly device: unknown; + /** + * Set or get whether validate input content. + * + * @defaultValue `false` + */ + validateInputContent?: boolean; } } diff --git a/js/common/lib/index.ts b/js/common/lib/index.ts index 85df1747f8576..9cbfcc4e8bcdc 100644 --- a/js/common/lib/index.ts +++ b/js/common/lib/index.ts @@ -22,3 +22,4 @@ export * from './env.js'; export * from './inference-session.js'; export * from './tensor.js'; export * from './onnx-value.js'; +export * from './training-session.js'; diff --git a/js/common/lib/inference-session-impl.ts b/js/common/lib/inference-session-impl.ts index 06949b4a26c0d..9bc2088f2088a 100644 --- a/js/common/lib/inference-session-impl.ts +++ b/js/common/lib/inference-session-impl.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. import {resolveBackend} from './backend-impl.js'; -import {SessionHandler} from './backend.js'; +import {InferenceSessionHandler} from './backend.js'; import {InferenceSession as InferenceSessionInterface} from './inference-session.js'; import {OnnxValue} from './onnx-value.js'; import {Tensor} from './tensor.js'; @@ -14,7 +14,7 @@ type FetchesType = InferenceSessionInterface.FetchesType; type ReturnType = InferenceSessionInterface.ReturnType; export class InferenceSession implements InferenceSessionInterface { - private constructor(handler: SessionHandler) { + private constructor(handler: InferenceSessionHandler) { this.handler = handler; } run(feeds: FeedsType, options?: RunOptions): Promise; @@ -195,7 +195,7 @@ export class InferenceSession implements InferenceSessionInterface { const eps = options.executionProviders || []; const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); const backend = await resolveBackend(backendHints); - const handler = await backend.createSessionHandler(filePathOrUint8Array, options); + const handler = await backend.createInferenceSessionHandler(filePathOrUint8Array, options); return new InferenceSession(handler); } @@ -213,5 +213,5 @@ export class InferenceSession implements InferenceSessionInterface { return this.handler.outputNames; } - private handler: SessionHandler; + private handler: InferenceSessionHandler; } diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts new file mode 100644 index 0000000000000..f06d06bda035f --- /dev/null +++ b/js/common/lib/training-session-impl.ts @@ -0,0 +1,49 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TrainingSessionHandler} from './backend.js'; +import {InferenceSession as InferenceSession} from './inference-session.js'; +import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; + +type SessionOptions = InferenceSession.SessionOptions; + +export class TrainingSession implements TrainingSessionInterface { + private constructor(handler: TrainingSessionHandler) { + this.handler = handler; + } + private handler: TrainingSessionHandler; + + get inputNames(): readonly string[] { + return this.handler.inputNames; + } + get outputNames(): readonly string[] { + return this.handler.outputNames; + } + + static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions): + Promise { + throw new Error('Method not implemented'); + } + + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + + async getContiguousParameters(_trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + + runTrainStep(feeds: InferenceSession.OnnxValueMapType, options?: InferenceSession.RunOptions|undefined): + Promise; + runTrainStep( + feeds: InferenceSession.OnnxValueMapType, fetches: InferenceSession.FetchesType, + options?: InferenceSession.RunOptions|undefined): Promise; + async runTrainStep(_feeds: unknown, _fetches?: unknown, _options?: unknown): + Promise { + throw new Error('Method not implemented.'); + } + + async release(): Promise { + return this.handler.dispose(); + } +} diff --git a/js/common/lib/training-session.ts b/js/common/lib/training-session.ts new file mode 100644 index 0000000000000..0967d79b33434 --- /dev/null +++ b/js/common/lib/training-session.ts @@ -0,0 +1,134 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession} from './inference-session.js'; +import {TrainingSession as TrainingSessionImpl} from './training-session-impl.js'; + +/* eslint-disable @typescript-eslint/no-redeclare */ + +export declare namespace TrainingSession { + /** + * Either URI file path (string) or Uint8Array containing model or checkpoint information. + */ + type URIorBuffer = string|Uint8Array; +} + +/** + * Represent a runtime instance of an ONNX training session, + * which contains a model that can be trained, and, optionally, + * an eval and optimizer model. + */ +export interface TrainingSession { + // #region run() + + /** + * Run TrainStep asynchronously with the given feeds and options. + * + * @param feeds - Representation of the model input. See type description of `InferenceSession.InputType` for + detail. + * @param options - Optional. A set of options that controls the behavior of model training. + * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding values. + */ + runTrainStep(feeds: InferenceSession.FeedsType, options?: InferenceSession.RunOptions): + Promise; + + /** + * Run a single train step with the given inputs and options. + * + * @param feeds - Representation of the model input. + * @param fetches - Representation of the model output. + * detail. + * @param options - Optional. A set of options that controls the behavior of model inference. + * @returns A promise that resolves to a map, which uses output names as keys and OnnxValue as corresponding + values. + */ + runTrainStep( + feeds: InferenceSession.FeedsType, fetches: InferenceSession.FetchesType, + options?: InferenceSession.RunOptions): Promise; + + // #endregion + + // #region copy parameters + /** + * Copies from a buffer containing parameters to the TrainingSession parameters. + * + * @param buffer - buffer containing parameters + * @param trainableOnly - True if trainable parameters only to be modified, false otherwise. + */ + loadParametersBuffer(array: Uint8Array, trainableOnly: boolean): Promise; + + /** + * Copies from the TrainingSession parameters to a buffer. + * + * @param trainableOnly - True if trainable parameters only to be copied, false othrwise. + * @returns A promise that resolves to a buffer of the requested parameters. + */ + getContiguousParameters(trainableOnly: boolean): Promise; + // #endregion + + // #region release() + + /** + * Release the inference session and the underlying resources. + */ + release(): Promise; + // #endregion + + // #region metadata + + /** + * Get input names of the loaded model. + */ + readonly inputNames: readonly string[]; + + /** + * Get output names of the loaded model. + */ + readonly outputNames: readonly string[]; + // #endregion +} + +/** + * Represents the optional parameters that can be passed into the TrainingSessionFactory. + */ +export interface TrainingSessionCreateOptions { + /** + * URI or buffer for a .ckpt file that contains the checkpoint for the training model. + */ + checkpointState: TrainingSession.URIorBuffer; + /** + * URI or buffer for the .onnx training file. + */ + trainModel: TrainingSession.URIorBuffer; + /** + * Optional. URI or buffer for the .onnx optimizer model file. + */ + optimizerModel?: TrainingSession.URIorBuffer; + /** + * Optional. URI or buffer for the .onnx eval model file. + */ + evalModel?: TrainingSession.URIorBuffer; +} + +/** + * Defines method overload possibilities for creating a TrainingSession. + */ +export interface TrainingSessionFactory { + // #region create() + + /** + * Creates a new TrainingSession and asynchronously loads any models passed in through trainingOptions + * + * @param trainingOptions specify models and checkpoints to load into the Training Session + * @param sessionOptions specify configuration for training session behavior + * + * @returns Promise that resolves to a TrainingSession object + */ + create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: InferenceSession.SessionOptions): + Promise; + + // #endregion +} + +// eslint-disable-next-line @typescript-eslint/naming-convention +export const TrainingSession: TrainingSessionFactory = TrainingSessionImpl; diff --git a/js/node/lib/backend.ts b/js/node/lib/backend.ts index d3680f9d44236..5f5ad49a2dea8 100644 --- a/js/node/lib/backend.ts +++ b/js/node/lib/backend.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Backend, InferenceSession, SessionHandler} from 'onnxruntime-common'; +import {Backend, InferenceSession, InferenceSessionHandler, SessionHandler} from 'onnxruntime-common'; import {Binding, binding} from './binding'; -class OnnxruntimeSessionHandler implements SessionHandler { +class OnnxruntimeSessionHandler implements InferenceSessionHandler { #inferenceSession: Binding.InferenceSession; constructor(pathOrBuffer: string|Uint8Array, options: InferenceSession.SessionOptions) { @@ -53,8 +53,8 @@ class OnnxruntimeBackend implements Backend { return Promise.resolve(); } - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { return new Promise((resolve, reject) => { process.nextTick(() => { try { diff --git a/js/react_native/e2e/package.json b/js/react_native/e2e/package.json index 969c70c110123..cd97ec1d099e4 100644 --- a/js/react_native/e2e/package.json +++ b/js/react_native/e2e/package.json @@ -10,7 +10,8 @@ }, "dependencies": { "react": "^18.1.0", - "react-native": "^0.69.1" + "react-native": "^0.69.1", + "react-native-fs": "^2.20.0" }, "devDependencies": { "@babel/core": "^7.17.0", diff --git a/js/react_native/e2e/src/App.tsx b/js/react_native/e2e/src/App.tsx index f3e415f0c5a55..8a76edabc613e 100644 --- a/js/react_native/e2e/src/App.tsx +++ b/js/react_native/e2e/src/App.tsx @@ -8,6 +8,7 @@ import { Image, Text, TextInput, View } from 'react-native'; import { InferenceSession, Tensor } from 'onnxruntime-react-native'; import MNIST, { MNISTInput, MNISTOutput, MNISTResult, } from './mnist-data-handler'; import { Buffer } from 'buffer'; +import { readFile } from 'react-native-fs'; interface State { session: @@ -39,10 +40,21 @@ export default class App extends React.PureComponent<{}, State> { this.setState({ imagePath }); const modelPath = await MNIST.getLocalModelPath(); - const session: InferenceSession = await InferenceSession.create(modelPath); + + // test creating session with path + console.log('Creating with path'); + const pathSession: InferenceSession = await InferenceSession.create(modelPath); + pathSession.release(); + + // and with bytes + console.log('Creating with bytes'); + const base64Str = await readFile(modelPath, 'base64'); + const bytes = Buffer.from(base64Str, 'base64'); + const session: InferenceSession = await InferenceSession.create(bytes); this.setState({ session }); - void this.infer(); + console.log('Test session created'); + void await this.infer(); } catch (err) { console.log(err.message); } diff --git a/js/react_native/lib/backend.ts b/js/react_native/lib/backend.ts index b3f0c466308a5..3b1852699ac48 100644 --- a/js/react_native/lib/backend.ts +++ b/js/react_native/lib/backend.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Backend, InferenceSession, SessionHandler, Tensor,} from 'onnxruntime-common'; +import {Backend, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {Platform} from 'react-native'; import {binding, Binding, JSIBlob, jsiHelper} from './binding'; @@ -43,7 +43,7 @@ const normalizePath = (path: string): string => { return path; }; -class OnnxruntimeSessionHandler implements SessionHandler { +class OnnxruntimeSessionHandler implements InferenceSessionHandler { #inferenceSession: Binding.InferenceSession; #key: string; @@ -66,12 +66,14 @@ class OnnxruntimeSessionHandler implements SessionHandler { let results: Binding.ModelLoadInfoType; // load a model if (typeof this.#pathOrBuffer === 'string') { + // load model from model path results = await this.#inferenceSession.loadModel(normalizePath(this.#pathOrBuffer), options); } else { + // load model from buffer if (!this.#inferenceSession.loadModelFromBlob) { throw new Error('Native module method "loadModelFromBlob" is not defined'); } - const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer); + const modelBlob = jsiHelper.storeArrayBuffer(this.#pathOrBuffer.buffer); results = await this.#inferenceSession.loadModelFromBlob(modelBlob, options); } // resolve promise if onnxruntime session is successfully created @@ -163,8 +165,8 @@ class OnnxruntimeBackend implements Backend { return Promise.resolve(); } - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { const handler = new OnnxruntimeSessionHandler(pathOrBuffer); await handler.loadModel(options || {}); return handler; diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index a87a894e3b3c5..f8ac29e5f82ca 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -62,6 +62,7 @@ Do not modify directly.* | Not | ai.onnx(1+) | | | Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | | Pow | ai.onnx(7-11,12,13-14,15+) | | +| Range | ai.onnx(11+) | | | Reciprocal | ai.onnx(6-12,13+) | | | ReduceL1 | ai.onnx(1-10,11-12,13-17,18+) | | | ReduceL2 | ai.onnx(1-10,11-12,13-17,18+) | | diff --git a/js/web/lib/backend-onnxjs.ts b/js/web/lib/backend-onnxjs.ts index 18a068e0ced8b..5ea7de809a495 100644 --- a/js/web/lib/backend-onnxjs.ts +++ b/js/web/lib/backend-onnxjs.ts @@ -2,7 +2,7 @@ // Licensed under the MIT License. /* eslint-disable import/no-internal-modules */ -import {Backend, InferenceSession, SessionHandler} from 'onnxruntime-common'; +import {Backend, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {Session} from './onnxjs/session'; import {OnnxjsSessionHandler} from './onnxjs/session-handler'; @@ -11,8 +11,8 @@ class OnnxjsBackend implements Backend { // eslint-disable-next-line @typescript-eslint/no-empty-function async init(): Promise {} - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { // NOTE: Session.Config(from onnx.js) is not compatible with InferenceSession.SessionOptions(from // onnxruntime-common). // In future we should remove Session.Config and use InferenceSession.SessionOptions. diff --git a/js/web/lib/backend-wasm.ts b/js/web/lib/backend-wasm.ts index ceb20044d97b6..04108c2ad0f66 100644 --- a/js/web/lib/backend-wasm.ts +++ b/js/web/lib/backend-wasm.ts @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Backend, env, InferenceSession, SessionHandler} from 'onnxruntime-common'; +import {Backend, env, InferenceSession, InferenceSessionHandler} from 'onnxruntime-common'; import {cpus} from 'os'; import {initializeWebAssemblyInstance} from './wasm/proxy-wrapper'; @@ -40,10 +40,12 @@ class OnnxruntimeWebAssemblyBackend implements Backend { // init wasm await initializeWebAssemblyInstance(); } - createSessionHandler(path: string, options?: InferenceSession.SessionOptions): Promise; - createSessionHandler(buffer: Uint8Array, options?: InferenceSession.SessionOptions): Promise; - async createSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): - Promise { + createInferenceSessionHandler(path: string, options?: InferenceSession.SessionOptions): + Promise; + createInferenceSessionHandler(buffer: Uint8Array, options?: InferenceSession.SessionOptions): + Promise; + async createInferenceSessionHandler(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): + Promise { const handler = new OnnxruntimeWebAssemblySessionHandler(); await handler.loadModel(pathOrBuffer, options); return Promise.resolve(handler); diff --git a/js/web/lib/onnxjs/session-handler.ts b/js/web/lib/onnxjs/session-handler.ts index 0b06a7a747a44..47e50aeab673a 100644 --- a/js/web/lib/onnxjs/session-handler.ts +++ b/js/web/lib/onnxjs/session-handler.ts @@ -1,12 +1,12 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {InferenceSession, SessionHandler, Tensor} from 'onnxruntime-common'; +import {InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {Session} from './session'; import {Tensor as OnnxjsTensor} from './tensor'; -export class OnnxjsSessionHandler implements SessionHandler { +export class OnnxjsSessionHandler implements InferenceSessionHandler { constructor(private session: Session) { this.inputNames = this.session.inputNames; this.outputNames = this.session.outputNames; diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 59da1369e152e..b7b2ff4537095 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import type {Tensor} from 'onnxruntime-common'; + export declare namespace JSEP { type BackendType = unknown; type AllocFunction = (size: number) => number; @@ -9,11 +11,8 @@ export declare namespace JSEP { type DownloadFunction = (gpuDataId: number, dataOffset: number, size: number) => Promise; type CreateKernelFunction = (name: string, kernel: number, attribute: unknown) => void; type ReleaseKernelFunction = (kernel: number) => void; - type RunFunction = (kernel: number, contextDataOffset: number, sessionState: SessionState) => number; - export interface SessionState { - sessionId: number; - errors: Array>; - } + type RunFunction = + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => number; } export interface OrtWasmModule extends EmscriptenModule { @@ -40,14 +39,23 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtFree(stringHandle: number): void; - _OrtCreateTensor(dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number): - number; + _OrtCreateTensor( + dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number, + dataLocation: number): number; _OrtGetTensorData(tensorHandle: number, dataType: number, dataOffset: number, dimsOffset: number, dimsLength: number): number; _OrtReleaseTensor(tensorHandle: number): void; + _OrtCreateBinding(sessionHandle: number): number; + _OrtBindInput(bindingHandle: number, nameOffset: number, tensorHandle: number): Promise; + _OrtBindOutput(bindingHandle: number, nameOffset: number, tensorHandle: number, location: number): number; + _OrtClearBoundOutputs(ioBindingHandle: number): void; + _OrtReleaseBinding(ioBindingHandle: number): void; + _OrtRunWithBinding( + sessionHandle: number, ioBindingHandle: number, outputCount: number, outputsOffset: number, + runOptionsHandle: number): Promise; _OrtRun( sessionHandle: number, inputNamesOffset: number, inputsOffset: number, inputCount: number, - outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): number; + outputNamesOffset: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): Promise; _OrtCreateSessionOptions( graphOptimizationLevel: number, enableCpuMemArena: boolean, enableMemPattern: boolean, executionMode: number, @@ -102,17 +110,67 @@ export interface OrtWasmModule extends EmscriptenModule { // #endregion // #region JSEP + /** + * This is the entry of JSEP initialization. This function is called once when initializing ONNX Runtime. + * This function initializes WebGPU backend and registers a few callbacks that will be called in C++ code. + */ jsepInit? (backend: JSEP.BackendType, alloc: JSEP.AllocFunction, free: JSEP.FreeFunction, upload: JSEP.UploadFunction, download: JSEP.DownloadFunction, createKernel: JSEP.CreateKernelFunction, releaseKernel: JSEP.ReleaseKernelFunction, run: JSEP.RunFunction): void; + /** + * [exported from wasm] Specify a kernel's output when running OpKernel::Compute(). + * + * @param context - specify the kernel context pointer. + * @param index - specify the index of the output. + * @param data - specify the pointer to encoded data of type and dims. + */ _JsepOutput(context: number, index: number, data: number): number; + /** + * [exported from wasm] Get name of an operator node. + * + * @param kernel - specify the kernel pointer. + * @returns the pointer to a C-style UTF8 encoded string representing the node name. + */ _JsepGetNodeName(kernel: number): number; - jsepOnRunStart?(sessionId: number): void; - jsepOnRunEnd?(sessionId: number): Promise; - jsepRunPromise?: Promise; + /** + * [exported from js_internal_api.js] Register a user GPU buffer for usage of a session's input or output. + * + * @param sessionId - specify the session ID. + * @param index - specify an integer to represent which input/output it is registering for. For input, it is the + * input_index corresponding to the session's inputNames. For output, it is the inputCount + output_index + * corresponding to the session's ouputNames. + * @param buffer - specify the GPU buffer to register. + * @param size - specify the original data size in byte. + * @returns the GPU data ID for the registered GPU buffer. + */ + jsepRegisterBuffer: (sessionId: number, index: number, buffer: GPUBuffer, size: number) => number; + /** + * [exported from js_internal_api.js] Unregister all user GPU buffers for a session. + * + * @param sessionId - specify the session ID. + */ + jsepUnregisterBuffers?: (sessionId: number) => void; + /** + * [exported from js_internal_api.js] Get the GPU buffer by GPU data ID. + * + * @param dataId - specify the GPU data ID + * @returns the GPU buffer. + */ + jsepGetBuffer: (dataId: number) => GPUBuffer; + /** + * [exported from js_internal_api.js] Create a function to be used to create a GPU Tensor. + * + * @param gpuBuffer - specify the GPU buffer + * @param size - specify the original data size in byte. + * @param type - specify the tensor type. + * @returns the generated downloader function. + */ + jsepCreateDownloader: + (gpuBuffer: GPUBuffer, size: number, + type: Tensor.GpuBufferDataTypes) => () => Promise; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 5e77a0343b4ee..5bec562b157ac 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -1,11 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {Env} from 'onnxruntime-common'; +import {Env, Tensor} from 'onnxruntime-common'; import {configureLogger, LOG_DEBUG} from './log'; -import {TensorView} from './tensor-view'; -import {createGpuDataManager, GpuDataManager} from './webgpu/gpu-data-manager'; +import {createView, TensorView} from './tensor-view'; +import {createGpuDataManager, downloadGpuData, GpuDataManager} from './webgpu/gpu-data-manager'; import {RunFunction, WEBGPU_OP_RESOLVE_RULES} from './webgpu/op-resolve-rules'; import {ProgramManager} from './webgpu/program-manager'; import {ComputeContext, GpuData, ProgramInfo, ProgramInfoLoader} from './webgpu/types'; @@ -98,6 +98,11 @@ export class WebGpuBackend { env: Env; + /** + * a SessionID -> a Map of (InputOutputIndex -> [ID, GPUBuffer]) mapping. + */ + sessionExternalDataMapping: Map> = new Map(); + async initialize(env: Env): Promise { if (!navigator.gpu) { // WebGPU is not available. @@ -192,11 +197,13 @@ export class WebGpuBackend { } flush(): void { - this.endComputePass(); - this.device.queue.submit([this.getCommandEncoder().finish()]); - this.gpuDataManager.refreshPendingBuffers(); - this.commandEncoder = null; - this.pendingDispatchNumber = 0; + if (this.commandEncoder) { + this.endComputePass(); + this.device.queue.submit([this.getCommandEncoder().finish()]); + this.gpuDataManager.refreshPendingBuffers(); + this.commandEncoder = null; + this.pendingDispatchNumber = 0; + } } /** @@ -304,12 +311,9 @@ export class WebGpuBackend { } async download(gpuDataId: number, getTargetBuffer: () => Uint8Array): Promise { - const arrayBuffer = await this.gpuDataManager.download(gpuDataId); - // the underlying buffer may be changed after the async function is called. so we use a getter function to make sure // the buffer is up-to-date. - const data = getTargetBuffer(); - data.set(new Uint8Array(arrayBuffer, 0, data.byteLength)); + await this.gpuDataManager.download(gpuDataId, getTargetBuffer); } alloc(size: number): number { @@ -372,7 +376,7 @@ export class WebGpuBackend { kernelEntry(context, attributes[1]); return 0; // ORT_OK } catch (e) { - LOG_DEBUG('warning', `[WebGPU] Kernel "[${opType}] ${nodeName}" failed. Error: ${e}`); + errors.push(Promise.resolve(`[WebGPU] Kernel "[${opType}] ${nodeName}" failed. ${e}`)); return 1; // ORT_FAIL } finally { if (useErrorScope) { @@ -387,4 +391,40 @@ export class WebGpuBackend { this.currentKernelId = null; } } + + // #region external buffer + registerBuffer(sessionId: number, index: number, buffer: GPUBuffer, size: number): number { + let sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId); + if (!sessionInputOutputMapping) { + sessionInputOutputMapping = new Map(); + this.sessionExternalDataMapping.set(sessionId, sessionInputOutputMapping); + } + + const previousBuffer = sessionInputOutputMapping.get(index); + const id = this.gpuDataManager.registerExternalBuffer(buffer, size, previousBuffer?.[1]); + sessionInputOutputMapping.set(index, [id, buffer]); + return id; + } + unregisterBuffers(sessionId: number): void { + const sessionInputOutputMapping = this.sessionExternalDataMapping.get(sessionId); + if (sessionInputOutputMapping) { + sessionInputOutputMapping.forEach(bufferInfo => this.gpuDataManager.unregisterExternalBuffer(bufferInfo[1])); + this.sessionExternalDataMapping.delete(sessionId); + } + } + getBuffer(gpuDataId: number): GPUBuffer { + const gpuData = this.gpuDataManager.get(gpuDataId); + if (!gpuData) { + throw new Error(`no GPU data for buffer: ${gpuDataId}`); + } + return gpuData.buffer; + } + createDownloader(gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes): + () => Promise { + return async () => { + const data = await downloadGpuData(this, gpuBuffer, size); + return createView(data.buffer, type); + }; + } + // #endregion } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 78316cbe1c825..6ff3971d720fd 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -3,7 +3,7 @@ import {Env} from 'onnxruntime-common'; -import {JSEP, OrtWasmModule} from '../binding/ort-wasm'; +import {OrtWasmModule} from '../binding/ort-wasm'; import {DataType, getTensorElementSize} from '../wasm-common'; import {WebGpuBackend} from './backend-webgpu'; @@ -120,6 +120,11 @@ class ComputeContextImpl implements ComputeContext { this.module.HEAPU32[offset++] = dims[i]; } return this.module._JsepOutput(this.opKernelContext, index, data); + } catch (e) { + throw new Error( + `Failed to generate kernel's output[${index}] with dims [${dims}]. ` + + 'If you are running with pre-allocated output, please make sure the output type/dims are correct. ' + + `Error: ${e}`); } finally { this.module.stackRestore(stack); } @@ -138,7 +143,7 @@ export const init = async(module: OrtWasmModule, env: Env): Promise => { init( // backend - {backend}, + backend, // jsepAlloc() (size: number) => backend.alloc(size), @@ -178,13 +183,13 @@ export const init = async(module: OrtWasmModule, env: Env): Promise => { (kernel: number) => backend.releaseKernel(kernel), // jsepRun - (kernel: number, contextDataOffset: number, sessionState: JSEP.SessionState) => { + (kernel: number, contextDataOffset: number, sessionHandle: number, errors: Array>) => { LOG_DEBUG( 'verbose', - () => `[WebGPU] jsepRun: sessionId=${sessionState.sessionId}, kernel=${kernel}, contextDataOffset=${ + () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${ contextDataOffset}`); const context = new ComputeContextImpl(module, backend, contextDataOffset); - return backend.computeKernel(kernel, context, sessionState.errors); + return backend.computeKernel(kernel, context, errors); }); } }; diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 92fdd5abc3892..131f7a9bfa29b 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -35,7 +35,7 @@ export interface GpuDataManager { /** * copy data from GPU to CPU. */ - download(id: GpuDataId): Promise; + download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise; /** * refresh the buffers that marked for release. @@ -46,6 +46,19 @@ export interface GpuDataManager { */ refreshPendingBuffers(): void; + /** + * register an external buffer for IO Binding. If the buffer is already registered, return the existing GPU data ID. + * + * GPU data manager only manages a mapping between the buffer and the GPU data ID. It will not manage the lifecycle of + * the external buffer. + */ + registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number; + + /** + * unregister an external buffer for IO Binding. + */ + unregisterExternalBuffer(buffer: GPUBuffer): void; + /** * destroy all gpu buffers. Call this when the session.release is called. */ @@ -62,12 +75,56 @@ interface StorageCacheValue { */ const calcNormalizedBufferSize = (size: number) => Math.ceil(size / 16) * 16; -let guid = 0; +let guid = 1; const createNewGpuDataId = () => guid++; +/** + * exported standard download function. This function is used by the session to download the data from GPU, and also by + * factory to create GPU tensors with the capacity of downloading data from GPU. + * + * @param backend - the WebGPU backend + * @param gpuBuffer - the GPU buffer to download + * @param originalSize - the original size of the data + * @param getTargetBuffer - optional. If provided, the data will be copied to the target buffer. Otherwise, a new buffer + * will be created and returned. + */ +export const downloadGpuData = + async(backend: WebGpuBackend, gpuBuffer: GPUBuffer, originalSize: number, getTargetBuffer?: () => Uint8Array): + Promise => { + const bufferSize = calcNormalizedBufferSize(originalSize); + const gpuReadBuffer = backend.device.createBuffer( + // eslint-disable-next-line no-bitwise + {size: bufferSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ}); + try { + const commandEncoder = backend.getCommandEncoder(); + backend.endComputePass(); + commandEncoder.copyBufferToBuffer( + gpuBuffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */, + 0 /* destination offset */, bufferSize /* size */ + ); + backend.flush(); + + await gpuReadBuffer.mapAsync(GPUMapMode.READ); + + const arrayBuffer = gpuReadBuffer.getMappedRange(); + if (getTargetBuffer) { + // if we already have a CPU buffer to accept the data, no need to clone the ArrayBuffer. + const targetBuffer = getTargetBuffer(); + targetBuffer.set(new Uint8Array(arrayBuffer, 0, originalSize)); + return targetBuffer; + } else { + // the mapped ArrayBuffer will be released when the GPU buffer is destroyed. Need to clone the + // ArrayBuffer. + return new Uint8Array(arrayBuffer.slice(0, originalSize)); + } + } finally { + gpuReadBuffer.destroy(); + } + }; + class GpuDataManagerImpl implements GpuDataManager { // GPU Data ID => GPU Data ( storage buffer ) - storageCache: Map; + private storageCache: Map; // pending buffers for uploading ( data is unmapped ) private buffersForUploadingPending: GPUBuffer[]; @@ -77,11 +134,15 @@ class GpuDataManagerImpl implements GpuDataManager { // The reusable storage buffers for computing. private freeBuffers: Map; + // The external buffers registered users for IO Binding. + private externalBuffers: Map; + constructor(private backend: WebGpuBackend) { this.storageCache = new Map(); this.freeBuffers = new Map(); this.buffersForUploadingPending = []; this.buffersPending = []; + this.externalBuffers = new Map(); } upload(id: GpuDataId, data: Uint8Array): void { @@ -143,6 +204,42 @@ class GpuDataManagerImpl implements GpuDataManager { sourceGpuDataCache.gpuData.buffer, 0, destinationGpuDataCache.gpuData.buffer, 0, size); } + registerExternalBuffer(buffer: GPUBuffer, originalSize: number, previousBuffer?: GPUBuffer): number { + let id: number|undefined; + if (previousBuffer) { + id = this.externalBuffers.get(previousBuffer); + if (id === undefined) { + throw new Error('previous buffer is not registered'); + } + if (buffer === previousBuffer) { + LOG_DEBUG( + 'verbose', + () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ + id}, buffer is the same, skip.`); + return id; + } + this.externalBuffers.delete(previousBuffer); + } else { + id = createNewGpuDataId(); + } + + this.storageCache.set(id, {gpuData: {id, type: GpuDataType.default, buffer}, originalSize}); + this.externalBuffers.set(buffer, id); + LOG_DEBUG( + 'verbose', + () => `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, registered.`); + return id; + } + + unregisterExternalBuffer(buffer: GPUBuffer): void { + const id = this.externalBuffers.get(buffer); + if (id !== undefined) { + this.storageCache.delete(id); + this.externalBuffers.delete(buffer); + LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.unregisterExternalBuffer() => id=${id}`); + } + } + // eslint-disable-next-line no-bitwise create(size: number, usage = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST): GpuData { const bufferSize = calcNormalizedBufferSize(size); @@ -193,31 +290,13 @@ class GpuDataManagerImpl implements GpuDataManager { return cachedData.originalSize; } - async download(id: GpuDataId): Promise { + async download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise { const cachedData = this.storageCache.get(id); if (!cachedData) { throw new Error('data does not exist'); } - const commandEncoder = this.backend.getCommandEncoder(); - this.backend.endComputePass(); - const bufferSize = calcNormalizedBufferSize(cachedData.originalSize); - const gpuReadBuffer = this.backend.device.createBuffer( - // eslint-disable-next-line no-bitwise - {size: bufferSize, usage: GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ}); - commandEncoder.copyBufferToBuffer( - cachedData.gpuData.buffer /* source buffer */, 0 /* source offset */, gpuReadBuffer /* destination buffer */, - 0 /* destination offset */, bufferSize /* size */ - ); - this.backend.flush(); - - return new Promise((resolve) => { - gpuReadBuffer.mapAsync(GPUMapMode.READ).then(() => { - const data = gpuReadBuffer.getMappedRange().slice(0); - gpuReadBuffer.destroy(); - resolve(data); - }); - }); + await downloadGpuData(this.backend, cachedData.gpuData.buffer, cachedData.originalSize, getTargetBuffer); } refreshPendingBuffers(): void { diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index e92e6696d9a78..cbe845b882468 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -16,6 +16,7 @@ import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm'; import {matMul} from './ops/matmul'; import {pad, parsePadAttributes} from './ops/pad'; import * as pool from './ops/pool'; +import {range} from './ops/range'; import {parseReduceAttributes, reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm'; @@ -83,6 +84,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Not', [unaryOps.not]], ['Pad', [pad, parsePadAttributes]], ['Pow', [binaryOps.pow]], + ['Range', [range]], ['Reciprocal', [unaryOps.reciprocal]], ['ReduceMin', [reduceMin, parseReduceAttributes]], ['ReduceMean', [reduceMean, parseReduceAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts index dd4f13e76ee04..22b91d680a9b4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/activation_util.ts @@ -21,16 +21,16 @@ export declare type Activation = 'linear' | 'relu' | 'prelu' | 'elu' | 'relu6' | 'leakyrelu' | 'sigmoid' | 'gelu'; -export const typeSnippet = (component: number) => { +export const typeSnippet = (component: number, dataType: string) => { switch (component) { case 1: - return 'f32'; + return dataType; case 2: - return 'vec2'; + return `vec2<${dataType}>`; case 3: - return 'vec3'; + return `vec3<${dataType}>`; case 4: - return 'vec4'; + return `vec4<${dataType}>`; default: throw new Error(`${component}-component is not supported.`); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts index 08b1d1f30b233..e6d4039d8131b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv2d_mm_webgpu.ts @@ -23,6 +23,7 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; +import {tensorTypeToWsglStorageType} from '../common'; import {ConvAttributes} from '../conv'; import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; @@ -32,13 +33,13 @@ import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packe const conv2dCommonSnippet = (isChannelsLast: boolean, fitAOuter: boolean, fitBOuter: boolean, fitInner: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false, innerElementSizeX = 4, innerElementSizeW = 4, - innerElementSize = 4): string => { + innerElementSize = 4, dataType = 'f32'): string => { const getXSnippet = (innerElementSize: number) => { switch (innerElementSize) { case 1: return 'resData = x[xIndex];'; case 3: - return 'resData = vec3(x[xIndex], x[xIndex + 1], x[xIndex + 2]);'; + return `resData = vec3<${dataType}>(x[xIndex], x[xIndex + 1], x[xIndex + 2]);`; case 4: return 'resData = x[xIndex / 4];'; default: @@ -92,7 +93,7 @@ const conv2dCommonSnippet = let xRow = outRow * stride[0] + dilation[0] * WRow - pad[0]; let xCol = outCol * stride[1] + dilation[1] * WCol - pad[1]; let xCh = ${col} % inChannels; - var resData = ${typeSnippet(innerElementSizeX)}(0.0); + var resData = ${typeSnippet(innerElementSizeX, dataType)}(0.0); // The bounds checking is always needed since we use it to pad zero for // the 'same' padding type. if (xRow >= 0 && xRow < ${xHeight} && xCol >= 0 && xCol < ${xWidth}) { @@ -110,7 +111,7 @@ const conv2dCommonSnippet = if (row < dimAOuter && col < dimInner) { ${readXSnippet} } - return ${typeSnippet(innerElementSizeX)}(0.0);`) : + return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`) : (fitInner && fitBOuter ? ` let col = colIn * ${innerElementSizeX}; ${readXSnippet}` : @@ -119,13 +120,15 @@ const conv2dCommonSnippet = if (row < dimInner && col < dimBOuter) { ${readXSnippet} } - return ${typeSnippet(innerElementSizeX)}(0.0);`); + return ${typeSnippet(innerElementSizeX, dataType)}(0.0);`); const sampleW = `${getWSnippet(innerElementSizeW)}`; - const resType = typeSnippet(innerElementSize); - const aType = isChannelsLast ? typeSnippet(innerElementSizeX) : typeSnippet(innerElementSizeW); - const bType = isChannelsLast ? typeSnippet(innerElementSizeW) : typeSnippet(innerElementSizeX); + const resType = typeSnippet(innerElementSize, dataType); + const aType = + isChannelsLast ? typeSnippet(innerElementSizeX, dataType) : typeSnippet(innerElementSizeW, dataType); + const bType = + isChannelsLast ? typeSnippet(innerElementSizeW, dataType) : typeSnippet(innerElementSizeX, dataType); const userCode = ` ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} { @@ -190,23 +193,24 @@ export const createConv2DMatMulProgramInfo = const fitInner = dimInner % tileInner === 0; const elementsSize = isVec4 ? [innerElementSize, 4, 4] : [1, 1, 1]; + const t = tensorTypeToWsglStorageType(inputs[0].dataType); const declareInputs = [ - `@group(0) @binding(0) var x: array<${isVec4 && innerElementSize === 4 ? 'vec4' : 'f32'}>;`, - `@group(0) @binding(1) var w: array<${isVec4 ? 'vec4' : 'f32'}>;` + `@group(0) @binding(0) var x: array<${isVec4 && innerElementSize === 4 ? `vec4<${t}>` : t}>;`, + `@group(0) @binding(1) var w: array<${isVec4 ? `vec4<${t}>` : t}>;` ]; let declareFunctions = ` - fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? 'vec4' : 'f32'}) { - result[flatIndex] = ${isVec4 ? 'vec4' : 'f32'}(value); + fn setOutputAtIndex(flatIndex : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { + result[flatIndex] = ${isVec4 ? `vec4<${t}>` : t}(value); } - fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? 'vec4' : 'f32'}) { + fn setOutputAtCoords(d0 : i32, d1 : i32, d2 : i32, d3 : i32, value : ${isVec4 ? `vec4<${t}>` : t}) { let flatIndex = getOutputIndexFromCoords(vec4(d0, d1, d2, d3)); setOutputAtIndex(flatIndex ${isVec4 ? '/ 4' : ''}, value); }`; if (hasBias) { - declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); + declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? `vec4<${t}>` : t}>;`); declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${t}>` : t} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -222,7 +226,7 @@ export const createConv2DMatMulProgramInfo = // dilation : vec2, dimAOuter : i32, dimBOuter : i32, dimInner : i32 }; ${declareInputs.join('')} @group(0) @binding(${declareInputs.length}) var result: array<${ - isVec4 ? 'vec4' : 'f32'}>; + isVec4 ? `vec4<${t}>` : t}>; //@group(0) @binding(${declareInputs.length + 1}) var uniforms: Uniforms; const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); @@ -240,12 +244,12 @@ export const createConv2DMatMulProgramInfo = ${ conv2dCommonSnippet( isChannelsLast, fitAOuter, fitBOuter, fitInner, hasBias, undefined, false, elementsSize[0], - elementsSize[1], elementsSize[2])} + elementsSize[1], elementsSize[2], t)} ${ isVec4 ? - makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner) : + makeMatMulPackedVec4Source(elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner) : makeMatMulPackedSource( - elementsPerThread, workGroupSize, undefined, !isChannelsLast, tileInner, false, undefined, + elementsPerThread, workGroupSize, t, undefined, !isChannelsLast, tileInner, false, undefined, sequentialAccessByThreads)}` }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts new file mode 100644 index 0000000000000..f41d0d058a624 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_mm_webgpu.ts @@ -0,0 +1,244 @@ +/** + * @license + * Copyright 2021 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +// sampled from [@tensorflow/tfjs] tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts +// +// modified to fit the needs of the project + +import {LOG_DEBUG} from '../../../log'; +import {TensorView} from '../../../tensor-view'; +import {ShapeUtil} from '../../../util'; +import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; +import {ConvTransposeAttributes} from '../conv-transpose'; + +import {Activation, activationFnSnippet, biasActivationSnippet, typeSnippet} from './activation_util'; +import {utilFunctions} from './conv_util'; +import {makeMatMulPackedSource, makeMatMulPackedVec4Source} from './matmul_packed_webgpu'; + +const conv2dTransposeCommonSnippet = + (isChannelsLast: boolean, addBias = false, activation?: Activation, hasPreluActivationWeights = false, + innerElementSize = 4): string => { + const type = typeSnippet(innerElementSize, 'f32'); + const getWSnippet = (innerElementSize: number) => { + switch (innerElementSize) { + case 1: + return 'return W[getIndexFromCoords4D(coord, wShape)];'; + case 4: + return ` + let coord1 = vec4(coordX, coordY, col + 1, rowInner); + let coord2 = vec4(coordX, coordY, col + 2, rowInner); + let coord3 = vec4(coordX, coordY, col + 3, rowInner); + let v0 = W[getIndexFromCoords4D(coord, wShape)]; + let v1 = W[getIndexFromCoords4D(coord1, wShape)]; + let v2 = W[getIndexFromCoords4D(coord2, wShape)]; + let v3 = W[getIndexFromCoords4D(coord3, wShape)]; + return vec4(v0, v1, v2, v3); + `; + default: + throw new Error(`innerElementSize ${innerElementSize} is not supported.`); + } + }; + const coordASnippet = isChannelsLast ? ` + let coord = vec4(batch, iXR, iXC, xCh); + ` : + ` + let coord = vec4(batch, xCh, iXR, iXC); + `; + + const coordResSnippet = isChannelsLast ? ` + let coords = vec4( + batch, + row / outWidth, + row % outWidth, + col); + ` : + ` + let coords = vec4( + batch, + row, + col / outWidth, + col % outWidth); + `; + + const xHeight = isChannelsLast ? 'outBackprop[1]' : 'outBackprop[2]'; + const xWidth = isChannelsLast ? 'outBackprop[2]' : 'outBackprop[3]'; + const row = isChannelsLast ? 'row' : 'col'; + const col = isChannelsLast ? 'col' : 'row'; + + const readASnippet = ` + let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + let outRow = ${row} / outWidth; + let outCol = ${row} % outWidth; + + let WRow = ${col} / (filterDims[1] * inChannels); + let WCol = ${col} / inChannels % filterDims[1]; + let xR = f32(outRow - pads[0] + dilation[0] * WRow) / f32(strides[0]); + let xC = f32(outCol - pads[1] + dilation[1] * WCol) / f32(strides[1]); + if (xR < 0.0 || xR >= f32(${xHeight}) || fract(xR) > 0.0) { + return ${type}(0.0); + } + if (xC < 0.0 || xC >= f32(${xWidth}) || fract(xC) > 0.0) { + return ${type}(0.0); + } + let iXR = i32(xR); + let iXC = i32(xC); + let xCh = ${col} % inChannels; + ${coordASnippet} + return x[getIndexFromCoords4D(coord, xShape)/${innerElementSize}];`; + + const sampleA = isChannelsLast ? ` + let col = colIn * ${innerElementSize}; + if (row < dimAOuter && col < dimInner) { + ${readASnippet} + } + return ${type}(0.0);` : + ` + let col = colIn * ${innerElementSize}; + if (row < dimInner && col < dimBOuter) { + ${readASnippet} + } + return ${type}(0.0);`; + + const sampleW = ` + let col = colIn * ${innerElementSize}; + let inChannels = ${isChannelsLast ? 'outBackprop[3]' : 'outBackprop[1]'}; + let coordX = filterDims.x - 1 - row / (filterDims[1] * inChannels); + let coordY = filterDims.y - 1 - (row / inChannels) % filterDims[1]; + if (${ + isChannelsLast ? 'row < dimInner && col < dimBOuter' : + 'row < dimInner && col < dimAOuter'} && coordX >= 0 && coordY >= 0) { + let rowInner = row % inChannels; + let coord = vec4(coordX, coordY, col, rowInner); + ${getWSnippet(innerElementSize)} + } + return ${type}(0.0); + `; + + + const userCode = ` + ${activationFnSnippet(activation, hasPreluActivationWeights, innerElementSize === 4, 4)} + fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${type} { + ${isChannelsLast ? sampleA : sampleW} + } + + fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${type} { + ${isChannelsLast ? sampleW : sampleA} + } + + fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${type}) { + let col = colIn * ${innerElementSize}; + if (row < dimAOuter && col < dimBOuter) { + var value = valueInput; + let outWidth = ${isChannelsLast ? 'outShape[2]' : 'outShape[3]'}; + ${coordResSnippet} + ${biasActivationSnippet(addBias, activation)} + result[getIndexFromCoords4D(coords, outShape)/${innerElementSize}] = value; + } + }`; + return userCode; + }; + +export const createConv2DTransposeMatMulProgramInfo = + (inputs: readonly TensorView[], metadata: ProgramMetadata, attributes: ConvTransposeAttributes, + outputShape: readonly number[], dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean, + sequentialAccessByThreads: boolean): ProgramInfo => { + const isChannelsLast = attributes.format === 'NHWC'; + const inChannels = isChannelsLast ? inputs[0].dims[3] : inputs[0].dims[1]; + const batchSize = outputShape[0]; + const outWidth = isChannelsLast ? outputShape[2] : outputShape[3]; + const outHeight = isChannelsLast ? outputShape[1] : outputShape[2]; + const outChannels = isChannelsLast ? outputShape[3] : outputShape[1]; + const isVec4 = + isChannelsLast ? inChannels % 4 === 0 && outChannels % 4 === 0 : outWidth % 4 === 0 && outChannels % 4 === 0; + + // TODO: fine tune size + const dispatchX = isChannelsLast ? outChannels : outWidth * outHeight; + const dispatchY = isChannelsLast ? outWidth * outHeight : outChannels; + const workGroupSize: [number, number, number] = isVec4 ? + [8, 8, 1] : + [(dispatchX <= 4 || dispatchY <= 4) ? 4 : 16, dispatchX > 4 && dispatchY <= 4 ? 4 : 16, 1]; + const elementsPerThread = + isVec4 ? [4, 4, 1] : [dispatchX <= 4 ? 1 : 4, dispatchX > 4 && dispatchY <= 4 ? 1 : 4, 1]; + const dispatch = [ + Math.ceil(dispatchX / workGroupSize[0] / elementsPerThread[0]), + Math.ceil(dispatchY / workGroupSize[1] / elementsPerThread[1]), + Math.ceil(batchSize / workGroupSize[2] / elementsPerThread[2]) + ]; + + LOG_DEBUG('verbose', () => `[conv_backprop_mm_webgpu] dispatch = ${dispatch}`); + + const innerElementSize = isVec4 ? 4 : 1; + const tileInner = Math.max(workGroupSize[0] * innerElementSize, workGroupSize[1]); + + + const declareInputs = [ + `@group(0) @binding(0) var x: array<${isVec4 ? 'vec4' : 'f32'}>;`, + '@group(0) @binding(1) var W: array;' + ]; + let declareFunctions = ''; + if (hasBias) { + declareInputs.push(`@group(0) @binding(2) var bias: array<${isVec4 ? 'vec4' : 'f32'}>;`); + declareFunctions += ` + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; + }`; + } + return { + ...metadata, + outputs: [{dims: outputShape, dataType: inputs[0].dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), + getShaderSource: () => ` + ${utilFunctions} + ${declareInputs.join('\n')} + @group(0) @binding(${declareInputs.length}) var result: array<${ + isVec4 ? 'vec4' : 'f32'}>; + const outBackprop : vec4 = vec4(${inputs[0].dims.join(',')}); + const xShape : vec4 = vec4(${inputs[0].dims.join(',')}); + const wShape : vec4 = vec4(${inputs[1].dims.join(',')}); + const outShape : vec4 = vec4(${outputShape.join(',')}); + const outShapeStrides : vec3 = vec3(${ShapeUtil.computeStrides(outputShape).slice(0, 3).join(',')}); + const filterDims : vec2 = vec2(${attributes.kernelShape[isChannelsLast ? 1 : 2]}, ${ + attributes.kernelShape[isChannelsLast ? 2 : 3]}); + const effectiveFilterDims : vec2 = filterDims + vec2( + ${ + attributes.dilations[0] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 1 : 2] - 1) * (attributes.dilations[0] - 1)}, + ${ + attributes.dilations[1] <= 1 ? + 0 : + (attributes.kernelShape[isChannelsLast ? 2 : 3] - 1) * (attributes.dilations[1] - 1)}); + const pads : vec2 = vec2(i32(effectiveFilterDims[0]) - 1 - (${ + attributes.pads[0] + attributes.pads[2]})/2, + i32(effectiveFilterDims[1]) - 1 - (${ + attributes.pads[1] + attributes.pads[3]})/2); + const strides : vec2 = vec2(${attributes.strides[0]}, ${attributes.strides[1]}); + const dilation : vec2 = vec2(${attributes.dilations[0]}, ${attributes.dilations[1]}); + const dimAOuter : i32 = ${dimAOuter}; + const dimBOuter : i32 = ${dimBOuter}; + const dimInner : i32 = ${dimInner}; + ${declareFunctions} + ${conv2dTransposeCommonSnippet(isChannelsLast, hasBias, undefined, false, innerElementSize)} + ${ + isVec4 ? makeMatMulPackedVec4Source( + elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner) : + makeMatMulPackedSource( + elementsPerThread, workGroupSize, 'f32', undefined, !isChannelsLast, tileInner, false, + undefined, sequentialAccessByThreads)}` + }; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts index ec6df438129fb..c60b94056a360 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv_backprop_webgpu.ts @@ -21,12 +21,13 @@ import {LOG_DEBUG} from '../../../log'; import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; -import {inputVariable, outputVariable, ShaderHelper} from '../common'; +import {inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {ConvTransposeAttributes} from '../conv-transpose'; const createConvTranspose2DOpProgramShaderSource = (shaderHelper: ShaderHelper, inputs: readonly TensorView[], attributes: ConvTransposeAttributes, - outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false): string => { + outputShape: readonly number[], hasBias: boolean, is1DimensionDispatch: boolean, isVec4 = false, + dataType: string): string => { const isChannelsLast = attributes.format === 'NHWC'; const rowDim = isChannelsLast ? 1 : 2; const colDim = isChannelsLast ? 2 : 3; @@ -39,12 +40,12 @@ const createConvTranspose2DOpProgramShaderSource = const outputChannelsPerGroup = wShape[1]; let declareFunctions = ` - fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? 'vec4' : 'f32'}) { - result[flatIndex] = ${isVec4 ? 'vec4' : 'f32'}(value); + fn setOutputAtIndex(flatIndex : u32, value : ${isVec4 ? `vec4<${dataType}>` : dataType}) { + result[flatIndex] = ${isVec4 ? `vec4<${dataType}>` : dataType}(value); }`; if (hasBias) { declareFunctions += ` - fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? 'vec4' : 'f32'} { + fn getBiasByOutputCoords(coords : vec4) -> ${isVec4 ? `vec4<${dataType}>` : dataType} { return bias[coords.${isChannelsLast ? 'w' : 'y'}${isVec4 ? '/ 4' : ''}]; }`; } @@ -66,33 +67,33 @@ const createConvTranspose2DOpProgramShaderSource = // Convolve dy(?, ?, d2) with w(:, :, d1, d2) to compute dx(xR, xC, d1). // ? = to be determined. : = across all values in that axis. - var dotProd: array, ${workPerThread}>; + var dotProd: array, ${workPerThread}>; for (var i = 0; i < ${workPerThread}; i++) { - dotProd[i] = vec4(0.0); + dotProd[i] = vec4<${dataType}>(0.0); } for (var wR: u32 = 0; wR < filterDims[0]; wR = wR + 1) { - var dyR = (f32(dyCorner.x) + f32(wR)) / f32(strides.x); + var dyR = (${dataType}(dyCorner.x) + ${dataType}(wR)) / ${dataType}(strides.x); let wRPerm = filterDims[0] - 1 - wR; - if (dyR < 0.0 || dyR >= f32(outBackprop[1]) || + if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[1]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } let idyR: u32 = u32(dyR); for (var wC: u32 = 0; wC < filterDims[1]; wC = wC + 1) { - let dyC = (f32(dyCorner.y) + f32(wC)) / f32(strides.y); - let dyC2 = (f32(dyCorner.y) + 1.0 + f32(wC)) / f32(strides.y); + let dyC = (${dataType}(dyCorner.y) + ${dataType}(wC)) / ${dataType}(strides.y); + let dyC2 = (${dataType}(dyCorner.y) + 1.0 + ${dataType}(wC)) / ${dataType}(strides.y); let wCPerm = filterDims[1] - 1 - wC; if (wCPerm < 0) { continue; } var bDyCVal = true; var bDyCVal2 = true; - if (dyC < 0.0 || dyC >= f32(outBackprop[2]) || + if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[2]) || fract(dyC) > 0.0) { bDyCVal = false; } - if (dyC2 < 0.0 || dyC2 >= f32(outBackprop[2]) || + if (dyC2 < 0.0 || dyC2 >= ${dataType}(outBackprop[2]) || fract(dyC2) > 0.0) { bDyCVal2 = false; } @@ -108,7 +109,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -116,7 +117,7 @@ const createConvTranspose2DOpProgramShaderSource = xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - dotProd[1] = dotProd[1] + vec4(dot(xValue, wValue0), + dotProd[1] = dotProd[1] + vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -130,7 +131,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -145,7 +146,7 @@ const createConvTranspose2DOpProgramShaderSource = let wValue3 = ${w.get('u32(wRPerm)', 'u32(wCPerm)', 'd1 + 3', 'd2')}; var xValue = ${dy.get('batch', 'idyR', 'idyC2', 'd2')}; - let tmpval = vec4(dot(xValue, wValue0), + let tmpval = vec4<${dataType}>(dot(xValue, wValue0), dot(xValue, wValue1), dot(xValue, wValue2), dot(xValue, wValue3)); @@ -178,9 +179,9 @@ const createConvTranspose2DOpProgramShaderSource = if (wR % dilations.x != 0) { continue; } - let dyR = (f32(dyRCorner) + f32(wR)) / f32(strides[0]); + let dyR = (${dataType}(dyRCorner) + ${dataType}(wR)) / ${dataType}(strides[0]); let wRPerm = filterDims.x - 1 - wR / dilations.x; - if (dyR < 0.0 || dyR >= f32(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || + if (dyR < 0.0 || dyR >= ${dataType}(outBackprop[${rowDim}]) || fract(dyR) > 0.0 || wRPerm < 0) { continue; } @@ -190,21 +191,21 @@ const createConvTranspose2DOpProgramShaderSource = if (wC % dilations.y != 0) { continue; } - let dyC = (f32(dyCCorner) + f32(wC)) / f32(strides.y); + let dyC = (${dataType}(dyCCorner) + ${dataType}(wC)) / ${dataType}(strides.y); let wCPerm = filterDims.y - 1 - wC / dilations.y; - if (dyC < 0.0 || dyC >= f32(outBackprop[${colDim}]) || + if (dyC < 0.0 || dyC >= ${dataType}(outBackprop[${colDim}]) || fract(dyC) > 0.0 || wCPerm < 0) { continue; } let idyC: u32 = u32(dyC); - + var inputChannel = groupId * ${inputChannelsPerGroup}; for (var d2: u32 = 0; d2 < ${inputChannelsPerGroup}; d2 = d2 + 1) { - let inputChannel = groupId * ${inputChannelsPerGroup} + d2; let xValue = ${ isChannelsLast ? dy.get('batch', 'idyR', 'idyC', 'inputChannel') : dy.get('batch', 'inputChannel', 'idyR', 'idyC')}; let wValue = ${w.get('inputChannel', 'wOutChannel', 'u32(wRPerm)', 'u32(wCPerm)')}; dotProd = dotProd + xValue * wValue; + inputChannel = inputChannel + 1; } } } @@ -256,6 +257,7 @@ export const createConvTranspose2DProgramInfo = ]; LOG_DEBUG('verbose', () => `[conv2d_backprop_webgpu] dispatch = ${dispatch}`); + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); return { ...metadata, outputs: [{ @@ -265,6 +267,7 @@ export const createConvTranspose2DProgramInfo = }], dispatchGroup: () => ({x: dispatch[0], y: dispatch[1], z: dispatch[2]}), getShaderSource: (shaderHelper: ShaderHelper) => createConvTranspose2DOpProgramShaderSource( - shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1), + shaderHelper, inputs, attributes, outputShape, hasBias, dispatch[1] === 1 && dispatch[2] === 1, false, + dataType), }; }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts index 8d43dbb378a69..82f8c82291f4b 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/matmul_packed_webgpu.ts @@ -22,7 +22,7 @@ import {TensorView} from '../../../tensor-view'; import {ShapeUtil} from '../../../util'; import {GpuDataType, ProgramInfo, ProgramMetadata} from '../../types'; -import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from '../common'; +import {getBroadcastDims, IndicesHelper, inputVariable, outputVariable, ShaderHelper, tensorTypeToWsglStorageType} from '../common'; import {getActicationSnippet, InternalActivationAttributes} from '../fuse-utils'; import {typeSnippet} from './activation_util'; @@ -70,8 +70,8 @@ const calculateResultSnippet = (transposeA: boolean, innerElementSize: number) = }; export const makeMatMulPackedVec4Source = - (workPerThread: number[], workgroupSize: [number, number, number], batchDims?: IndicesHelper, transposeA = false, - tileInner = 32, splitK = false, splitedDimInner = 32): string => { + (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, + transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32): string => { const tileAOuter = workgroupSize[1] * workPerThread[1]; const tileBOuter = workgroupSize[0] * workPerThread[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; @@ -90,8 +90,8 @@ export const makeMatMulPackedVec4Source = workPerThread[0]} must be 4.`); } return ` -var mm_Asub : array, ${tileAWidth / innerElementSize}>, ${tileAHight}>; -var mm_Bsub : array, ${tileBOuter / workPerThread[0]}>, ${tileInner}>; +var mm_Asub : array, ${tileAWidth / innerElementSize}>, ${tileAHight}>; +var mm_Bsub : array, ${tileBOuter / workPerThread[0]}>, ${tileInner}>; const rowPerThread = ${workPerThread[1]}; const colPerThread = ${workPerThread[0]}; @@ -115,7 +115,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; - var acc: array, rowPerThread>; + var acc: array, rowPerThread>; // Loop over shared dimension. let tileRowB = localRow * ${rowPerThreadB}; @@ -179,8 +179,9 @@ const readDataFromSubASnippet = (transposeA: boolean) => // sequentialAccessByThreads means sequential data in memory is accessed by // threads, instead of a single thread (default behavior). export const makeMatMulPackedSource = - (workPerThread: number[], workgroupSize: [number, number, number], batchDims?: IndicesHelper, transposeA = false, - tileInner = 32, splitK = false, splitedDimInner = 32, sequentialAccessByThreads = false): string => { + (workPerThread: number[], workgroupSize: [number, number, number], type = 'f32', batchDims?: IndicesHelper, + transposeA = false, tileInner = 32, splitK = false, splitedDimInner = 32, + sequentialAccessByThreads = false): string => { const tileAOuter = workPerThread[1] * workgroupSize[1]; const tileBOuter = workPerThread[0] * workgroupSize[0]; const tileAWidth = transposeA ? tileAOuter : tileInner; @@ -222,7 +223,7 @@ export const makeMatMulPackedSource = workgroupBarrier(); // Compute acc values for a single thread. - var BCached : array; + var BCached : array<${type}, colPerThread>; for (var k = 0; k < tileInner; k = k + 1) { for (var inner = 0; inner < colPerThread; inner = inner + 1) { BCached[inner] = mm_Bsub[k][localCol + inner * ${workgroupSize[0]}]; @@ -283,7 +284,7 @@ for (var t = 0; t < numTiles; t = t + 1) { workgroupBarrier(); // Compute acc values for a single thread. - var BCached : array; + var BCached : array<${type}, colPerThread>; for (var k = 0; k < tileInner; k = k + 1) { for (var inner = 0; inner < colPerThread; inner = inner + 1) { BCached[inner] = mm_Bsub[k][tileCol + inner]; @@ -309,8 +310,8 @@ for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { `; return ` - var mm_Asub : array, ${tileAHight}>; - var mm_Bsub : array, ${tileInner}>; + var mm_Asub : array, ${tileAHight}>; + var mm_Bsub : array, ${tileInner}>; const rowPerThread = ${workPerThread[1]}; const colPerThread = ${workPerThread[0]}; const tileInner = ${tileInner}; @@ -324,7 +325,7 @@ fn main(@builtin(local_invocation_id) localId : vec3, let numTiles = ${splitK ? `${Math.ceil(splitedDimInner / tileInner)}` : '(dimInner - 1) / tileInner + 1'}; var kStart = ${splitK ? `i32(globalId.z) * ${splitedDimInner}` : '0'}; - var acc : array, rowPerThread>; + var acc : array, rowPerThread>; // Without this initialization strange values show up in acc. for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) { @@ -347,6 +348,7 @@ const matMulReadWriteFnSource = const outputVariable = variables[5]; const broadCastADims = getBroadcastDims(batchAVariable.shape, batchVariable.shape); const broadCastBDims = getBroadcastDims(batchBVariable.shape, batchVariable.shape); + const dataType = tensorTypeToWsglStorageType(variables[0].type.tensor); const getAIndices = () => { const aRank = aVariable.shape.length; const batchRank = batchVariable.shape.length; @@ -377,8 +379,8 @@ const matMulReadWriteFnSource = }; const source = ` fn mm_readA(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${ - typeSnippet(component)} { - var value = ${typeSnippet(component)}(0.0); + typeSnippet(component, dataType)} { + var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; if(row < dimAOuter && col < dimInner) { @@ -389,8 +391,8 @@ const matMulReadWriteFnSource = } fn mm_readB(batch: i32, row: i32, colIn: i32, batchIndices: ${batchVariable.type.indices}) -> ${ - typeSnippet(component)} { - var value = ${typeSnippet(component)}(0.0); + typeSnippet(component, dataType)} { + var value = ${typeSnippet(component, dataType)}(0.0); let col = colIn * ${component}; if(row < dimInner && col < dimBOuter) { @@ -400,7 +402,7 @@ const matMulReadWriteFnSource = return value; } - fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component)}) { + fn mm_write(batch: i32, row: i32, colIn: i32, valueIn: ${typeSnippet(component, dataType)}) { let col = colIn * ${component}; if (row < dimAOuter && col < dimBOuter) { var value = valueIn; @@ -444,6 +446,7 @@ export const createMatmulProgramInfo = Math.ceil(batchSize / workgroupSize[2] / elementsPerThread[2]) ]; + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); const components = isVec4 ? 4 : 1; const A = inputVariable('a', inputs[0].dataType, [...outerDimsA, dimAOuter, dimInner / components], components); const B = inputVariable('b', inputs[1].dataType, [...outerDimsB, dimInner, dimBOuter / components], components); @@ -466,8 +469,8 @@ export const createMatmulProgramInfo = ${declareFunctions} ${activationFunction} ${ - isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, batchDims) : - makeMatMulPackedSource(elementsPerThread, workgroupSize, batchDims)} + isVec4 ? makeMatMulPackedVec4Source(elementsPerThread, workgroupSize, dataType, batchDims) : + makeMatMulPackedSource(elementsPerThread, workgroupSize, dataType, batchDims)} ${batchDims.impl()}`; return { ...metadata, diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index e7d1ddf771650..59f11d6d9abba 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -1,14 +1,15 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; import {createConvTranspose2DProgramInfo} from './3rd-party/conv_backprop_webgpu'; import {ConvAttributes} from './conv'; +import {createConv2DTransposeMatMulProgramInfoLoader} from './conv2dtranspose-mm'; import {parseInternalActivationAttributes} from './fuse-utils'; +import {createTransposeProgramInfo, TransposeAttributes, transposeProgramMetadata} from './transpose'; const computeTotalPad = (inDim: number, stride: number, adj: number, kernel: number, dilation: number, outSize: number) => @@ -63,7 +64,7 @@ const getAdjustedConvTransposeAttributes = (attributes: T, inputs: readonly TensorView[]): T => { const kernelShape = attributes.kernelShape.slice(); // if kernelShape is not specified in the attributes of this op, infer it from the weight tensor dims - if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 0) === 0) { + if (attributes.kernelShape.length === 0 || attributes.kernelShape.reduce((a, b) => a * b, 1) === 0) { kernelShape.length = 0; for (let i = 2; i < inputs[1].dims.length; ++i) { kernelShape.push(inputs[1].dims[i]); @@ -95,9 +96,11 @@ const getAdjustedConvTransposeAttributes = // always return a new object so does not modify the original attributes const newAttributes: T = Object.assign({}, attributes); - Object.assign( - newAttributes, - {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey: attributes.cacheKey}); + const cacheKey = attributes.cacheKey + [ + kernelShape.join('n,'), pads.join(','), strides.join(','), outputPadding.join(','), outputShape.join(','), + dilations.join(',') + ].join('_'); + Object.assign(newAttributes, {kernelShape, pads, outputPadding, outputShape, dilations, strides, cacheKey}); return newAttributes; }; @@ -197,15 +200,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvTranspose if (attributes.outputShape.length !== 0 && attributes.outputShape.length !== inputs[0].dims.length - 2) { throw new Error('invalid output shape'); } - - // TODO : Need to add support for float64 - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('ConvTranspose input(X,W) should be float tensor'); - } - - if (inputs.length === 3 && inputs[2].dataType !== DataType.float) { - throw new Error('ConvTranspose input(bias) should be float tensor'); - } }; const createConvTranspose2DProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ @@ -226,12 +220,64 @@ const createConvTranspose2DProgramInfoLoader = }; }; +// for transposing weight tensor from [C, M/group, KH, KW] to [KH, KW, M/group, C] +const weightTransposeAttribute: TransposeAttributes = createAttributeWithCacheKey({perm: [2, 3, 1, 0]}); + const convTranspose2d = (context: ComputeContext, inputs: readonly TensorView[], attributes: ConvTransposeAttributes): void => { const adjustedAttributes = getAdjustedConvTransposeAttributes(attributes, inputs); + const isChannelsLast = attributes.format === 'NHWC'; + const hasBias = inputs.length === 3; + if (adjustedAttributes.group !== 1) { + context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); + return; + } + const outputShape = adjustedAttributes.outputShape; + const outHeight = outputShape[isChannelsLast ? 1 : 2]; + const outWidth = outputShape[isChannelsLast ? 2 : 3]; + const outChannels = outputShape[isChannelsLast ? 3 : 1]; + const weightHeight = inputs[1].dims[2]; + const weightWidth = inputs[1].dims[3]; + const inputChannels = inputs[0].dims[isChannelsLast ? 3 : 1]; + + const dimAOuter = isChannelsLast ? outHeight * outWidth : outChannels; + const dimBOuter = isChannelsLast ? outChannels : outHeight * outWidth; + const dimInner = weightHeight * weightWidth * inputChannels; - context.compute(createConvTranspose2DProgramInfoLoader(inputs, adjustedAttributes)); + const sequentialAccessByThreads = /* backend.adapterInfo.isIntel() */ true; + + + // STEP.1: transpose weight + const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? + context.compute( + { + ...transposeProgramMetadata, + cacheHint: weightTransposeAttribute.cacheKey, + get: () => createTransposeProgramInfo(inputs[1], weightTransposeAttribute.perm) + }, + {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; + if (attributes.wIsConst && !context.kernelCustomData.wT) { + context.kernelCustomData.wT = transposedWeight; + } + + // STEP.2: prepare reshaped inputs + const convTransposeInputs = [inputs[0], transposedWeight]; + if (hasBias) { + if (!isChannelsLast && inputs[2].dims.length === 1) { + convTransposeInputs.push(inputs[2].reshape([inputs[2].dims[0], 1, 1])); + } else { + convTransposeInputs.push(inputs[2]); + } + } + + // STEP.3: compute matmul + context.compute( + createConv2DTransposeMatMulProgramInfoLoader( + convTransposeInputs, adjustedAttributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, + sequentialAccessByThreads), + {inputs: convTransposeInputs}); }; + const convTranspose1d = (context: ComputeContext, attributes: ConvTransposeAttributes): void => { // extend the input to 2D by adding H dimension const isChannelLast = attributes.format === 'NHWC'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index 95a64e5787841..7afc3ce1b9d77 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {PoolConvUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; @@ -93,15 +92,6 @@ const validateInputs = (inputs: readonly TensorView[], attributes: ConvAttribute if (attributes.kernelShape.length !== 0 && attributes.kernelShape.length !== inputs[1].dims.length - 2) { throw new Error('invalid kernel shape'); } - - // TODO : Need to add support for float64 - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('Conv input(X,W) should be float tensor'); - } - - if (inputs.length === 3 && inputs[2].dataType !== DataType.float) { - throw new Error('Conv input(bias) should be float tensor'); - } }; const getAdjustedConvAttributes = (attributes: T, inputs: readonly TensorView[]): T => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts new file mode 100644 index 0000000000000..da04b5063a9f0 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv2dtranspose-mm.ts @@ -0,0 +1,29 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {TensorView} from '../../tensor-view'; +import {GpuDataType, ProgramInfoLoader, ProgramMetadata} from '../types'; + +import {createConv2DTransposeMatMulProgramInfo} from './3rd-party/conv_backprop_mm_webgpu'; +import {ConvTransposeAttributes} from './conv-transpose'; + + +const createConv2DTransposeMatMulProgramMetadata = (hasBias: boolean, cacheHint: string): ProgramMetadata => ({ + name: 'Conv2DTransposeMatMul', + inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : + [GpuDataType.default, GpuDataType.default], + cacheHint +}); + +export const createConv2DTransposeMatMulProgramInfoLoader = + (inputs: readonly TensorView[], attributes: ConvTransposeAttributes, outputShape: readonly number[], + dimAOuter: number, dimBOuter: number, dimInner: number, hasBias: boolean, + sequentialAccessByThreads: boolean): ProgramInfoLoader => { + const metadata = createConv2DTransposeMatMulProgramMetadata(hasBias, attributes.cacheKey); + return { + ...metadata, + get: () => createConv2DTransposeMatMulProgramInfo( + inputs, metadata, attributes, outputShape, dimAOuter, dimBOuter, dimInner, hasBias, + sequentialAccessByThreads) + }; + }; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts index 837ac8410f291..7dadf9a6205ea 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/matmul.ts @@ -1,7 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {BroadcastUtil} from '../../util'; import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types'; @@ -9,7 +8,6 @@ import {ComputeContext, GpuDataType, ProgramInfoLoader} from '../types'; import {createMatmulProgramInfo} from './3rd-party/matmul_packed_webgpu'; import {InternalActivationAttributes} from './fuse-utils'; - const createMatmulProgramMetadata = (hasBias: boolean, cacheHint: string) => ({ name: 'MatMul', inputTypes: hasBias ? [GpuDataType.default, GpuDataType.default, GpuDataType.default] : @@ -35,10 +33,6 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (inputs[0].dims[inputs[0].dims.length - 1] !== inputs[1].dims[inputs[1].dims.length - 2]) { throw new Error('shared dimension does not match.'); } - - if (inputs[0].dataType !== DataType.float || inputs[1].dataType !== DataType.float) { - throw new Error('inputs should be float type'); - } }; export const matMul = (context: ComputeContext): void => { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/range.ts b/js/web/lib/wasm/jsep/webgpu/ops/range.ts new file mode 100644 index 0000000000000..3ecb3308b1899 --- /dev/null +++ b/js/web/lib/wasm/jsep/webgpu/ops/range.ts @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env} from 'onnxruntime-common'; + +import {DataType} from '../../../wasm-common'; +import {ComputeContext, GpuDataType, ProgramInfo, ProgramMetadata} from '../types'; + +import {outputVariable, ShaderHelper} from './common'; + +const validateInputsContent = (start: number, limit: number, delta: number): void => { + const sameStartLimit = start === limit; + const increasingRangeNegativeStep = start < limit && delta < 0; + const decreasingRangePositiveStep = start > limit && delta > 0; + + if (sameStartLimit || increasingRangeNegativeStep || decreasingRangePositiveStep) { + throw new Error('Range these inputs\' contents are invalid.'); + } +}; + +const createRangeProgramInfo = + (metadata: ProgramMetadata, start: number, limit: number, delta: number, dataType: DataType): ProgramInfo => { + const numElements = Math.abs(Math.ceil((limit - start) / delta)); + const outputShape: number[] = [numElements]; + const outputSize = numElements; + + const output = outputVariable('output', dataType, outputShape); + const wgslType = output.type.storage; + + const getShaderSource = (shaderHelper: ShaderHelper) => ` + ${shaderHelper.declareVariables(output)} + ${shaderHelper.mainStart()} + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} + output[global_idx] = ${wgslType}(${start}) + ${wgslType}(global_idx) * ${wgslType}(${delta}); + }`; + return { + ...metadata, + getShaderSource, + outputs: [{dims: outputShape, dataType, gpuDataType: GpuDataType.default}], + dispatchGroup: () => ({x: Math.ceil(outputSize / 64 /* workgroup size */)}) + }; + }; + +export const range = (context: ComputeContext): void => { + let start = 0; + let limit = 0; + let delta = 0; + if (context.inputs[0].dataType === DataType.int32) { + start = context.inputs[0].getInt32Array()[0]; + limit = context.inputs[1].getInt32Array()[0]; + delta = context.inputs[2].getInt32Array()[0]; + } else if (context.inputs[0].dataType === DataType.float) { + start = context.inputs[0].getFloat32Array()[0]; + limit = context.inputs[1].getFloat32Array()[0]; + delta = context.inputs[2].getFloat32Array()[0]; + } + if (env.webgpu.validateInputContent) { + validateInputsContent(start, limit, delta); + } + + const cacheHint = [start, limit, delta].map(x => x.toString()).join('_'); + const metadata: ProgramMetadata = {name: 'Range', inputTypes: [], cacheHint}; + context.compute( + {...metadata, get: () => createRangeProgramInfo(metadata, start, limit, delta, context.inputs[0].dataType)}, + {inputs: []}); +}; diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index e5a2d8c2351b8..43f70c23f7193 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -3,20 +3,24 @@ import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; -/** - * tuple elements are: ORT element type; dims; tensor data - */ -export type SerializableTensor = [Tensor.Type, readonly number[], Tensor.DataType]; +export type SerializableTensorMetadata = + [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu']; -/** - * tuple elements are: InferenceSession handle; input names; output names - */ -export type SerializableSessionMetadata = [number, string[], string[]]; +export type GpuBufferMetadata = { + gpuBuffer: Tensor.GpuBufferType; + download?: () => Promise; + dispose?: () => void; +}; -/** - * tuple elements are: modeldata.offset, modeldata.length - */ -export type SerializableModeldata = [number, number]; +export type UnserializableTensorMetadata = + [dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer']| + [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned']; + +export type TensorMetadata = SerializableTensorMetadata|UnserializableTensorMetadata; + +export type SerializableSessionMetadata = [sessionHandle: number, inputNames: string[], outputNames: string[]]; + +export type SerializableModeldata = [modelDataOffset: number, modelDataLength: number]; interface MessageError { err?: string; @@ -58,10 +62,10 @@ interface MessageReleaseSession extends MessageError { interface MessageRun extends MessageError { type: 'run'; in ?: { - sessionId: number; inputIndices: number[]; inputs: SerializableTensor[]; outputIndices: number[]; + sessionId: number; inputIndices: number[]; inputs: SerializableTensorMetadata[]; outputIndices: number[]; options: InferenceSession.RunOptions; }; - out?: SerializableTensor[]; + out?: SerializableTensorMetadata[]; } interface MesssageEndProfiling extends MessageError { diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index 815b223e40379..202209ed3bfed 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -3,7 +3,7 @@ import {Env, env, InferenceSession} from 'onnxruntime-common'; -import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensor} from './proxy-messages'; +import {OrtWasmMessage, SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; import * as core from './wasm-core-impl'; import {initializeWebAssembly} from './wasm-factory'; @@ -22,7 +22,7 @@ const createSessionAllocateCallbacks: Array> = []; const createSessionCallbacks: Array> = []; const releaseSessionCallbacks: Array> = []; -const runCallbacks: Array> = []; +const runCallbacks: Array> = []; const endProfilingCallbacks: Array> = []; const ensureWorker = (): void => { @@ -177,6 +177,10 @@ export const createSessionFinalize = async(modeldata: SerializableModeldata, opt export const createSession = async(model: Uint8Array, options?: InferenceSession.SessionOptions): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + // check unsupported options + if (options?.preferredOutputLocation) { + throw new Error('session option "preferredOutputLocation" is not supported for proxy.'); + } ensureWorker(); return new Promise((resolve, reject) => { createSessionCallbacks.push([resolve, reject]); @@ -202,17 +206,27 @@ export const releaseSession = async(sessionId: number): Promise => { }; export const run = async( - sessionId: number, inputIndices: number[], inputs: SerializableTensor[], outputIndices: number[], - options: InferenceSession.RunOptions): Promise => { + sessionId: number, inputIndices: number[], inputs: TensorMetadata[], outputIndices: number[], + outputs: Array, options: InferenceSession.RunOptions): Promise => { if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + // check inputs location + if (inputs.some(t => t[3] !== 'cpu')) { + throw new Error('input tensor on GPU is not supported for proxy.'); + } + // check outputs location + if (outputs.some(t => t)) { + throw new Error('pre-allocated output tensor is not supported for proxy.'); + } ensureWorker(); - return new Promise((resolve, reject) => { + return new Promise((resolve, reject) => { runCallbacks.push([resolve, reject]); - const message: OrtWasmMessage = {type: 'run', in : {sessionId, inputIndices, inputs, outputIndices, options}}; - proxyWorker!.postMessage(message, core.extractTransferableBuffers(inputs)); + const serializableInputs = inputs as SerializableTensorMetadata[]; // every input is on CPU. + const message: OrtWasmMessage = + {type: 'run', in : {sessionId, inputIndices, inputs: serializableInputs, outputIndices, options}}; + proxyWorker!.postMessage(message, core.extractTransferableBuffers(serializableInputs)); }); } else { - return core.run(sessionId, inputIndices, inputs, outputIndices, options); + return core.run(sessionId, inputIndices, inputs, outputIndices, outputs, options); } }; diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index d8c5ae7886fe4..7bc467449c33a 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -2,16 +2,45 @@ // Licensed under the MIT License. import {readFile} from 'fs'; -import {env, InferenceSession, SessionHandler, Tensor} from 'onnxruntime-common'; +import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {promisify} from 'util'; -import {SerializableModeldata} from './proxy-messages'; +import {SerializableModeldata, TensorMetadata} from './proxy-messages'; import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper'; +import {isGpuBufferSupportedType} from './wasm-common'; let runtimeInitialized: boolean; let runtimeInitializationPromise: Promise|undefined; -export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { +const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { + switch (tensor.location) { + case 'cpu': + return [tensor.type, tensor.dims, tensor.data, 'cpu']; + case 'gpu-buffer': + return [tensor.type, tensor.dims, {gpuBuffer: tensor.gpuBuffer}, 'gpu-buffer']; + default: + throw new Error(`invalid data location: ${tensor.location} for ${getName()}`); + } +}; + +const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { + switch (tensor[3]) { + case 'cpu': + return new Tensor(tensor[0], tensor[2], tensor[1]); + case 'gpu-buffer': { + const dataType = tensor[0]; + if (!isGpuBufferSupportedType(dataType)) { + throw new Error(`not supported data type: ${dataType} for deserializing GPU tensor`); + } + const {gpuBuffer, download, dispose} = tensor[2]; + return Tensor.fromGpuBuffer(gpuBuffer, {dataType, dims: tensor[1], download, dispose}); + } + default: + throw new Error(`invalid data location: ${tensor[3]}`); + } +}; + +export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHandler { private sessionId: number; inputNames: string[]; @@ -74,25 +103,31 @@ export class OnnxruntimeWebAssemblySessionHandler implements SessionHandler { inputIndices.push(index); }); + const outputArray: Array = []; const outputIndices: number[] = []; Object.entries(fetches).forEach(kvp => { const name = kvp[0]; - // TODO: support pre-allocated output + const tensor = kvp[1]; const index = this.outputNames.indexOf(name); if (index === -1) { throw new Error(`invalid output '${name}'`); } + outputArray.push(tensor); outputIndices.push(index); }); - const outputs = - await run(this.sessionId, inputIndices, inputArray.map(t => [t.type, t.dims, t.data]), outputIndices, options); + const inputs = + inputArray.map((t, i) => encodeTensorMetadata(t, () => `input "${this.inputNames[inputIndices[i]]}"`)); + const outputs = outputArray.map( + (t, i) => t ? encodeTensorMetadata(t, () => `output "${this.outputNames[outputIndices[i]]}"`) : null); + + const results = await run(this.sessionId, inputIndices, inputs, outputIndices, outputs, options); - const result: SessionHandler.ReturnType = {}; - for (let i = 0; i < outputs.length; i++) { - result[this.outputNames[outputIndices[i]]] = new Tensor(outputs[i][0], outputs[i][2], outputs[i][1]); + const resultMap: SessionHandler.ReturnType = {}; + for (let i = 0; i < results.length; i++) { + resultMap[this.outputNames[outputIndices[i]]] = outputArray[i] ?? decodeTensorMetadata(results[i]); } - return result; + return resultMap; } startProfiling(): void { diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index 389773f3e8884..b9eff45e890c4 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -164,3 +164,35 @@ export const logLevelStringToEnum = (logLevel?: 'verbose'|'info'|'warning'|'erro throw new Error(`unsupported logging level: ${logLevel}`); } }; + +/** + * Check whether the given tensor type is supported by GPU buffer + */ +export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuBufferDataTypes => type === 'float32' || + type === 'int32' || type === 'int64' || type === 'bool' || type === 'float16' || type === 'uint32'; + +/** + * Map string data location to integer value + */ +export const dataLocationStringToEnum = (location: Tensor.DataLocation): number => { + switch (location) { + case 'none': + return 0; + case 'cpu': + return 1; + case 'cpu-pinned': + return 2; + case 'texture': + return 3; + case 'gpu-buffer': + return 4; + default: + throw new Error(`unsupported data location: ${location}`); + } +}; + +/** + * Map integer data location to string value + */ +export const dataLocationEnumToString = (location: number): Tensor.DataLocation|undefined => + (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer'] as const)[location]; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index fcca82ab2aa54..5b49a1d4202e3 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -3,10 +3,10 @@ import {Env, InferenceSession, Tensor} from 'onnxruntime-common'; -import {SerializableModeldata, SerializableSessionMetadata, SerializableTensor} from './proxy-messages'; +import {SerializableModeldata, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages'; import {setRunOptions} from './run-options'; import {setSessionOptions} from './session-options'; -import {logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; +import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; @@ -60,9 +60,36 @@ export const initRuntime = async(env: Env): Promise => { }; /** - * tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded + * valid data locations for input/output tensors. */ -type SessionMetadata = [number, number[], number[]]; +type SupportedTensorDataLocationForInputOutput = 'cpu'|'cpu-pinned'|'gpu-buffer'; + +type IOBindingState = { + /** + * the handle of IO binding. + */ + readonly handle: number; + + /** + * the preferred location for each output tensor. + * + * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer'. + */ + readonly outputPreferredLocations: readonly SupportedTensorDataLocationForInputOutput[]; + + /** + * enum value of the preferred location for each output tensor. + */ + readonly outputPreferredLocationsEncoded: readonly number[]; +}; + +/** + * tuple elements are: InferenceSession ID; inputNamesUTF8Encoded; outputNamesUTF8Encoded; bindingState + */ +type SessionMetadata = [ + inferenceSessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[], + bindingState: IOBindingState|null +]; const activeSessions = new Map(); @@ -92,6 +119,7 @@ export const createSessionFinalize = let sessionHandle = 0; let sessionOptionsHandle = 0; + let ioBindingHandle = 0; let allocs: number[] = []; const inputNamesUTF8Encoded = []; const outputNamesUTF8Encoded = []; @@ -108,6 +136,7 @@ export const createSessionFinalize = const inputNames = []; const outputNames = []; + const outputPreferredLocations: SupportedTensorDataLocationForInputOutput[] = []; for (let i = 0; i < inputCount; i++) { const name = wasm._OrtGetInputName(sessionHandle, i); if (name === 0) { @@ -122,15 +151,45 @@ export const createSessionFinalize = checkLastError('Can\'t get an output name.'); } outputNamesUTF8Encoded.push(name); - outputNames.push(wasm.UTF8ToString(name)); + const nameString = wasm.UTF8ToString(name); + outputNames.push(nameString); + + if (!BUILD_DEFS.DISABLE_WEBGPU) { + const location = typeof options?.preferredOutputLocation === 'string' ? + options.preferredOutputLocation : + options?.preferredOutputLocation?.[nameString] ?? 'cpu'; + if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { + throw new Error(`Not supported preferred output location: ${location}.`); + } + outputPreferredLocations.push(location); + } + } + + // use IO binding only when at least one output is preffered to be on GPU. + let bindingState: IOBindingState|null = null; + if (!BUILD_DEFS.DISABLE_WEBGPU && outputPreferredLocations.some(l => l === 'gpu-buffer')) { + ioBindingHandle = wasm._OrtCreateBinding(sessionHandle); + if (ioBindingHandle === 0) { + checkLastError('Can\'t create IO binding.'); + } + + bindingState = { + handle: ioBindingHandle, + outputPreferredLocations, + outputPreferredLocationsEncoded: outputPreferredLocations.map(l => dataLocationStringToEnum(l)), + }; } - activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded]); + activeSessions.set(sessionHandle, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, bindingState]); return [sessionHandle, inputNames, outputNames]; } catch (e) { inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + if (ioBindingHandle !== 0) { + wasm._OrtReleaseBinding(ioBindingHandle); + } + if (sessionHandle !== 0) { wasm._OrtReleaseSession(sessionHandle); } @@ -161,7 +220,13 @@ export const releaseSession = (sessionId: number): void => { if (!session) { throw new Error(`cannot release session. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; + + if (ioBindingState) { + wasm._OrtReleaseBinding(ioBindingState.handle); + } + + wasm.jsepUnregisterBuffers?.(sessionId); inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); @@ -169,18 +234,84 @@ export const releaseSession = (sessionId: number): void => { activeSessions.delete(sessionId); }; +const prepareInputOutputTensor = + (tensor: TensorMetadata|null, tensorHandles: number[], allocs: number[], sessionId: number, index: number): + void => { + if (!tensor) { + tensorHandles.push(0); + return; + } + + const wasm = getInstance(); + + const dataType = tensor[0]; + const dims = tensor[1]; + const location = tensor[3]; + + let rawData: number; + let dataByteLength: number; + + if (dataType === 'string' && location === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); + } + + if (location === 'gpu-buffer') { + const gpuBuffer = tensor[2].gpuBuffer as GPUBuffer; + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(dataType))!; + dataByteLength = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + rawData = wasm.jsepRegisterBuffer(sessionId, index, gpuBuffer, dataByteLength); + } else { + const data = tensor[2]; + + if (Array.isArray(data)) { + // string tensor + dataByteLength = 4 * data.length; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + let dataIndex = rawData / 4; + for (let i = 0; i < data.length; i++) { + if (typeof data[i] !== 'string') { + throw new TypeError(`tensor data at index ${i} is not a string`); + } + wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); + } + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } + + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + let dimIndex = dimsOffset / 4; + dims.forEach(d => wasm.HEAP32[dimIndex++] = d); + const tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, + dataLocationStringToEnum(location)); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + } + tensorHandles.push(tensor); + } finally { + wasm.stackRestore(stack); + } + }; + /** * perform inference run */ export const run = async( - sessionId: number, inputIndices: number[], inputs: SerializableTensor[], outputIndices: number[], - options: InferenceSession.RunOptions): Promise => { + sessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], + outputTensors: Array, options: InferenceSession.RunOptions): Promise => { const wasm = getInstance(); const session = activeSessions.get(sessionId); if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); } - const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded] = session; + const [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState] = session; const inputCount = inputIndices.length; const outputCount = outputIndices.length; @@ -188,171 +319,200 @@ export const run = async( let runOptionsHandle = 0; let runOptionsAllocs: number[] = []; - const inputValues: number[] = []; - const inputAllocs: number[] = []; + const inputTensorHandles: number[] = []; + const outputTensorHandles: number[] = []; + const inputOutputAllocs: number[] = []; + + const beforeRunStack = wasm.stackSave(); + const inputValuesOffset = wasm.stackAlloc(inputCount * 4); + const inputNamesOffset = wasm.stackAlloc(inputCount * 4); + const outputValuesOffset = wasm.stackAlloc(outputCount * 4); + const outputNamesOffset = wasm.stackAlloc(outputCount * 4); try { [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); // create input tensors for (let i = 0; i < inputCount; i++) { - const dataType = inputs[i][0]; - const dims = inputs[i][1]; - const data = inputs[i][2]; - - let dataOffset: number; - let dataByteLength: number; - - if (Array.isArray(data)) { - // string tensor - dataByteLength = 4 * data.length; - dataOffset = wasm._malloc(dataByteLength); - inputAllocs.push(dataOffset); - let dataIndex = dataOffset / 4; - for (let i = 0; i < data.length; i++) { - if (typeof data[i] !== 'string') { - throw new TypeError(`tensor data at index ${i} is not a string`); - } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], inputAllocs); - } - } else { - dataByteLength = data.byteLength; - dataOffset = wasm._malloc(dataByteLength); - inputAllocs.push(dataOffset); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), dataOffset); - } + prepareInputOutputTensor(inputTensors[i], inputTensorHandles, inputOutputAllocs, sessionId, inputIndices[i]); + } - const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); - try { - let dimIndex = dimsOffset / 4; - dims.forEach(d => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(dataType), dataOffset, dataByteLength, dimsOffset, dims.length); - if (tensor === 0) { - checkLastError(`Can't create tensor for input[${i}].`); - } - inputValues.push(tensor); - } finally { - wasm.stackRestore(stack); - } + // create output tensors + for (let i = 0; i < outputCount; i++) { + prepareInputOutputTensor( + outputTensors[i], outputTensorHandles, inputOutputAllocs, sessionId, inputCount + outputIndices[i]); + } + + let inputValuesIndex = inputValuesOffset / 4; + let inputNamesIndex = inputNamesOffset / 4; + let outputValuesIndex = outputValuesOffset / 4; + let outputNamesIndex = outputNamesOffset / 4; + for (let i = 0; i < inputCount; i++) { + wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i]; + wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]]; + } + for (let i = 0; i < outputCount; i++) { + wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i]; + wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; } - const beforeRunStack = wasm.stackSave(); - const inputValuesOffset = wasm.stackAlloc(inputCount * 4); - const inputNamesOffset = wasm.stackAlloc(inputCount * 4); - const outputValuesOffset = wasm.stackAlloc(outputCount * 4); - const outputNamesOffset = wasm.stackAlloc(outputCount * 4); - - try { - let inputValuesIndex = inputValuesOffset / 4; - let inputNamesIndex = inputNamesOffset / 4; - let outputValuesIndex = outputValuesOffset / 4; - let outputNamesIndex = outputNamesOffset / 4; + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + const {handle, outputPreferredLocations, outputPreferredLocationsEncoded} = ioBindingState; + + if (inputNamesUTF8Encoded.length !== inputCount) { + throw new Error(`input count from feeds (${ + inputCount}) is expected to be always equal to model's input count (${inputNamesUTF8Encoded.length}).`); + } + + // process inputs for (let i = 0; i < inputCount; i++) { - wasm.HEAPU32[inputValuesIndex++] = inputValues[i]; - wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]]; + const index = inputIndices[i]; + const errorCode = await wasm._OrtBindInput(handle, inputNamesUTF8Encoded[index], inputTensorHandles[i]); + if (errorCode !== 0) { + checkLastError(`Can't bind input[${i}] for session=${sessionId}.`); + } } + + // process pre-allocated outputs for (let i = 0; i < outputCount; i++) { - wasm.HEAPU32[outputValuesIndex++] = 0; - wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; + const index = outputIndices[i]; + const location = outputTensors[i]?.[3]; // undefined means output is not pre-allocated. + + if (location) { + // output is pre-allocated. bind the tensor. + const errorCode = wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], outputTensorHandles[i], 0); + if (errorCode !== 0) { + checkLastError(`Can't bind pre-allocated output[${i}] for session=${sessionId}.`); + } + } else { + // output is not pre-allocated. reset preferred location. + const errorCode = + wasm._OrtBindOutput(handle, outputNamesUTF8Encoded[index], 0, outputPreferredLocationsEncoded[index]); + if (errorCode !== 0) { + checkLastError(`Can't bind output[${i}] to ${outputPreferredLocations[i]} for session=${sessionId}.`); + } + } } + } - // jsepOnRunStart is only available when JSEP is enabled. - wasm.jsepOnRunStart?.(sessionId); + let errorCode: number; - // support RunOptions - let errorCode = wasm._OrtRun( + if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { + errorCode = await wasm._OrtRunWithBinding( + sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle); + } else { + errorCode = await wasm._OrtRun( sessionHandle, inputNamesOffset, inputValuesOffset, inputCount, outputNamesOffset, outputCount, outputValuesOffset, runOptionsHandle); + } - const runPromise = wasm.jsepRunPromise; - if (runPromise) { - // jsepRunPromise is a Promise object. It is only available when JSEP is enabled. - // - // OrtRun() is a synchrnous call, but it internally calls async functions. Emscripten's ASYNCIFY allows it to - // work in this way. However, OrtRun() does not return a promise, so when code reaches here, it is earlier than - // the async functions are finished. - // - // To make it work, we created a Promise and resolve the promise when the C++ code actually reaches the end of - // OrtRun(). If the promise exists, we need to await for the promise to be resolved. - errorCode = await runPromise; - } + if (errorCode !== 0) { + checkLastError('failed to call OrtRun().'); + } - const jsepOnRunEnd = wasm.jsepOnRunEnd; - if (jsepOnRunEnd) { - // jsepOnRunEnd is only available when JSEP is enabled. - // - // it returns a promise, which is resolved or rejected when the following async functions are finished: - // - collecting GPU validation errors. - await jsepOnRunEnd(sessionId); + const output: TensorMetadata[] = []; + + for (let i = 0; i < outputCount; i++) { + const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + if (tensor === outputTensorHandles[i]) { + // output tensor is pre-allocated. no need to copy data. + output.push(outputTensors[i]!); + continue; } - const output: SerializableTensor[] = []; + const beforeGetTensorDataStack = wasm.stackSave(); + // stack allocate 4 pointer value + const tensorDataOffset = wasm.stackAlloc(4 * 4); - if (errorCode !== 0) { - checkLastError('failed to call OrtRun().'); - } + let keepOutputTensor = false; + let type: Tensor.Type|undefined, dataOffset = 0; + try { + const errorCode = wasm._OrtGetTensorData( + tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + if (errorCode !== 0) { + checkLastError(`Can't access output tensor data on index ${i}.`); + } + let tensorDataIndex = tensorDataOffset / 4; + const dataType = wasm.HEAPU32[tensorDataIndex++]; + dataOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; + const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dims = []; + for (let i = 0; i < dimsLength; i++) { + dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + } + wasm._OrtFree(dimsOffset); - for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + const size = dims.reduce((a, b) => a * b, 1); + type = tensorDataTypeEnumToString(dataType); - const beforeGetTensorDataStack = wasm.stackSave(); - // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); + const preferredLocation = ioBindingState?.outputPreferredLocations[outputIndices[i]]; - let type: Tensor.Type|undefined, dataOffset = 0; - try { - errorCode = wasm._OrtGetTensorData( - tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); - if (errorCode !== 0) { - checkLastError(`Can't access output tensor data on index ${i}.`); + if (type === 'string') { + if (preferredLocation === 'gpu-buffer') { + throw new Error('String tensor is not supported on GPU.'); } - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; - const dims = []; - for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + const stringData: string[] = []; + let dataIndex = dataOffset / 4; + for (let i = 0; i < size; i++) { + const offset = wasm.HEAPU32[dataIndex++]; + const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); } - wasm._OrtFree(dimsOffset); - - const size = dims.length === 0 ? 1 : dims.reduce((a, b) => a * b); - type = tensorDataTypeEnumToString(dataType); - if (type === 'string') { - const stringData: string[] = []; - let dataIndex = dataOffset / 4; - for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; - stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); + output.push([type, dims, stringData, 'cpu']); + } else { + // If a certain output's preferred location is GPU but the tensor is empty, we still need to create a CPU + // tensor for it. There is no mapping GPU buffer for an empty tensor. + if (preferredLocation === 'gpu-buffer' && size > 0) { + const gpuBuffer = wasm.jsepGetBuffer(dataOffset); + const elementSize = getTensorElementSize(dataType); + if (elementSize === undefined || !isGpuBufferSupportedType(type)) { + throw new Error(`Unsupported data type: ${type}`); } - output.push([type, dims, stringData]); + + // do not release the tensor right now. it will be released when user calls tensor.dispose(). + keepOutputTensor = true; + + output.push([ + type, dims, { + gpuBuffer, + download: wasm.jsepCreateDownloader(gpuBuffer, size * elementSize, type), + dispose: () => { + wasm._OrtReleaseTensor(tensor); + } + }, + 'gpu-buffer' + ]); } else { const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); const data = new typedArrayConstructor(size); new Uint8Array(data.buffer, data.byteOffset, data.byteLength) .set(wasm.HEAPU8.subarray(dataOffset, dataOffset + data.byteLength)); - output.push([type, dims, data]); - } - } finally { - wasm.stackRestore(beforeGetTensorDataStack); - if (type === 'string' && dataOffset) { - wasm._free(dataOffset); + output.push([type, dims, data, 'cpu']); } + } + } finally { + wasm.stackRestore(beforeGetTensorDataStack); + if (type === 'string' && dataOffset) { + wasm._free(dataOffset); + } + if (!keepOutputTensor) { wasm._OrtReleaseTensor(tensor); } } + } - return output; - } finally { - wasm.stackRestore(beforeRunStack); + if (ioBindingState) { + wasm._OrtClearBoundOutputs(ioBindingState.handle); } + + return output; } finally { - inputValues.forEach(v => wasm._OrtReleaseTensor(v)); - inputAllocs.forEach(p => wasm._free(p)); + wasm.stackRestore(beforeRunStack); + + inputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + outputTensorHandles.forEach(v => wasm._OrtReleaseTensor(v)); + inputOutputAllocs.forEach(p => wasm._free(p)); if (runOptionsHandle !== 0) { wasm._OrtReleaseRunOptions(runOptionsHandle); @@ -380,11 +540,11 @@ export const endProfiling = (sessionId: number): void => { wasm._OrtFree(profileFileName); }; -export const extractTransferableBuffers = (tensors: readonly SerializableTensor[]): ArrayBufferLike[] => { +export const extractTransferableBuffers = (tensors: readonly SerializableTensorMetadata[]): ArrayBufferLike[] => { const buffers: ArrayBufferLike[] = []; for (const tensor of tensors) { const data = tensor[2]; - if (!Array.isArray(data) && data.buffer) { + if (!Array.isArray(data) && 'buffer' in data) { buffers.push(data.buffer); } } diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index f90f568879146..31bca8b94306d 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -51,6 +51,10 @@ Options: -P[=<...>], --perf[=<...>] Generate performance number. Cannot be used with flag --debug. This flag can be used with a number as value, specifying the total count of test cases to run. The test cases may be used multiple times. Default value is 10. -c, --file-cache Enable file cache. + -i=<...>, --io-binding=<...> Specify the IO binding testing type. Should be one of the following: + none (default) + gpu-tensor use pre-allocated GPU tensors for inputs and outputs + gpu-location use pre-allocated GPU tensors for inputs and set preferredOutputLocation to 'gpu-buffer' *** Session Options *** -u=<...>, --optimized-model-file-path=<...> Specify whether to dump the optimized model. @@ -109,6 +113,7 @@ export declare namespace TestRunnerCliArgs { type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'xnnpack'|'webnn'; type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'; type BundleMode = 'prod'|'dev'|'perf'; + type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location'; } export interface TestRunnerCliArgs { @@ -140,6 +145,8 @@ export interface TestRunnerCliArgs { */ bundleMode: TestRunnerCliArgs.BundleMode; + ioBindingMode: TestRunnerCliArgs.IOBindingMode; + logConfig: Test.Config['log']; /** @@ -326,7 +333,11 @@ function parseWebgpuFlags(args: minimist.ParsedArgs): Partial { if (profilingMode !== undefined && profilingMode !== 'off' && profilingMode !== 'default') { throw new Error('Flag "webgpu-profiling-mode" is invalid'); } - return {profilingMode}; + const validateInputContent = args['webgpu-validate-input-content']; + if (validateInputContent !== undefined && typeof validateInputContent !== 'boolean') { + throw new Error('Flag "webgpu-validate-input-content" is invalid'); + } + return {profilingMode, validateInputContent}; } function parseGlobalEnvFlags(args: minimist.ParsedArgs): NonNullable { @@ -416,6 +427,13 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs logConfig.push({category: 'TestRunner.Perf', config: {minimalSeverity: 'verbose'}}); } + // Option: -i=<...>, --io-binding=<...> + const ioBindingArg = args['io-binding'] || args.i; + const ioBindingMode = (typeof ioBindingArg !== 'string') ? 'none' : ioBindingArg; + if (['none', 'gpu-tensor', 'gpu-location'].indexOf(ioBindingMode) === -1) { + throw new Error(`not supported io binding mode ${ioBindingMode}`); + } + // Option: -u, --optimized-model-file-path const optimizedModelFilePath = args['optimized-model-file-path'] || args.u || undefined; if (typeof optimizedModelFilePath !== 'undefined' && typeof optimizedModelFilePath !== 'string') { @@ -455,6 +473,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs npmlog.verbose('TestRunnerCli.Init', ` Env: ${env}`); npmlog.verbose('TestRunnerCli.Init', ` Debug: ${debug}`); npmlog.verbose('TestRunnerCli.Init', ` Backend: ${backend}`); + npmlog.verbose('TestRunnerCli.Init', ` IO Binding Mode: ${ioBindingMode}`); npmlog.verbose('TestRunnerCli.Init', 'Parsing commandline arguments... DONE'); return { @@ -467,6 +486,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs logConfig, profile, times: perf ? times : undefined, + ioBindingMode: ioBindingMode as TestRunnerCliArgs['ioBindingMode'], optimizedModelFilePath, graphOptimizationLevel: graphOptimizationLevel as TestRunnerCliArgs['graphOptimizationLevel'], fileCache, diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index f3764e63fcf45..d8fecec1b8084 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -257,7 +257,7 @@ async function main() { times?: number): Test.ModelTest { if (times === 0) { npmlog.verbose('TestRunnerCli.Init.Model', `Skip test data from folder: ${testDataRootFolder}`); - return {name: path.basename(testDataRootFolder), backend, modelUrl: '', cases: []}; + return {name: path.basename(testDataRootFolder), backend, modelUrl: '', cases: [], ioBinding: args.ioBindingMode}; } let modelUrl: string|null = null; @@ -323,6 +323,16 @@ async function main() { } } + let ioBinding: Test.IOBindingMode; + if (backend !== 'webgpu' && args.ioBindingMode !== 'none') { + npmlog.warn( + 'TestRunnerCli.Init.Model', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`); + ioBinding = 'none'; + } else { + ioBinding = args.ioBindingMode; + } + + npmlog.verbose('TestRunnerCli.Init.Model', 'Finished preparing test data.'); npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); npmlog.verbose('TestRunnerCli.Init.Model', ` Model file: ${modelUrl}`); @@ -330,7 +340,7 @@ async function main() { npmlog.verbose('TestRunnerCli.Init.Model', ` Test set(s): ${cases.length} (${caseCount})`); npmlog.verbose('TestRunnerCli.Init.Model', '==============================================================='); - return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases}; + return {name: path.basename(testDataRootFolder), platformCondition, modelUrl, backend, cases, ioBinding}; } function tryLocateModelTestFolder(searchPattern: string): string { @@ -390,6 +400,13 @@ async function main() { for (const test of tests) { test.backend = backend; test.opset = test.opset || {domain: '', version: MAX_OPSET_VERSION}; + if (backend !== 'webgpu' && args.ioBindingMode !== 'none') { + npmlog.warn( + 'TestRunnerCli.Init.Op', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`); + test.ioBinding = 'none'; + } else { + test.ioBinding = args.ioBindingMode; + } } npmlog.verbose('TestRunnerCli.Init.Op', 'Finished preparing test data.'); npmlog.verbose('TestRunnerCli.Init.Op', '==============================================================='); diff --git a/js/web/test/data/ops/conv-transpose.jsonc b/js/web/test/data/ops/conv-transpose.jsonc index a249dc807fa0b..7038e2a4f8766 100644 --- a/js/web/test/data/ops/conv-transpose.jsonc +++ b/js/web/test/data/ops/conv-transpose.jsonc @@ -28,6 +28,37 @@ } ] }, + { + "name": "ConvTranspose without bias addition A - NHWC", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "operator": "ConvTranspose", + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [10, 40, 40, 60, 200, 160, 90, 240, 160], + "dims": [1, 1, 3, 3], + "type": "float32" + } + ] + } + ] + }, { "name": "ConvTranspose without bias addition B", "operator": "ConvTranspose", @@ -74,26 +105,22 @@ }, { "data": [ - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, - 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 ], "dims": [4, 4, 2, 2], "type": "float32" }, { - "data": [0.1, 0.2, 0.3, 0.4], + "data": [65, 66, 67, 68], "dims": [4], "type": "float32" } ], "outputs": [ { - "data": [ - 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.0999984741211, 100.19999694824219, - 100.19999694824219, 100.19999694824219, 100.19999694824219, 100.30000305175781, 100.30000305175781, - 100.30000305175781, 100.30000305175781, 100.4000015258789, 100.4000015258789, 100.4000015258789, - 100.4000015258789 - ], + "data": [3365, 3465, 3565, 3665, 3766, 3866, 3966, 4066, 4167, 4267, 4367, 4467, 4568, 4668, 4768, 4868], "dims": [1, 4, 2, 2], "type": "float32" } @@ -115,7 +142,43 @@ "type": "float32" }, { - "data": [1, 1, 1, 1], + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + }, + { + "data": [5], + "dims": [1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [11, 25, 28, 19, 32, 86, 99, 55, 40, 114, 131, 67, 29, 73, 80, 41], + "dims": [1, 1, 4, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "ConvTranspose with bias addition B - NHWC", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [{ "name": "kernel_shape", "data": [2, 2], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [6, 8, 7, 9, 15, 11, 8, 12, 9], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], "dims": [1, 1, 2, 2], "type": "float32" }, @@ -127,7 +190,7 @@ ], "outputs": [ { - "data": [11, 19, 20, 12, 20, 43, 46, 23, 22, 49, 52, 25, 13, 25, 26, 14], + "data": [11, 25, 28, 19, 32, 86, 99, 55, 40, 114, 131, 67, 29, 73, 80, 41], "dims": [1, 1, 4, 4], "type": "float32" } @@ -251,7 +314,6 @@ } ] }, - { "name": "ConvTranspose- pointwise", "operator": "ConvTranspose", @@ -285,5 +347,50 @@ ] } ] + }, + { + "name": "ConvTranspose with bias addition C", + "operator": "ConvTranspose", + "inputShapeDefinitions": "rankOnly", + "opset": { "domain": "", "version": 17 }, + "attributes": [{ "name": "kernel_shape", "data": [1, 1], "type": "ints" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, + 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, + 56, 57, 58, 59, 60, 61, 62, 63, 64 + ], + "dims": [1, 4, 4, 4], + "type": "float32" + }, + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + "dims": [4, 4, 1, 1], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + 1021, 1049, 1077, 1105, 1133, 1161, 1189, 1217, 1245, 1273, 1301, 1329, 1357, 1385, 1413, 1441, 1122, + 1154, 1186, 1218, 1250, 1282, 1314, 1346, 1378, 1410, 1442, 1474, 1506, 1538, 1570, 1602, 1223, 1259, + 1295, 1331, 1367, 1403, 1439, 1475, 1511, 1547, 1583, 1619, 1655, 1691, 1727, 1763, 1324, 1364, 1404, + 1444, 1484, 1524, 1564, 1604, 1644, 1684, 1724, 1764, 1804, 1844, 1884, 1924 + ], + "dims": [1, 4, 4, 4], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 6e65645ef4756..96ced2bdf9216 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -885,10 +885,10 @@ // // "test_qlinearmatmul_3D", // // "test_quantizelinear_axis", // // "test_quantizelinear", - // "test_range_float_type_positive_delta_expanded", - // "test_range_float_type_positive_delta", - // "test_range_int32_type_negative_delta_expanded", - // "test_range_int32_type_negative_delta", + "test_range_float_type_positive_delta_expanded", + "test_range_float_type_positive_delta", + "test_range_int32_type_negative_delta_expanded", + "test_range_int32_type_negative_delta", "test_reciprocal_example", "test_reciprocal", "test_reduce_l1_default_axes_keepdims_example", diff --git a/js/web/test/test-main.ts b/js/web/test/test-main.ts index 49d0ac225be2f..d3592875bb6c7 100644 --- a/js/web/test/test-main.ts +++ b/js/web/test/test-main.ts @@ -57,6 +57,9 @@ if (options.globalEnvFlags) { if (flags.webgpu?.profilingMode !== undefined) { ort.env.webgpu.profilingMode = flags.webgpu.profilingMode; } + if (flags.webgpu?.validateInputContent !== undefined) { + ort.env.webgpu.validateInputContent = flags.webgpu.validateInputContent; + } } // Set logging configuration diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 46d80a9f56f35..628e5408150f8 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -14,7 +14,8 @@ import {Operator} from '../lib/onnxjs/operators'; import {onnx} from '../lib/onnxjs/ort-schema/protobuf/onnx'; import {Tensor} from '../lib/onnxjs/tensor'; import {ProtoUtil} from '../lib/onnxjs/util'; -import {tensorDataTypeStringToEnum} from '../lib/wasm/wasm-common'; +import {createView} from '../lib/wasm/jsep/tensor-view'; +import {getTensorElementSize, isGpuBufferSupportedType, tensorDataTypeStringToEnum} from '../lib/wasm/wasm-common'; import {base64toBuffer, createMockGraph, readFile} from './test-shared'; import {Test} from './test-types'; @@ -136,8 +137,8 @@ async function loadTensors( } async function initializeSession( - modelFilePath: string, backendHint: string, profile: boolean, sessionOptions: ort.InferenceSession.SessionOptions, - fileCache?: FileCacheBuffer): Promise { + modelFilePath: string, backendHint: string, ioBindingMode: Test.IOBindingMode, profile: boolean, + sessionOptions: ort.InferenceSession.SessionOptions, fileCache?: FileCacheBuffer): Promise { const preloadModelData: Uint8Array|undefined = fileCache && fileCache[modelFilePath] ? fileCache[modelFilePath] : undefined; Logger.verbose( @@ -146,8 +147,14 @@ async function initializeSession( preloadModelData ? ` [preloaded(${preloadModelData.byteLength})]` : ''}`); const profilerConfig = profile ? {maxNumberEvents: 65536} : undefined; - const sessionConfig = - {...sessionOptions, executionProviders: [backendHint], profiler: profilerConfig, enableProfiling: profile}; + const sessionConfig = { + ...sessionOptions, + executionProviders: [backendHint], + profiler: profilerConfig, + enableProfiling: profile, + preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined + }; + let session: ort.InferenceSession; try { @@ -181,6 +188,7 @@ export class ModelTestContext { readonly session: ort.InferenceSession, readonly backend: string, readonly perfData: ModelTestContext.ModelTestPerfData, + readonly ioBinding: Test.IOBindingMode, private readonly profile: boolean, ) {} @@ -232,8 +240,8 @@ export class ModelTestContext { this.initializing = true; const initStart = now(); - const session = - await initializeSession(modelTest.modelUrl, modelTest.backend!, profile, sessionOptions || {}, this.cache); + const session = await initializeSession( + modelTest.modelUrl, modelTest.backend!, modelTest.ioBinding, profile, sessionOptions || {}, this.cache); const initEnd = now(); for (const testCase of modelTest.cases) { @@ -244,6 +252,7 @@ export class ModelTestContext { session, modelTest.backend!, {init: initEnd - initStart, firstRun: -1, runs: [], count: 0}, + modelTest.ioBinding, profile, ); } finally { @@ -481,6 +490,130 @@ export class TensorResultValidator { } } +function createGpuTensorForInput(cpuTensor: ort.Tensor): ort.Tensor { + if (!isGpuBufferSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) { + throw new Error(`createGpuTensorForInput can not work with ${cpuTensor.type} tensor`); + } + const device = ort.env.webgpu.device as GPUDevice; + const gpuBuffer = device.createBuffer({ + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, + size: Math.ceil(cpuTensor.data.byteLength / 16) * 16, + mappedAtCreation: true + }); + const arrayBuffer = gpuBuffer.getMappedRange(); + new Uint8Array(arrayBuffer) + .set(new Uint8Array(cpuTensor.data.buffer, cpuTensor.data.byteOffset, cpuTensor.data.byteLength)); + gpuBuffer.unmap(); + + // TODO: how to "await" for the copy to finish, so that we can get more accurate performance data? + + return ort.Tensor.fromGpuBuffer( + gpuBuffer, {dataType: cpuTensor.type, dims: cpuTensor.dims, dispose: () => gpuBuffer.destroy()}); +} + +function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[]) { + if (!isGpuBufferSupportedType(type)) { + throw new Error(`createGpuTensorForOutput can not work with ${type} tensor`); + } + + const elementSizeInBytes = getTensorElementSize(tensorDataTypeStringToEnum(type))!; + const size = dims.reduce((a, b) => a * b, 1) * elementSizeInBytes; + + const device = ort.env.webgpu.device as GPUDevice; + const gpuBuffer = device.createBuffer({ + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.COPY_DST | GPUBufferUsage.STORAGE, + size: Math.ceil(size / 16) * 16 + }); + + return ort.Tensor.fromGpuBuffer(gpuBuffer, { + dataType: type, + dims, + dispose: () => gpuBuffer.destroy(), + download: async () => { + const stagingBuffer = device.createBuffer({ + // eslint-disable-next-line no-bitwise + usage: GPUBufferUsage.MAP_READ | GPUBufferUsage.COPY_DST, + size: gpuBuffer.size + }); + const encoder = device.createCommandEncoder(); + encoder.copyBufferToBuffer(gpuBuffer, 0, stagingBuffer, 0, gpuBuffer.size); + device.queue.submit([encoder.finish()]); + + await stagingBuffer.mapAsync(GPUMapMode.READ); + const arrayBuffer = stagingBuffer.getMappedRange().slice(0, size); + stagingBuffer.destroy(); + + return createView(arrayBuffer, type) as ort.Tensor.DataTypeMap[ort.Tensor.GpuBufferDataTypes]; + } + }); +} + +export async function sessionRun(options: { + session: ort.InferenceSession; feeds: Record; + outputsMetaInfo: Record>; + ioBinding: Test.IOBindingMode; +}): Promise<[number, number, ort.InferenceSession.OnnxValueMapType]> { + const session = options.session; + const feeds = options.feeds; + const fetches: Record = {}; + + // currently we only support IO Binding for WebGPU + // + // For inputs, we create GPU tensors on both 'gpu-tensor' and 'gpu-location' binding testing mode. + // For outputs, we create GPU tensors on 'gpu-tensor' binding testing mode only. + // in 'gpu-device' binding mode, outputs are not pre-allocated. + const shouldUploadInput = options.ioBinding === 'gpu-tensor' || options.ioBinding === 'gpu-location'; + const shouldUploadOutput = options.ioBinding === 'gpu-tensor'; + try { + if (shouldUploadInput) { + // replace the CPU tensors in feeds into GPU tensors + for (const name in feeds) { + if (Object.hasOwnProperty.call(feeds, name)) { + feeds[name] = createGpuTensorForInput(feeds[name]); + } + } + } + + if (shouldUploadOutput) { + for (const name in options.outputsMetaInfo) { + if (Object.hasOwnProperty.call(options.outputsMetaInfo, name)) { + const {type, dims} = options.outputsMetaInfo[name]; + fetches[name] = createGpuTensorForOutput(type, dims); + } + } + } + + const start = now(); + Logger.verbose('TestRunner', `Timestamp before session run: ${start}`); + const outputs = await ( + shouldUploadOutput ? session.run(feeds, fetches) : + session.run(feeds, Object.getOwnPropertyNames(options.outputsMetaInfo))); + const end = now(); + Logger.verbose('TestRunner', `Timestamp after session run: ${end}`); + + // download each output tensor if needed + for (const name in outputs) { + if (Object.hasOwnProperty.call(outputs, name)) { + const tensor = outputs[name]; + // Tensor.getData(true) release the underlying resource + await tensor.getData(true); + } + } + + return [start, end, outputs]; + } finally { + // dispose the GPU tensors in feeds + for (const name in feeds) { + if (Object.hasOwnProperty.call(feeds, name)) { + const tensor = feeds[name]; + tensor.dispose(); + } + } + } +} + /** * run a single model test case. the inputs/outputs tensors should already been prepared. */ @@ -491,12 +624,11 @@ export async function runModelTestSet( const validator = new TensorResultValidator(context.backend); try { const feeds: Record = {}; + const outputsMetaInfo: Record = {}; testCase.inputs!.forEach((tensor, i) => feeds[context.session.inputNames[i]] = tensor); - const start = now(); - Logger.verbose('TestRunner', `Timestamp before session run: ${start}`); - const outputs = await context.session.run(feeds); - const end = now(); - Logger.verbose('TestRunner', `Timestamp after session run: ${end}`); + testCase.outputs!.forEach((tensor, i) => outputsMetaInfo[context.session.outputNames[i]] = tensor); + const [start, end, outputs] = + await sessionRun({session: context.session, feeds, outputsMetaInfo, ioBinding: context.ioBinding}); if (context.perfData.count === 0) { context.perfData.firstRun = end - start; } else { @@ -575,6 +707,7 @@ export class ProtoOpTestContext { private readonly loadedData: Uint8Array; // model data, inputs, outputs session: ort.InferenceSession; readonly backendHint: string; + readonly ioBindingMode: Test.IOBindingMode; constructor(test: Test.OperatorTest, private readonly sessionOptions: ort.InferenceSession.SessionOptions = {}) { const opsetImport = onnx.OperatorSetIdProto.create(test.opset); const operator = test.operator; @@ -713,6 +846,7 @@ export class ProtoOpTestContext { model.graph.name = test.name; this.backendHint = test.backend!; + this.ioBindingMode = test.ioBinding; this.loadedData = onnx.ModelProto.encode(model).finish(); // in debug mode, open a new tab in browser for the generated onnx model. @@ -729,8 +863,11 @@ export class ProtoOpTestContext { } } async init(): Promise { - this.session = await ort.InferenceSession.create( - this.loadedData, {executionProviders: [this.backendHint], ...this.sessionOptions}); + this.session = await ort.InferenceSession.create(this.loadedData, { + executionProviders: [this.backendHint], + preferredOutputLocation: this.ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined, + ...this.sessionOptions + }); } async dispose(): Promise { @@ -739,10 +876,11 @@ export class ProtoOpTestContext { } async function runProtoOpTestcase( - session: ort.InferenceSession, testCase: Test.OperatorTestCase, validator: TensorResultValidator): Promise { + session: ort.InferenceSession, testCase: Test.OperatorTestCase, ioBindingMode: Test.IOBindingMode, + validator: TensorResultValidator): Promise { const feeds: Record = {}; - const fetches: string[] = []; - testCase.inputs!.forEach((input, i) => { + const fetches: Record> = {}; + testCase.inputs.forEach((input, i) => { if (input.data) { let data: number[]|BigUint64Array|BigInt64Array = input.data; if (input.type === 'uint64') { @@ -756,7 +894,7 @@ async function runProtoOpTestcase( const outputs: ort.Tensor[] = []; const expectedOutputNames: string[] = []; - testCase.outputs!.forEach((output, i) => { + testCase.outputs.forEach((output, i) => { if (output.data) { let data: number[]|BigUint64Array|BigInt64Array = output.data; if (output.type === 'uint64') { @@ -766,11 +904,11 @@ async function runProtoOpTestcase( } outputs.push(new ort.Tensor(output.type, data, output.dims)); expectedOutputNames.push(`output_${i}`); - fetches.push(`output_${i}`); + fetches[`output_${i}`] = {dims: output.dims, type: output.type}; } }); - const results = await session.run(feeds, fetches); + const [, , results] = await sessionRun({session, feeds, outputsMetaInfo: fetches, ioBinding: ioBindingMode}); const actualOutputNames = Object.getOwnPropertyNames(results); expect(actualOutputNames.length).to.equal(expectedOutputNames.length); @@ -821,7 +959,8 @@ async function runOpTestcase( export async function runOpTest( testcase: Test.OperatorTestCase, context: ProtoOpTestContext|OpTestContext): Promise { if (context instanceof ProtoOpTestContext) { - await runProtoOpTestcase(context.session, testcase, new TensorResultValidator(context.backendHint)); + await runProtoOpTestcase( + context.session, testcase, context.ioBindingMode, new TensorResultValidator(context.backendHint)); } else { await runOpTestcase( context.inferenceHandler, context.createOperator(), testcase, new TensorResultValidator(context.backendHint)); diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts index 1f95d1cd8e682..88915e7972383 100644 --- a/js/web/test/test-types.ts +++ b/js/web/test/test-types.ts @@ -43,6 +43,18 @@ export declare namespace Test { */ export type PlatformCondition = string; + /** + * The IOBindingMode represents how to test a model with GPU data. + * + * - none: inputs will be pre-allocated as CPU tensors; no output will be pre-allocated; `preferredOutputLocation` + * will not be set. + * - gpu-location: inputs will be pre-allocated as GPU tensors; no output will be pre-allocated; + * `preferredOutputLocation` will be set to `gpu-buffer`. + * - gpu-tensor: inputs and outputs will all be pre-allocated as GPU tensors. `preferredOutputLocation` + * will not be set. + */ + export type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location'; + export interface ModelTestCase { name: string; dataFiles: readonly string[]; @@ -54,6 +66,7 @@ export declare namespace Test { name: string; modelUrl: string; backend?: string; // value should be populated at build time + ioBinding: IOBindingMode; platformCondition?: PlatformCondition; cases: readonly ModelTestCase[]; } @@ -82,6 +95,7 @@ export declare namespace Test { inputShapeDefinitions?: 'none'|'rankOnly'|'static'|ReadonlyArray; opset?: OperatorTestOpsetImport; backend?: string; // value should be populated at build time + ioBinding: IOBindingMode; platformCondition?: PlatformCondition; attributes?: readonly AttributeValue[]; cases: readonly OperatorTestCase[]; diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index fd147eaa11f3f..0ed7d887fc5e5 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -42,6 +42,7 @@ from onnxruntime.capi._pybind_state import get_build_info # noqa: F401 from onnxruntime.capi._pybind_state import get_device # noqa: F401 from onnxruntime.capi._pybind_state import get_version_string # noqa: F401 + from onnxruntime.capi._pybind_state import has_collective_ops # noqa: F401 from onnxruntime.capi._pybind_state import set_default_logger_severity # noqa: F401 from onnxruntime.capi._pybind_state import set_default_logger_verbosity # noqa: F401 from onnxruntime.capi._pybind_state import set_seed # noqa: F401 diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index f0385ea5abdfb..48af0a9a6adec 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -140,27 +140,29 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { #endif if (!use_flash_attention) { - if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT - // GPT fused kernels requires left side padding. mask can be: - // none (no padding), 1D sequence lengths or 2d mask. - // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token - // where past state is empty. - bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; - bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && - nullptr == relative_position_bias && - parameters.past_sequence_length == 0 && - parameters.hidden_size == parameters.v_hidden_size && - FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, - enable_trt_flash_attention_, true); - if (use_causal_fused_runner) { - // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. - if (nullptr == fused_fp16_runner_.get()) { - fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, - enable_trt_flash_attention_, parameters.scale); + if (is_unidirectional_) { // GPT + if (enable_fused_causal_attention_) { + // GPT fused kernels requires left side padding. mask can be: + // none (no padding), 1D sequence lengths or 2d mask. + // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token + // where past state is empty. + bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; + bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && + nullptr == relative_position_bias && + parameters.past_sequence_length == 0 && + parameters.hidden_size == parameters.v_hidden_size && + FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length, + enable_trt_flash_attention_, true); + if (use_causal_fused_runner) { + // Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node. + if (nullptr == fused_fp16_runner_.get()) { + fused_fp16_runner_ = FusedMHARunnerFP16v2::Create(num_heads_, parameters.head_size, sm, is_unidirectional_, + enable_trt_flash_attention_, parameters.scale); + } + + // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. + fused_runner = fused_fp16_runner_.get(); } - - // Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check. - fused_runner = fused_fp16_runner_.get(); } } else { // BERT bool use_fused_runner = !disable_fused_self_attention_ && diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index dede1ecc95885..1b492a3561396 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -177,9 +177,9 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - // Run layout transformer only for EPs other than CPU EP and provided the preferred layout is NHWC + // Run layout transformer for EPs with preferred layout of NHWC // CPU EP layout transformation happens later when level 3 transformers are run. - if (params.mode != GraphPartitioner::Mode::kAssignOnly && + if (params.mode != GraphPartitioner::Mode::kAssignOnly && params.transform_layout.get() && current_ep.GetPreferredLayout() == DataLayout::NHWC) { for (auto& capability : capabilities) { TryAssignNodes(graph, *capability->sub_graph, ep_type); diff --git a/onnxruntime/core/framework/kernel_registry_manager.cc b/onnxruntime/core/framework/kernel_registry_manager.cc index 38c8a4a4e3d5e..c4eef5b27c1bb 100644 --- a/onnxruntime/core/framework/kernel_registry_manager.cc +++ b/onnxruntime/core/framework/kernel_registry_manager.cc @@ -63,8 +63,9 @@ Status KernelRegistryManager::SearchKernelRegistry(const Node& node, auto create_error_message = [&node, &status](const std::string& prefix) { std::ostringstream errormsg; errormsg << prefix << node.OpType() << "(" << node.SinceVersion() << ")"; - if (!node.Name().empty()) errormsg << " (node " << node.Name() << "). "; - if (!status.IsOK()) errormsg << status.ErrorMessage(); + errormsg << " (node:'" << node.Name() << "' ep:'" << node.GetExecutionProviderType() << "'). "; + if (!status.IsOK()) + errormsg << status.ErrorMessage(); return errormsg.str(); }; diff --git a/onnxruntime/core/framework/session_options.h b/onnxruntime/core/framework/session_options.h index aff90b8d40bde..8deeb4c2b8b64 100644 --- a/onnxruntime/core/framework/session_options.h +++ b/onnxruntime/core/framework/session_options.h @@ -148,6 +148,10 @@ struct SessionOptions { std::shared_ptr custom_op_libs; void AddCustomOpLibraryHandle(PathString library_name, void* library_handle); #endif + + // User specified logging func and param + OrtLoggingFunction user_logging_function = nullptr; + void* user_logging_param = nullptr; }; } // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc index 3ce7c40e754dc..d3fc5873cb274 100644 --- a/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc +++ b/onnxruntime/core/graph/contrib_ops/internal_nhwc_onnx_schemas.cc @@ -94,7 +94,7 @@ void OpSet_Internal_NHWC_ONNX::ForEachSchema(const std::function> GenerateTransformers( const InlinedHashSet cpu_ep = {onnxruntime::kCpuExecutionProvider}; #endif const InlinedHashSet dml_ep = {onnxruntime::kDmlExecutionProvider}; + AllocatorPtr cpu_allocator = std::make_shared(); + switch (level) { case TransformerLevel::Level1: { // RewriteRule optimizations are the simplest (they generally remove unnecessary nodes and are cheap to run) @@ -240,13 +242,14 @@ InlinedVector> GenerateTransformers( // run TransposeOptimizer last as it works in a slightly different way by moving Transpose nodes around. // shouldn't affect the end result - just easier to debug any issue if it's last. - // local CPU allocator is enough as this allocator is finally passed to a local tensor. - // We will also benefit by using a local allocator as we don't need to pass allocator as parameter for EP API refactor - AllocatorPtr cpu_allocator = std::make_shared(); transformers.emplace_back(std::make_unique(std::move(cpu_allocator))); } break; case TransformerLevel::Level2: { + // we run TransposeOptimizer again in Level2 for some CPU EP specific optimizations that can only be + // applied once nodes are assigned to the CPU EP (which happens between level 1 and level 2). + transformers.emplace_back(std::make_unique(std::move(cpu_allocator), kCpuExecutionProvider)); + const bool enable_quant_qdq_cleanup = session_options.config_options.GetConfigOrDefault(kOrtSessionOptionsEnableQuantQDQCleanup, "0") == "1"; #if !defined(DISABLE_CONTRIB_OPS) @@ -366,16 +369,16 @@ InlinedVector> GenerateTransformers( if (MlasNchwcGetBlockSize() > 1) { transformers.emplace_back(std::make_unique()); } - AllocatorPtr cpu_allocator = std::make_shared(); + auto cpu_registry = cpu_execution_provider.GetKernelRegistry(); auto nhwc_transformer = std::make_unique(std::move(cpu_allocator), std::move(cpu_registry)); if (nhwc_transformer->IsActive()) { transformers.emplace_back(std::move(nhwc_transformer)); } - // NCHWCtransformer should have a higher priority versus this. Because NCHWCtransformer also do the similar things - // of fusion patterns and target on CPU. However, NCHWCtransformer will reorder the layout to nchwc which is only available for - // x86-64 cpu, not edge cpu like arm. But This transformer could be used by opencl-ep/cpu-ep. So - // we will prefer NhwcTransformer once ort runs on x86-64 CPU, otherwise ConvAddActivationFusion is enabled. + + // NchwcTransformer must have a higher priority than ConvAddActivationFusion. NchwcTransformer does similar + // fusions targeting CPU but also reorders the layout to NCHWc which is expected to be more efficient but is + // only available on x86-64. // PR #6351 implemented similar fusion-pattern for CUDA only, and can only fuse conv-add-relu, // while we can fuse more activation. transformers.emplace_back(std::make_unique(cpu_ep)); diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc index 2d12c407e6e31..6c91949e467ae 100644 --- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc +++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc @@ -13,27 +13,91 @@ using namespace onnx_transpose_optimization; namespace onnxruntime { namespace layout_transformation { +namespace { +// Cost check for aggressively pushing the Transpose nodes involved in the layout transformation further out. +CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const api::NodeRef& node, + const std::vector& perm, + const std::unordered_set& outputs_leading_to_transpose) { + // we aggressively push the layout transpose nodes. + // Exception: pushing through a Concat can result in Transpose nodes being added to multiple other inputs which + // can potentially be worse for performance. Use the cost check in that case. + if (node.OpType() != "Concat" && + (perm == ChannelFirstToLastPerm(perm.size()) || perm == ChannelLastToFirstPerm(perm.size()))) { + return CostCheckResult::kPushTranspose; + } + + // for other nodes use the default ORT cost check + return OrtEPCostCheck(graph, node, perm, outputs_leading_to_transpose); +} + +/// +/// Default function for checking if a node should have its layout changed. Allows EP specific adjustments to the +/// default set of layout sensitive operators if required. +/// +/// Longer term, if required, the EP API could allow the EP to provide a delegate to plugin EP specific logic so we +/// don't hardcode it here. +/// +/// Node to check +/// true if the node should have its layout converted to NHWC. +bool ConvertNodeLayout(const api::NodeRef& node) { + // skip if op is not an ONNX or contrib op + auto domain = node.Domain(); + if (domain != kOnnxDomain && domain != kMSDomain) { + return false; + } + + const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); + + // handle special cases +#if defined(USE_XNNPACK) + if (node.GetExecutionProviderType() == kXnnpackExecutionProvider) { + if (node.OpType() == "Resize") { + // XNNPACK supports NCHW and NHWC for Resize so we don't need to use the internal NHWC domain and wrap the Resize + // with Transpose nodes. EPAwareHandleResize will allow an NCHW <-> NHWC Transpose to be pushed through + // the Resize during transpose optimization. + return false; + } + } +#endif + +#if defined(USE_JSEP) + // TODO(fs-eire): Remove special case handing of JSEP once NHWC Resize implementation is fixed + if (node.GetExecutionProviderType() == kJsExecutionProvider) { + if (node.OpType() == "Resize") { + // leave Resize as-is pending bugfix for NHWC implementation. this means the node will remain in the ONNX domain + // with the original input layout. + return false; + } + } +#endif + + // #if defined(USE_CUDA) + // if (node.GetExecutionProviderType() == kCudaExecutionProvider) { + // Update as per https://github.com/microsoft/onnxruntime/pull/17200 with CUDA ops that support NHWC + // } + // #endif + + return layout_sensitive_ops.count(node.OpType()) != 0; +} +} // namespace // Layout sensitive NCHW ops. TransformLayoutForEP will wrap these with Transpose nodes to convert the input // data to NHWC and output data back to NCHW, and move the op to the internal NHWC domain (kMSInternalNHWCDomain). -// The EP requesting these ops MUST be able to handle the node with the operator in the kMSInternalNHWCDomain. +// The EP requesting these ops MUST be able to handle the node with the operator in the kMSInternalNHWCDomain domain. // Once all the layout sensitive ops requested by the EP are wrapped the transpose optimizer will attempt to remove // as many of the layout transposes as possible. const std::unordered_set& GetORTLayoutSensitiveOps() { static std::unordered_set ort_layout_sensitive_ops = []() { const auto& layout_sensitive_ops = onnx_transpose_optimization::GetLayoutSensitiveOps(); std::unordered_set ort_specific_ops = - { "FusedConv", - "QLinearAveragePool", - "QLinearGlobalAveragePool" -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) || defined(USE_WEBNN) - // The CUDA/ROCM Resize kernel is layout sensitive as it only handles NCHW input. - // The CPU kernel and ONNX spec are not limited to handling NCHW input so are not layout sensitive, and - // onnx_layout_transformation::HandleResize is used. - , - "Resize" -#endif - }; + { + "FusedConv", + "QLinearAveragePool", + "QLinearGlobalAveragePool", + // Whilst the ONNX spec doesn't specify a layout for Resize, we treat it as layout sensitive by default + // as EPs tend to only support one layout. + "Resize", + }; ort_specific_ops.insert(layout_sensitive_ops.cbegin(), layout_sensitive_ops.cend()); return ort_specific_ops; @@ -42,45 +106,21 @@ const std::unordered_set& GetORTLayoutSensitiveOps() { return ort_layout_sensitive_ops; } -// Cost check for aggressively pushing the Transpose nodes involved in the layout transformation further out. -static CostCheckResult -PostLayoutTransformCostCheck(const api::GraphRef& graph, const api::NodeRef& node, - const std::vector& perm, - const std::unordered_set& outputs_leading_to_transpose) { - // we aggressively push the layout transpose nodes. - // Exception: pushing through a Concat can result in Transpose nodes being added to multiple other inputs which - // can potentially be worse for performance. Use the cost check in that case. - if (node.OpType() != "Concat" && - (perm == ChannelFirstToLastPerm(perm.size()) || perm == ChannelLastToFirstPerm(perm.size()))) { - return CostCheckResult::kPushTranspose; - } - - // for other nodes use the default ORT cost check - return OrtEPCostCheck(graph, node, perm, outputs_leading_to_transpose); -} - Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvider& execution_provider, AllocatorPtr cpu_allocator, const DebugGraphFn& debug_graph_fn) { // We pass in nullptr for the new_node_ep param as new nodes will be assigned by the graph partitioner after // TransformLayoutForEP returns. - // sub graph recurse will be added later. + // sub graph recurse will be added later auto api_graph = MakeApiGraph(graph, cpu_allocator, /*new_node_ep*/ nullptr); - const auto& layout_sensitive_ops = GetORTLayoutSensitiveOps(); // to convert to NHWC we need to wrap layout sensitive nodes to Transpose from NCHW to NHWC and back. for (auto& node : api_graph->Nodes()) { - if (layout_sensitive_ops.count(node->OpType())) { - if (node->GetExecutionProviderType() != execution_provider.Type()) { - continue; - } - - auto domain = node->Domain(); - // Skip if domain is incorrect - if (domain != kOnnxDomain && domain != kMSDomain) { - continue; - } + if (node->GetExecutionProviderType() != execution_provider.Type()) { + continue; + } + if (ConvertNodeLayout(*node)) { // if already transformed then change the domain to kMSInternalNHWCDomain this way the EP // knows this op is in the expected format. if (node->GetAttributeIntDefault("channels_last", 0) == 1) { @@ -137,7 +177,6 @@ Status TransformLayoutForEP(Graph& graph, bool& modified, const IExecutionProvid WrapTransposesAroundNode(*api_graph, *node, {&input_perm}, {&output_perm}); } - // TODO: Technically Resize doesn't need to change domain as the ONNX Resize spec is not layout sensitive. SwapNodeOpTypeAndDomain(*api_graph, *node, node->OpType(), kMSInternalNHWCDomain); modified = true; } diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc index f6d9a60726ccc..81b415c2e40ae 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.cc @@ -1242,18 +1242,7 @@ static void PermuteInput(api::GraphRef& graph, api::NodeRef& node, size_t i, con node.SetInput(i, gather_output); } -static bool HandleResize([[maybe_unused]] HandlerArgs& args) { -#if defined(USE_CUDA) || defined(USE_ROCM) || defined(USE_QNN) || defined(USE_WEBNN) - // The CUDA Resize kernel requires that the input is NCHW, so we can't push a Transpose through a Resize - // in ORT builds with CUDA enabled. - // The ROCm EP is generated from the CUDA EP kernel so the same applies to builds with ROCm enabled. - // The QNN EP requires the input to be NHWC, so the Resize handler is also not enabled for QNN builds. - // - // TODO: Remove this special case once the CUDA Resize kernel is implemented "generically" (i.e.) aligning with the - // generic nature of the ONNX spec. - // See https://github.com/microsoft/onnxruntime/pull/10824 for a similar fix applied to the CPU Resize kernel. - return false; -#else +bool HandleResize([[maybe_unused]] HandlerArgs& args) { auto inputs = args.node.Inputs(); int64_t rank_int = gsl::narrow_cast(args.perm.size()); @@ -1279,10 +1268,10 @@ static bool HandleResize([[maybe_unused]] HandlerArgs& args) { TransposeOutputs(args.ctx, args.node, args.perm); return true; -#endif } -constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize}; +// Not currently registered by default. +// constexpr HandlerInfo resize_handler = {&FirstInput, &HandleResize}; static bool HandlePad(HandlerArgs& args) { size_t rank = args.perm.size(); @@ -2034,8 +2023,11 @@ static const std::unordered_map handler_ma {"Split", split_handler}, {"Shape", shape_handler}, {"Pad", pad_handler}, - {"Resize", resize_handler}, - {"ReduceSum", reduce_op_handler}, + + // Execution providers tend to only implement Resize for specific layouts. Due to that, it's safer to not + // push a Transpose through a Resize unless the EP specifically checks that it can handle the change via an + // extended handler. + // {"Resize", resize_handler}, {"ReduceLogSum", reduce_op_handler}, {"ReduceLogSumExp", reduce_op_handler}, @@ -2043,6 +2035,7 @@ static const std::unordered_map handler_ma {"ReduceMean", reduce_op_handler}, {"ReduceMin", reduce_op_handler}, {"ReduceProd", reduce_op_handler}, + {"ReduceSum", reduce_op_handler}, {"ReduceSumSquare", reduce_op_handler}, {"ReduceL1", reduce_op_handler}, {"ReduceL2", reduce_op_handler}, @@ -2385,6 +2378,8 @@ OptimizeResult OptimizeImpl(OptimizerCtx& ctx) { continue; } + // NOTE: this bleeds ORT specific logic into the base optimizer, however we justify that for now because we expect + // the types that the ORT DQ provides to be added to the ONNX spec, at which point this special case can go away. if (IsMSDomain(dq_domain) && !TransposeQuantizeDequantizeAxis(ctx.graph, perm_inv, *dq_node)) { continue; } diff --git a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h index f8aaeca915171..cc1552704c187 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/onnx_transpose_optimization.h @@ -98,6 +98,7 @@ bool HandleSimpleNodeWithAxis(HandlerArgs& args, std::optional default_ // base handlers that are used by extended handlers. add from transpose_optimizer.cc as needed. bool HandleReduceOps(HandlerArgs& args); +bool HandleResize([[maybe_unused]] HandlerArgs& args); void TransposeInput(api::GraphRef& graph, api::NodeRef& node, size_t i, const std::vector& perm, diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc index ead82a6b56741..f4f3505128737 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.cc @@ -5,12 +5,35 @@ #include #include "core/graph/constants.h" +#include "core/framework/utils.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" using namespace onnx_transpose_optimization; namespace onnxruntime { +static bool EPAwareHandleResize(HandlerArgs& args) { + // Whilst Resize is not technically layout sensitive, execution providers typically implement handling for only one + // layout. Due to that, only push a Transpose through a Resize once it is assigned and we know it's being handled + // by an EP that supports multiple layouts. Currently that's the CPU and XNNPACK EPs. + const auto ep_type = args.node.GetExecutionProviderType(); + if (ep_type == kCpuExecutionProvider || ep_type == kXnnpackExecutionProvider) { + // allow NCHW <-> NHWC for now. not clear any other sort of transpose has a valid usage in a real model + int64_t rank_int = gsl::narrow_cast(args.perm.size()); + if (rank_int == 4) { + static const std::vector nchw_to_nhwc_perm{0, 2, 3, 1}; + static const std::vector nhwc_to_nchw_perm{0, 3, 1, 2}; + if (args.perm == nchw_to_nhwc_perm || args.perm == nhwc_to_nchw_perm) { + return HandleResize(args); + } + } + } + + return false; +} + +constexpr HandlerInfo ep_aware_resize_handler = {&FirstInput, &EPAwareHandleResize}; + static bool HandleQLinearConcat(HandlerArgs& args) { return HandleSimpleNodeWithAxis(args); } @@ -62,7 +85,7 @@ static bool HandleMaxPool(HandlerArgs& args) { ORT_UNUSED_PARAMETER(args); return false; #else - if (args.node.GetExecutionProviderType() != "CPUExecutionProvider") { + if (args.node.GetExecutionProviderType() != kCpuExecutionProvider) { return false; } @@ -103,6 +126,7 @@ static bool HandleContribQuantizeDequantizeLinear(HandlerArgs& args) { } constexpr HandlerInfo max_pool_op_handler = {&FirstInput, &HandleMaxPool}; + constexpr HandlerInfo node_1_inp_handler = {&FirstInput, &HandleSimpleNode}; constexpr HandlerInfo reduce_op_handler = {&FirstInput, &HandleReduceOps}; constexpr HandlerInfo contrib_quantize_dequantize_linear_handler = {&FirstInput, @@ -113,6 +137,7 @@ const HandlerMap& OrtExtendedHandlers() { static const HandlerMap extended_handler_map = []() { HandlerMap map = { {"MaxPool", max_pool_op_handler}, + {"Resize", ep_aware_resize_handler}, {"com.microsoft.QuantizeLinear", contrib_quantize_dequantize_linear_handler}, {"com.microsoft.DequantizeLinear", contrib_quantize_dequantize_linear_handler}, {"com.microsoft.QLinearAdd", q_linear_binary_op_handler}, diff --git a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h index 0a5dbd6d13d06..8245d8c3b4eae 100644 --- a/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h +++ b/onnxruntime/core/optimizer/transpose_optimization/ort_transpose_optimization.h @@ -10,7 +10,6 @@ namespace onnxruntime { /// /// Get the extended handlers for ORT specific transpose optimization. /// These include handlers for contrib ops, and where we have an NHWC version of a layout sensitive op. -/// Extends the handlers returned by OrtHandlers. /// /// HandlerMap const onnx_transpose_optimization::HandlerMap& OrtExtendedHandlers(); diff --git a/onnxruntime/core/optimizer/transpose_optimizer.cc b/onnxruntime/core/optimizer/transpose_optimizer.cc index 33e3f5eeaf0fa..092df9cc7dcfb 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer.cc +++ b/onnxruntime/core/optimizer/transpose_optimizer.cc @@ -18,10 +18,18 @@ namespace onnxruntime { Status TransposeOptimizer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { - auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ nullptr); - - OptimizeResult result = onnx_transpose_optimization::Optimize(*api_graph, "", /* default cost check*/ nullptr, - OrtExtendedHandlers()); + OptimizeResult result; + + if (ep_.empty()) { + // basic usage - no EP specific optimizations + auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ nullptr); + result = onnx_transpose_optimization::Optimize(*api_graph, "", /* default cost check*/ nullptr, + OrtExtendedHandlers()); + } else { + // EP specific optimizations enabled. Currently only used for CPU EP. + auto api_graph = MakeApiGraph(graph, cpu_allocator_, /*new_node_ep*/ ep_.c_str()); + result = onnx_transpose_optimization::Optimize(*api_graph, ep_, OrtEPCostCheck, OrtExtendedHandlers()); + } if (result.error_msg) { // currently onnx_layout_transformation::Optimize only fails if we hit an unsupported opset. diff --git a/onnxruntime/core/optimizer/transpose_optimizer.h b/onnxruntime/core/optimizer/transpose_optimizer.h index 1ae6d611d2f0e..97d7ab4d0e220 100644 --- a/onnxruntime/core/optimizer/transpose_optimizer.h +++ b/onnxruntime/core/optimizer/transpose_optimizer.h @@ -15,10 +15,14 @@ Push transposes through ops and eliminate them. class TransposeOptimizer : public GraphTransformer { private: AllocatorPtr cpu_allocator_; + const std::string ep_; public: - explicit TransposeOptimizer(AllocatorPtr cpu_allocator) noexcept - : GraphTransformer("TransposeOptimizer"), cpu_allocator_(std::move(cpu_allocator)) {} + explicit TransposeOptimizer(AllocatorPtr cpu_allocator, + const std::string& ep = {}) noexcept + : GraphTransformer(ep.empty() ? "TransposeOptimizer" : "TransposeOptimizer_" + ep), + cpu_allocator_(std::move(cpu_allocator)), + ep_{ep} {} Status ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const override; diff --git a/onnxruntime/core/platform/windows/debug_alloc.cc b/onnxruntime/core/platform/windows/debug_alloc.cc index b08d189f79866..ff6a059607367 100644 --- a/onnxruntime/core/platform/windows/debug_alloc.cc +++ b/onnxruntime/core/platform/windows/debug_alloc.cc @@ -12,7 +12,7 @@ // #ifndef NDEBUG #ifdef ONNXRUNTIME_ENABLE_MEMLEAK_CHECK -constexpr int c_callstack_limit = 16; // Maximum depth of callstack in leak trace +constexpr int c_callstack_limit = 32; // Maximum depth of callstack in leak trace #define VALIDATE_HEAP_EVERY_ALLOC 0 // Call HeapValidate on every new/delete #pragma warning(disable : 4073) // initializers put in library initialization area (this is intentional) @@ -223,6 +223,11 @@ Memory_LeakCheck::~Memory_LeakCheck() { // empty_group_names = new std::map; }); if (string.find("RtlRunOnceExecuteOnce") == std::string::npos && string.find("re2::RE2::Init") == std::string::npos && + string.find("dynamic initializer for 'FLAGS_") == std::string::npos && + string.find("AbslFlagDefaultGenForgtest_") == std::string::npos && + string.find("AbslFlagDefaultGenForundefok::Gen") == std::string::npos && + string.find("::SetProgramUsageMessage") == std::string::npos && + string.find("testing::internal::ParseGoogleTestFlagsOnly") == std::string::npos && string.find("testing::internal::Mutex::ThreadSafeLazyInit") == std::string::npos && string.find("testing::internal::ThreadLocalRegistryImpl::GetThreadLocalsMapLocked") == std::string::npos && string.find("testing::internal::ThreadLocalRegistryImpl::GetValueOnCurrentThread") == std::string::npos && diff --git a/onnxruntime/core/platform/windows/stacktrace.cc b/onnxruntime/core/platform/windows/stacktrace.cc index cac6f4f29043b..d7d423e4a483e 100644 --- a/onnxruntime/core/platform/windows/stacktrace.cc +++ b/onnxruntime/core/platform/windows/stacktrace.cc @@ -10,7 +10,6 @@ #include #endif #endif -#include #include "core/common/logging/logging.h" #include "core/common/gsl.h" diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 72e36a161e9aa..ae33fb752fe00 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -232,18 +232,18 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Uns class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, Transpose); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Transpose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, Conv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, float, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, ConvTranspose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalAveragePool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 1, GlobalMaxPool); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, Conv); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Conv); -class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, float, ConvTranspose); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, Conv); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Conv); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 10, ConvTranspose); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, ConvTranspose); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, Gemm); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, Gemm); @@ -318,6 +318,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Til class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 17, float, LayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kMSInternalNHWCDomain, 6, float, InstanceNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, float, InstanceNormalization); + +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, Range); + class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 12, float, Einsum); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 2, 10, Pad); @@ -496,18 +499,18 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -584,7 +587,11 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index 177c0a9e691ed..fdd5c7dee5bfc 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -196,7 +196,7 @@ class JsKernel : public OpKernel { } int status_code = EM_ASM_INT( - { return Module.jsepRunKernel($0, $1, Module.jsepSessionState); }, + { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, this, reinterpret_cast(p_serialized_kernel_context)); LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data=" diff --git a/onnxruntime/core/providers/js/operators/conv.cc b/onnxruntime/core/providers/js/operators/conv.cc index c7c9f7f7c3f0e..2e07124dcd901 100644 --- a/onnxruntime/core/providers/js/operators/conv.cc +++ b/onnxruntime/core/providers/js/operators/conv.cc @@ -9,33 +9,27 @@ namespace onnxruntime { namespace js { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Conv, \ - kMSInternalNHWCDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - Conv, \ - kOnnxDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - Conv, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - Conv); - -REGISTER_KERNEL_TYPED(float) +ONNX_OPERATOR_KERNEL_EX( + Conv, + kMSInternalNHWCDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Conv); +ONNX_OPERATOR_KERNEL_EX( + Conv, + kOnnxDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Conv); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + Conv, + kOnnxDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + Conv); } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index 22f7721276677..fdf3e5b6c6b66 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -9,7 +9,7 @@ namespace onnxruntime { namespace js { -template +template class Conv : public JsKernel { public: Conv(const OpKernelInfo& info) : JsKernel(info), conv_attrs_(info), w_is_const_(false) { diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.cc b/onnxruntime/core/providers/js/operators/conv_transpose.cc index 1a2fc99eada6a..2228343e1e6e3 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.cc +++ b/onnxruntime/core/providers/js/operators/conv_transpose.cc @@ -7,33 +7,28 @@ #include "conv_transpose.h" namespace onnxruntime { namespace js { -#define REGISTER_KERNEL_TYPED(T) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kMSInternalNHWCDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kOnnxDomain, \ - 11, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); \ - ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \ - ConvTranspose, \ - kOnnxDomain, \ - 1, 10, \ - T, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - ConvTranspose); -REGISTER_KERNEL_TYPED(float) +ONNX_OPERATOR_KERNEL_EX( + ConvTranspose, + kMSInternalNHWCDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + ConvTranspose); +ONNX_OPERATOR_KERNEL_EX( + ConvTranspose, + kOnnxDomain, + 11, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + ConvTranspose); +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + ConvTranspose, + kOnnxDomain, + 1, 10, + kJsExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", JsepSupportedFloatTypes()), + ConvTranspose); } // namespace js } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index a5aeae8646373..18ef73268005d 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -9,7 +9,7 @@ #include "core/providers/js/js_kernel.h" namespace onnxruntime { namespace js { -template +template class ConvTranspose : public JsKernel { public: ConvTranspose(const OpKernelInfo& info) : JsKernel(info), conv_transpose_attrs_(info), w_is_const_(false) { @@ -108,6 +108,23 @@ class ConvTranspose : public JsKernel { } } + Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, + /*out*/ bool& is_packed, + /*out*/ PrePackedWeights* /* prepacked_weights */) override { + is_packed = false; + + if (input_idx == 1) { + // Only handle the common case of conv2D + if (tensor.Shape().NumDimensions() != 4 || tensor.SizeInBytes() == 0) { + return Status::OK(); + } + + w_is_const_ = true; + } + + return Status::OK(); + } + protected: ConvTransposeAttributes conv_transpose_attrs_; bool w_is_const_; diff --git a/onnxruntime/core/providers/js/operators/matmul.cc b/onnxruntime/core/providers/js/operators/matmul.cc index ddfbb454def07..6e6f906f7b42c 100644 --- a/onnxruntime/core/providers/js/operators/matmul.cc +++ b/onnxruntime/core/providers/js/operators/matmul.cc @@ -9,11 +9,11 @@ namespace js { JSEP_KERNEL_IMPL(MatMul, MatMul) ONNX_OPERATOR_VERSIONED_KERNEL_EX(MatMul, kOnnxDomain, 1, 12, kJsExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), MatMul); ONNX_OPERATOR_KERNEL_EX(MatMul, kOnnxDomain, 13, kJsExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), + KernelDefBuilder().TypeConstraint("T", JsepSupportedFloatTypes()), MatMul); } // namespace js diff --git a/onnxruntime/core/providers/js/operators/range.cc b/onnxruntime/core/providers/js/operators/range.cc new file mode 100644 index 0000000000000..e15861f7f227a --- /dev/null +++ b/onnxruntime/core/providers/js/operators/range.cc @@ -0,0 +1,22 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/js/js_kernel.h" + +#include "range.h" + +namespace onnxruntime { +namespace js { +ONNX_OPERATOR_KERNEL_EX( + Range, + kOnnxDomain, + 11, + kJsExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}) + .InputMemoryType(OrtMemTypeCPU, 0) + .InputMemoryType(OrtMemTypeCPU, 1) + .InputMemoryType(OrtMemTypeCPU, 2), + Range); +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/operators/range.h b/onnxruntime/core/providers/js/operators/range.h new file mode 100644 index 0000000000000..8b32bfc3d984b --- /dev/null +++ b/onnxruntime/core/providers/js/operators/range.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/js/js_kernel.h" + +namespace onnxruntime { +namespace js { + +JSEP_KERNEL_IMPL(Range, Range); + +} // namespace js +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index b207b804416aa..6a86ca7aca6e9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -35,30 +35,79 @@ Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); const auto input_size = input_shape.size(); - // WebNN Softmax only support 2d input shape, reshape input to 2d. - if (input_size != 2) { - NodeAttrHelper helper(node); + NodeAttrHelper helper(node); + if (node.SinceVersion() < 13) { int32_t axis = helper.Get("axis", 1); - if (node.SinceVersion() >= 13) - // Opset 13 has default value -1. - axis = helper.Get("axis", -1); + axis = static_cast(HandleNegativeAxis(axis, input_size)); // Coerce the input into a 2-dimensional tensor with dimensions [a_0 * ... * a_{k-1}, a_k * ... * a_{n-1}]. + if (input_size != 2) { + int32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, + 1, std::multiplies())); + int32_t second_dim = static_cast(std::reduce(input_shape.begin() + axis, input_shape.end(), + 1, std::multiplies())); + emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); + input = model_builder.GetBuilder().call("reshape", input, new_shape); + } + + output = model_builder.GetBuilder().call("softmax", input); + + // Reshape output to the same shape of input. + if (input_size != 2) { + emscripten::val new_shape = emscripten::val::array(); + for (size_t i = 0; i < input_size; i++) { + new_shape.call("push", static_cast(input_shape[i])); + } + output = model_builder.GetBuilder().call("reshape", output, new_shape); + } + } else { + int32_t axis = helper.Get("axis", -1); axis = static_cast(HandleNegativeAxis(axis, input_size)); - int32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, - 1, std::multiplies())); - int32_t second_dim = static_cast(std::reduce(input_shape.begin() + axis, input_shape.end(), - 1, std::multiplies())); - emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); - input = model_builder.GetBuilder().call("reshape", input, new_shape); - } - output = model_builder.GetBuilder().call("softmax", input); - // Reshape output to the same shape of input. - if (input_size != 2) { - emscripten::val new_shape = emscripten::val::array(); - for (size_t i = 0; i < input_size; i++) { - new_shape.call("push", static_cast(input_shape[i])); + // Wraparound for transpose the target axis to the last. + // WebNN compute the softmax values of the 2-D input tensor along axis 1. + // https://www.w3.org/TR/webnn/#api-mlgraphbuilder-softmax-method + if (axis != static_cast(input_shape.size() - 1)) { + emscripten::val options = emscripten::val::object(); + std::vector permutation(input_shape.size()); + std::iota(permutation.begin(), permutation.end(), 0); + permutation.erase(permutation.begin() + axis); + permutation.push_back(axis); + options.set("permutation", emscripten::val::array(permutation)); + input = model_builder.GetBuilder().call("transpose", input, options); + } + // Wraparound for reshape input tensor to 2-D. + if (input_shape.size() != 2) { + uint32_t first_dim = static_cast(std::reduce(input_shape.begin(), input_shape.begin() + axis, + 1, std::multiplies())); + first_dim *= static_cast(std::reduce(input_shape.begin() + axis + 1, input_shape.end(), + 1, std::multiplies())); + uint32_t second_dim = static_cast(input_shape[axis]); + emscripten::val new_shape = emscripten::val::array(std::vector{first_dim, second_dim}); + input = model_builder.GetBuilder().call("reshape", input, new_shape); + } + + output = model_builder.GetBuilder().call("softmax", input); + + // Transpose back to the axis. + if (input_shape.size() != 2) { + std::vector new_shape; + std::transform(input_shape.begin(), input_shape.begin() + axis, std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return static_cast(dim); }); + std::transform(input_shape.begin() + axis + 1, input_shape.end(), std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return static_cast(dim); }); + new_shape.push_back(static_cast(input_shape[axis])); + output = model_builder.GetBuilder().call("reshape", + output, emscripten::val::array(new_shape)); + } + // Reshape to the original shape. + if (axis != static_cast(input_shape.size() - 1)) { + emscripten::val options = emscripten::val::object(); + std::vector permutation(input_shape.size()); + std::iota(permutation.begin(), permutation.end(), 0); + permutation.pop_back(); + permutation.insert(permutation.begin() + axis, input_shape.size() - 1); + options.set("permutation", emscripten::val::array(permutation)); + output = model_builder.GetBuilder().call("transpose", output, options); } - output = model_builder.GetBuilder().call("reshape", output, new_shape); } model_builder.AddOperand(node.OutputDefs()[0]->Name(), std::move(output)); return Status::OK(); @@ -80,14 +129,6 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali << input_size << "d shape"; return false; } - NodeAttrHelper helper(node); - const int64_t axis = helper.Get("axis", 1); - // WebNN softmax only support reshape for the last axis or version before 13. - // TODO: support opset 13 by composing into: Exp(input) / ReduceSum(Exp(input), axis=axis, keepdims=1). - if (axis != -1 && axis != input_shape.size() - 1 && node.SinceVersion() >= 13) { - LOGS(logger, VERBOSE) << "SoftMax only support axis 1 or -1, input axis: " << axis; - return false; - } return true; } diff --git a/onnxruntime/core/providers/xnnpack/nn/resize.cc b/onnxruntime/core/providers/xnnpack/nn/resize.cc index 672b2597279db..76c6b6acbfe32 100644 --- a/onnxruntime/core/providers/xnnpack/nn/resize.cc +++ b/onnxruntime/core/providers/xnnpack/nn/resize.cc @@ -331,7 +331,13 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 13, 17, kXnnpackExecution DataTypeImpl::GetTensorType()}), Resize); -ONNX_OPERATOR_KERNEL_EX(Resize, kOnnxDomain, 18, kXnnpackExecutionProvider, +ONNX_OPERATOR_VERSIONED_KERNEL_EX(Resize, kOnnxDomain, 18, 18, kXnnpackExecutionProvider, + KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + Resize); + +ONNX_OPERATOR_KERNEL_EX(Resize, kOnnxDomain, 19, kXnnpackExecutionProvider, KernelDefBuilder().TypeConstraint("T1", {DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType()}), diff --git a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc index ba577ac38d48c..494c718cde081 100644 --- a/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc +++ b/onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc @@ -46,7 +46,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWC class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 10, 10, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 11, 12, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, 17, Resize); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, Resize); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, 18, Resize); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 19, Resize); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, AveragePool); @@ -84,7 +85,9 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo< ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 1, 12, Softmax)>, BuildKernelCreateInfo< - ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, Resize)>, + ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 19, Resize)>, + BuildKernelCreateInfo< + ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 18, 18, Resize)>, BuildKernelCreateInfo< ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 13, 17, Resize)>, BuildKernelCreateInfo< diff --git a/onnxruntime/core/session/abi_session_options.cc b/onnxruntime/core/session/abi_session_options.cc index 4fcc6de561f8c..fb314b161f1ad 100644 --- a/onnxruntime/core/session/abi_session_options.cc +++ b/onnxruntime/core/session/abi_session_options.cc @@ -143,6 +143,14 @@ ORT_API_STATUS_IMPL(OrtApis::SetSessionLogId, _In_ OrtSessionOptions* options, c return nullptr; } +///< logging function and optional logging param to use for session output +ORT_API_STATUS_IMPL(OrtApis::SetUserLoggingFunction, _In_ OrtSessionOptions* options, + _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param) { + options->value.user_logging_function = user_logging_function; + options->value.user_logging_param = user_logging_param; + return nullptr; +} + ///< applies to session load, initialization, etc ORT_API_STATUS_IMPL(OrtApis::SetSessionLogVerbosityLevel, _In_ OrtSessionOptions* options, int session_log_verbosity_level) { options->value.session_log_verbosity_level = session_log_verbosity_level; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 21c8fbe0cd2c9..b4d47652942b7 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -56,6 +56,7 @@ #include "core/providers/dml/dml_session_options_config_keys.h" #endif #include "core/session/environment.h" +#include "core/session/user_logging_sink.h" #include "core/session/IOBinding.h" #include "core/session/inference_session_utils.h" #include "core/session/onnxruntime_session_options_config_keys.h" @@ -298,6 +299,35 @@ static Status FinalizeSessionOptions(const SessionOptions& user_provided_session return Status::OK(); } +logging::Severity GetSeverity(const SessionOptions& session_options) { + logging::Severity severity = logging::Severity::kWARNING; + if (session_options.session_log_severity_level == -1) { + severity = logging::LoggingManager::DefaultLogger().GetSeverity(); + } else { + ORT_ENFORCE(session_options.session_log_severity_level >= 0 && + session_options.session_log_severity_level <= static_cast(logging::Severity::kFATAL), + "Invalid session log severity level. Not a valid onnxruntime::logging::Severity value: ", + session_options.session_log_severity_level); + severity = static_cast(session_options.session_log_severity_level); + } + return severity; +} + +void InferenceSession::SetLoggingManager(const SessionOptions& session_options, + const Environment& session_env) { + logging_manager_ = session_env.GetLoggingManager(); + if (session_options.user_logging_function) { + std::unique_ptr user_sink = std::make_unique(session_options.user_logging_function, + session_options.user_logging_param); + user_logging_manager_ = std::make_unique(std::move(user_sink), + GetSeverity(session_options), + false, + logging::LoggingManager::InstanceType::Temporal, + &session_options.session_logid); + logging_manager_ = user_logging_manager_.get(); + } +} + void InferenceSession::ConstructorCommon(const SessionOptions& session_options, const Environment& session_env) { auto status = FinalizeSessionOptions(session_options, model_proto_, is_model_proto_parsed_, session_options_); @@ -306,6 +336,8 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ", status.ErrorMessage()); + SetLoggingManager(session_options, session_env); + // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked // after the invocation of FinalizeSessionOptions. InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point. @@ -427,7 +459,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const #if !defined(ORT_MINIMAL_BUILD) graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), #endif - logging_manager_(session_env.GetLoggingManager()), environment_(session_env) { // Initialize assets of this session instance ConstructorCommon(session_options, session_env); @@ -441,7 +472,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, #if !defined(ORT_MINIMAL_BUILD) graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), #endif - logging_manager_(session_env.GetLoggingManager()), external_intra_op_thread_pool_(external_intra_op_thread_pool), external_inter_op_thread_pool_(external_inter_op_thread_pool), environment_(session_env) { @@ -454,7 +484,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const const PathString& model_uri) : model_location_(model_uri), graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), - logging_manager_(session_env.GetLoggingManager()), environment_(session_env) { auto status = Model::Load(model_location_, model_proto_); ORT_ENFORCE(status.IsOK(), "Given model could not be parsed while creating inference session. Error message: ", @@ -475,7 +504,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, std::istream& model_istream) : graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), - logging_manager_(session_env.GetLoggingManager()), environment_(session_env) { Status st = Model::Load(model_istream, &model_proto_); ORT_ENFORCE(st.IsOK(), "Could not parse model successfully while constructing the inference session"); @@ -487,7 +515,6 @@ InferenceSession::InferenceSession(const SessionOptions& session_options, const InferenceSession::InferenceSession(const SessionOptions& session_options, const Environment& session_env, const void* model_data, int model_data_len) : graph_transformer_mgr_(session_options.max_num_graph_transformation_steps), - logging_manager_(session_env.GetLoggingManager()), environment_(session_env) { const bool result = model_proto_.ParseFromArray(model_data, model_data_len); ORT_ENFORCE(result, "Could not parse model successfully while constructing the inference session"); @@ -2815,17 +2842,7 @@ const logging::Logger& InferenceSession::CreateLoggerForRun(const RunOptions& ru void InferenceSession::InitLogger(logging::LoggingManager* logging_manager) { // create logger for session, using provided logging manager if possible if (logging_manager != nullptr) { - logging::Severity severity = logging::Severity::kWARNING; - if (session_options_.session_log_severity_level == -1) { - severity = logging::LoggingManager::DefaultLogger().GetSeverity(); - } else { - ORT_ENFORCE(session_options_.session_log_severity_level >= 0 && - session_options_.session_log_severity_level <= static_cast(logging::Severity::kFATAL), - "Invalid session log severity level. Not a valid onnxruntime::logging::Severity value: ", - session_options_.session_log_severity_level); - severity = static_cast(session_options_.session_log_severity_level); - } - + logging::Severity severity = GetSeverity(session_options_); owned_session_logger_ = logging_manager_->CreateLogger(session_options_.session_logid, severity, false, session_options_.session_log_verbosity_level); session_logger_ = owned_session_logger_.get(); diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 9259e014b9860..4db436f132d11 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -595,7 +595,8 @@ class InferenceSession { private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(InferenceSession); - + void SetLoggingManager(const SessionOptions& session_options, + const Environment& session_env); void ConstructorCommon(const SessionOptions& session_options, const Environment& session_env); @@ -698,7 +699,10 @@ class InferenceSession { SessionOptions session_options_; /// Logging manager if provided. - logging::LoggingManager* const logging_manager_; + logging::LoggingManager* logging_manager_; + + /// User specified logging mgr; logging_manager_ is simply the ptr in this unique_ptr when available + std::unique_ptr user_logging_manager_; /// Logger for this session. WARNING: Will contain nullptr if logging_manager_ is nullptr. std::unique_ptr owned_session_logger_ = nullptr; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 8167bef1b9c00..a3e79628b7581 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2713,6 +2713,8 @@ static constexpr OrtApi ort_api_1_to_17 = { &OrtApis::GetCUDAProviderOptionsByName, &OrtApis::KernelContext_GetResource, // End of Version 16 - DO NOT MODIFY ABOVE (see above text for more information) + + &OrtApis::SetUserLoggingFunction, &OrtApis::ShapeInferContext_GetInputCount, &OrtApis::ShapeInferContext_GetInputTypeShape, &OrtApis::ShapeInferContext_SetOutputTypeShape, @@ -2746,6 +2748,7 @@ static_assert(offsetof(OrtApi, ReleaseCANNProviderOptions) / sizeof(void*) == 22 static_assert(offsetof(OrtApi, GetSessionConfigEntry) / sizeof(void*) == 238, "Size of version 14 API cannot change"); static_assert(offsetof(OrtApi, GetBuildInfoString) / sizeof(void*) == 254, "Size of version 15 API cannot change"); static_assert(offsetof(OrtApi, KernelContext_GetResource) / sizeof(void*) == 265, "Size of version 16 API cannot change"); +static_assert(offsetof(OrtApi, SetUserLoggingFunction) / sizeof(void*) == 266, "Size of version 17 API cannot change"); // So that nobody forgets to finish an API version, this check will serve as a reminder: static_assert(std::string_view(ORT_VERSION) == "1.17.0", diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 9ed367c3b03ba..b54c98be175fa 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -492,6 +492,8 @@ ORT_API_STATUS_IMPL(UpdateCUDAProviderOptionsWithValue, _Inout_ OrtCUDAProviderO ORT_API_STATUS_IMPL(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr); ORT_API_STATUS_IMPL(KernelContext_GetResource, _In_ const OrtKernelContext* context, _In_ int resource_version, _In_ int resource_id, _Outptr_ void** stream); +ORT_API_STATUS_IMPL(SetUserLoggingFunction, _Inout_ OrtSessionOptions* options, + _In_ OrtLoggingFunction user_logging_function, _In_opt_ void* user_logging_param); ORT_API_STATUS_IMPL(ShapeInferContext_GetInputCount, _In_ const OrtShapeInferContext* context, _Out_ size_t* out); ORT_API_STATUS_IMPL(ShapeInferContext_GetInputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _Outptr_ OrtTensorTypeAndShapeInfo** info); ORT_API_STATUS_IMPL(ShapeInferContext_SetOutputTypeShape, _In_ const OrtShapeInferContext* context, _In_ size_t index, _In_ const OrtTensorTypeAndShapeInfo* info); diff --git a/onnxruntime/core/session/ort_env.cc b/onnxruntime/core/session/ort_env.cc index eb78d5d799a55..e3957baa990f8 100644 --- a/onnxruntime/core/session/ort_env.cc +++ b/onnxruntime/core/session/ort_env.cc @@ -9,6 +9,7 @@ #include "core/session/ort_apis.h" #include "core/session/environment.h" #include "core/session/allocator_adapters.h" +#include "core/session/user_logging_sink.h" #include "core/common/logging/logging.h" #include "core/framework/provider_shutdown.h" #include "core/platform/logging/make_platform_default_log_sink.h" @@ -20,17 +21,6 @@ std::unique_ptr OrtEnv::p_instance_; int OrtEnv::ref_count_ = 0; onnxruntime::OrtMutex OrtEnv::m_; -LoggingWrapper::LoggingWrapper(OrtLoggingFunction logging_function, void* logger_param) - : logging_function_(logging_function), logger_param_(logger_param) { -} - -void LoggingWrapper::SendImpl(const onnxruntime::logging::Timestamp& /*timestamp*/, const std::string& logger_id, - const onnxruntime::logging::Capture& message) { - std::string s = message.Location().ToString(); - logging_function_(logger_param_, static_cast(message.Severity()), message.Category(), - logger_id.c_str(), s.c_str(), message.Message().c_str()); -} - OrtEnv::OrtEnv(std::unique_ptr value1) : value_(std::move(value1)) { } @@ -50,8 +40,8 @@ OrtEnv* OrtEnv::GetInstance(const OrtEnv::LoggingManagerConstructionInfo& lm_inf std::unique_ptr lmgr; std::string name = lm_info.logid; if (lm_info.logging_function) { - std::unique_ptr logger = std::make_unique(lm_info.logging_function, - lm_info.logger_param); + std::unique_ptr logger = std::make_unique(lm_info.logging_function, + lm_info.logger_param); lmgr = std::make_unique(std::move(logger), static_cast(lm_info.default_warning_level), false, diff --git a/onnxruntime/core/session/ort_env.h b/onnxruntime/core/session/ort_env.h index 7d609acb2db5d..444134d0612e9 100644 --- a/onnxruntime/core/session/ort_env.h +++ b/onnxruntime/core/session/ort_env.h @@ -5,27 +5,15 @@ #include #include #include "core/session/onnxruntime_c_api.h" -#include "core/common/logging/isink.h" #include "core/platform/ort_mutex.h" #include "core/common/status.h" +#include "core/common/logging/logging.h" #include "core/framework/allocator.h" namespace onnxruntime { class Environment; } -class LoggingWrapper : public onnxruntime::logging::ISink { - public: - LoggingWrapper(OrtLoggingFunction logging_function, void* logger_param); - - void SendImpl(const onnxruntime::logging::Timestamp& /*timestamp*/, const std::string& logger_id, - const onnxruntime::logging::Capture& message) override; - - private: - OrtLoggingFunction logging_function_; - void* logger_param_; -}; - struct OrtEnv { public: struct LoggingManagerConstructionInfo { diff --git a/onnxruntime/core/session/user_logging_sink.h b/onnxruntime/core/session/user_logging_sink.h new file mode 100644 index 0000000000000..5a9ceb21d6500 --- /dev/null +++ b/onnxruntime/core/session/user_logging_sink.h @@ -0,0 +1,28 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include + +#include "core/session/onnxruntime_c_api.h" +#include "core/common/logging/isink.h" + +namespace onnxruntime { +class UserLoggingSink : public onnxruntime::logging::ISink { + public: + UserLoggingSink(OrtLoggingFunction logging_function, void* logger_param) + : logging_function_(logging_function), logger_param_(logger_param) { + } + + void SendImpl(const onnxruntime::logging::Timestamp& /*timestamp*/, const std::string& logger_id, + const onnxruntime::logging::Capture& message) override { + std::string s = message.Location().ToString(); + logging_function_(logger_param_, static_cast(message.Severity()), message.Category(), + logger_id.c_str(), s.c_str(), message.Message().c_str()); + } + + private: + OrtLoggingFunction logging_function_{}; + void* logger_param_{}; +}; +} // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_pybind_module.cc b/onnxruntime/python/onnxruntime_pybind_module.cc index f320707697c9e..6824a5d0bf98f 100644 --- a/onnxruntime/python/onnxruntime_pybind_module.cc +++ b/onnxruntime/python/onnxruntime_pybind_module.cc @@ -10,6 +10,12 @@ namespace onnxruntime { namespace python { namespace py = pybind11; +#if defined(USE_MPI) && defined(ORT_USE_NCCL) +static constexpr bool HAS_COLLECTIVE_OPS = true; +#else +static constexpr bool HAS_COLLECTIVE_OPS = false; +#endif + void CreateInferencePybindStateModule(py::module& m); PYBIND11_MODULE(onnxruntime_pybind11_state, m) { @@ -23,6 +29,7 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { m.def("get_version_string", []() -> std::string { return ORT_VERSION; }); m.def("get_build_info", []() -> std::string { return ORT_BUILD_INFO; }); + m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; }); } } // namespace python } // namespace onnxruntime diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 2d1e418f9d2b4..c3c3d2a837336 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -112,8 +112,8 @@ def __init__( False if "ActivationSymmetric" not in self.extra_options else self.extra_options["ActivationSymmetric"] ) - self.activation_qType = activation_qType.tensor_type - self.weight_qType = weight_qType.tensor_type + self.activation_qType = getattr(activation_qType, "tensor_type", activation_qType) + self.weight_qType = getattr(weight_qType, "tensor_type", weight_qType) """ Dictionary specifying the min and max values for tensors. It has following format: { diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 40f2aee875382..a7460157ba409 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -78,7 +78,15 @@ def process_mask(self, input: str) -> str: # ReduceSum-13: axes is moved from attribute to input axes_name = "ort_const_1_reduce_sum_axes" if self.model.get_initializer(axes_name) is None: - self.add_initializer(name=axes_name, data_type=TensorProto.INT64, dims=[1], vals=[1], raw=False) + self.model.add_initializer( + helper.make_tensor( + name=axes_name, + data_type=TensorProto.INT64, + dims=[1], + vals=[1], + raw=False, + ) + ) mask_index_node = helper.make_node( "ReduceSum", inputs=[input_name, axes_name], diff --git a/onnxruntime/test/contrib_ops/qordered_attention_test.cc b/onnxruntime/test/contrib_ops/qordered_attention_test.cc index 24e4bff528285..1dd0162ad722f 100644 --- a/onnxruntime/test/contrib_ops/qordered_attention_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_attention_test.cc @@ -240,12 +240,6 @@ static std::vector transpose(const T& src, size_t h, size_t w) { } TEST(QOrderedTest, Attention_WithData_ROW_ORDER) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - // Needs Turing architecture if (NeedSkipIfCudaArchLowerThan(750) || NeedSkipIfCudaArchGreaterEqualThan(800)) { return; diff --git a/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc b/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc index 55209d9422fdd..06fe42ca989d7 100644 --- a/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_longformer_attention_op_test.cc @@ -34,12 +34,6 @@ static void run_qordered_longformer_attention_op_test( const int64_t head_size, const int64_t window, int64_t input_hidden_size = 0) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - // Needs Turing architecture if (NeedSkipIfCudaArchLowerThan(750) || NeedSkipIfCudaArchGreaterEqualThan(800)) { return; diff --git a/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc b/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc index e5b3d59ef86e3..e3905db6355d9 100644 --- a/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_matmul_op_test.cc @@ -21,12 +21,6 @@ static void RunQOrdered_MatMul_Test( OrderCublasLt weight_order, float scale_A, float scale_B, float scale_C, float scale_Y, bool add_bias = false, bool broadcast_c_batch = false, bool per_channel = false) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - // Needs Turing architecture if (NeedSkipIfCudaArchLowerThan(750) || NeedSkipIfCudaArchGreaterEqualThan(800)) { return; diff --git a/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc b/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc index 15e97751acf2d..0f3f702695b80 100644 --- a/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc +++ b/onnxruntime/test/contrib_ops/qordered_qdq_op_test.cc @@ -73,12 +73,6 @@ static void RunQOrdered_Quantize_Test( std::vector const& shape, OrderCublasLt order_q, float scale) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - auto qvec = QuantizeTransform(shape, scale, fvec, order_q); std::vector> execution_providers; @@ -153,12 +147,6 @@ static void RunQOrdered_Dequantize_Test( std::vector const& shape, OrderCublasLt order_q, float scale) { - int cuda_runtime_version = 0; - // Need 11.4 or higher cuda runtime - if ((cudaRuntimeGetVersion(&cuda_runtime_version) != cudaSuccess) || (cuda_runtime_version < 11040)) { - return; - } - auto fvec = DequantizeTransform(shape, scale, qvec, order_q); std::vector> execution_providers; diff --git a/onnxruntime/test/framework/inference_session_test.cc b/onnxruntime/test/framework/inference_session_test.cc index 2298e4afa6de0..486ec37d1eebd 100644 --- a/onnxruntime/test/framework/inference_session_test.cc +++ b/onnxruntime/test/framework/inference_session_test.cc @@ -890,6 +890,47 @@ TEST(InferenceSessionTests, ConfigureVerbosityLevel) { #endif } +TEST(InferenceSessionTests, UseUserSpecifiedLoggingFunctionInSession) { + SessionOptions so; + /* + typedef void(ORT_API_CALL* OrtLoggingFunction)( + void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, + const char* message); + */ + std::vector log_msgs; + so.user_logging_function = [](void* param, OrtLoggingLevel severity, const char* category, const char* logid, const char* code_location, + const char* message) { + ORT_UNUSED_PARAMETER(severity); + ORT_UNUSED_PARAMETER(category); + ORT_UNUSED_PARAMETER(logid); + ORT_UNUSED_PARAMETER(code_location); + std::vector* v_ptr = reinterpret_cast*>(param); + std::vector& msg_vector = *v_ptr; + msg_vector.push_back(std::string(message)); + }; + so.user_logging_param = &log_msgs; + so.session_log_severity_level = static_cast(Severity::kVERBOSE); + so.session_log_verbosity_level = 1; + so.session_logid = "InferenceSessionTests.UseUserSpecifiedLoggingFunctionInSession"; + + InferenceSession session_object{so, GetEnvironment()}; + ASSERT_STATUS_OK(session_object.Load(MODEL_URI)); + ASSERT_STATUS_OK(session_object.Initialize()); + + RunOptions run_options; + run_options.run_tag = "one session/one tag"; + RunModel(session_object, run_options); + +// vlog output is disabled in release builds +#ifndef NDEBUG + bool have_log_entry_with_vlog_session_msg = + (std::find_if(log_msgs.begin(), log_msgs.end(), + [&](std::string msg) { return msg.find("Added input argument with name") != string::npos; }) != + log_msgs.end()); + ASSERT_TRUE(have_log_entry_with_vlog_session_msg); +#endif +} + TEST(InferenceSessionTests, TestWithIstream) { SessionOptions so; diff --git a/onnxruntime/test/framework/tunable_op_test.cc b/onnxruntime/test/framework/tunable_op_test.cc index 6aa7c5ee9f8ac..0d9e557ebc813 100644 --- a/onnxruntime/test/framework/tunable_op_test.cc +++ b/onnxruntime/test/framework/tunable_op_test.cc @@ -263,8 +263,7 @@ TEST(TunableOp, OpWrapsMutableFunctor) { class VecAddMoveOnlyFunctor { public: - VecAddMoveOnlyFunctor() { - } + VecAddMoveOnlyFunctor() = default; VecAddMoveOnlyFunctor(VecAddMoveOnlyFunctor&&) = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(VecAddMoveOnlyFunctor); @@ -290,8 +289,7 @@ TEST(TunableOp, OpWrapsMoveOnlyFunctor) { class VecAddWithIsSupportedMethod { public: - VecAddWithIsSupportedMethod() { - } + VecAddWithIsSupportedMethod() = default; VecAddWithIsSupportedMethod(VecAddWithIsSupportedMethod&&) = default; ORT_DISALLOW_COPY_AND_ASSIGNMENT(VecAddWithIsSupportedMethod); diff --git a/onnxruntime/test/optimizer/qdq_transformer_test.cc b/onnxruntime/test/optimizer/qdq_transformer_test.cc index d3616a14d8a5d..1bf1cbacf479e 100644 --- a/onnxruntime/test/optimizer/qdq_transformer_test.cc +++ b/onnxruntime/test/optimizer/qdq_transformer_test.cc @@ -3085,12 +3085,12 @@ TEST(QDQTransformerTests, QDQPropagation_Per_Layer_No_Propagation) { check_graph, TransformerLevel::Default, TransformerLevel::Level1, - 18); // disable TransposeOptimizer for simplicity + 18); TransformerTester(build_test_case, check_graph, TransformerLevel::Default, TransformerLevel::Level1, - 19); // disable TransposeOptimizer for simplicity + 19); }; test_case({1, 13, 13, 23}, {0, 2, 3, 1}, false /*use_contrib_qdq*/); diff --git a/onnxruntime/test/optimizer/transpose_optimizer_test.cc b/onnxruntime/test/optimizer/transpose_optimizer_test.cc index 0d66e6f8d5f6e..4f4157bd7b1cf 100644 --- a/onnxruntime/test/optimizer/transpose_optimizer_test.cc +++ b/onnxruntime/test/optimizer/transpose_optimizer_test.cc @@ -320,20 +320,6 @@ TEST(TransposeOptimizerTests, TestPadNonconst) { /*opset_version*/ {11, 18}); } -// The CUDA Resize kernel assumes that the input is NCHW and -// Resize can't be supported in ORT builds with CUDA enabled. -// TODO: Enable this once the CUDA Resize kernel is implemented -// "generically" (i.e.) aligning with the generic nature of the -// ONNX spec. -// See https://github.com/microsoft/onnxruntime/pull/10824 for -// a similar fix applied to the CPU Resize kernel. -// Per tests included in #10824, the ROCM EP also generates -// incorrect results when this handler is used, so the Resize -// handler is not enabled even for those builds. -// -// The QNN EP requires the input to be NHWC, so the Resize handler is also not enabled -// for QNN builds. -#if !defined(USE_CUDA) && !defined(USE_ROCM) && !defined(USE_QNN) TEST(TransposeOptimizerTests, TestResize) { auto build_test_case_1 = [&](ModelTestBuilder& builder) { auto* input0_arg = MakeInput(builder, {{4, -1, 2, -1}}, {4, 6, 2, 10}, 0.0, 1.0); @@ -362,7 +348,9 @@ TEST(TransposeOptimizerTests, TestResize) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {10, 18}); } @@ -390,7 +378,9 @@ TEST(TransposeOptimizerTests, TestResizeOpset11) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {11, 18}); } @@ -418,7 +408,9 @@ TEST(TransposeOptimizerTests, TestResizeOpset15) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {15, 18}); } @@ -448,7 +440,9 @@ TEST(TransposeOptimizerTests, TestResizeSizeRoi) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {15, 18}); } @@ -482,7 +476,9 @@ TEST(TransposeOptimizerTests, TestResizeRoiScalesZeroRank0) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, {12, 18}); } @@ -511,7 +507,9 @@ TEST(TransposeOptimizerTests, TestResizeNonconst) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {11, 18}); } @@ -540,11 +538,12 @@ TEST(TransposeOptimizerTests, TestResizeNonconstOpset13) { TransformerTester(build_test_case_1, check_optimized_graph_1, TransformerLevel::Default, - TransformerLevel::Level1, + // need the level 2 TransposeOptimizer as pushing a Transpose through a Resize requires it to be + // assigned to the CPU EP first + TransformerLevel::Level2, /*opset_version*/ {13, 18}); } -#endif TEST(TransposeOptimizerTests, TestAdd) { auto build_test_case_1 = [&](ModelTestBuilder& builder) { auto* input0_arg = builder.MakeInput({4, 6, 10}, 0.0, 1.0); @@ -4454,12 +4453,13 @@ TEST(TransposeOptimizerTests, RegressionTest_GitHubIssue12151) { testing::ContainerEq(fetches[0].Get().DataAsSpan())); } +// These tests uses internal testing EP with static kernels which requires a full build, +// and the NHWC Conv which requires contrib ops +#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) + // Test a Transpose node followed by a Reshape that is logically equivalent to an Transpose can be merged. // The test graph was extracted from a model we were trying to use with the QNN EP. TEST(TransposeOptimizerTests, QnnTransposeReshape) { - // test uses internal testing EP with static kernels which requires a full build, - // and the NHWC Conv with requires contrib ops -#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) Status status; auto model_uri = ORT_TSTR("testdata/layout_transform_reshape.onnx"); @@ -4509,13 +4509,9 @@ TEST(TransposeOptimizerTests, QnnTransposeReshape) { EXPECT_TRUE(inputs[1]->Exists()); } } -#endif } TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { - // test uses internal testing EP with static kernels which requires a full build, - // and the NHWC Conv with requires contrib ops -#if !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) Status status; auto model_uri = ORT_TSTR("testdata/layout_transform_reshape.qdq.onnx"); @@ -4552,9 +4548,49 @@ TEST(TransposeOptimizerTests, QnnTransposeReshapeQDQ) { EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() << "' was not assigned to the internal testing EP."; } -#endif } +// Validate handling for EP with layout specific Resize that prefers NHWC +TEST(TransposeOptimizerTests, QnnResizeOpset11) { + Status status; + auto model_uri = ORT_TSTR("testdata/nhwc_resize_scales_opset11.onnx"); + + SessionOptions so; + // Uncomment to debug + // ASSERT_STATUS_OK(so.config_options.AddConfigEntry(kDebugLayoutTransformation, "1")); + + using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + + // set the test EP to support all ops in the model so that the layout transform applies to all nodes + const std::unordered_set empty_set; + auto internal_testing_ep = std::make_unique(empty_set, empty_set, DataLayout::NHWC); + internal_testing_ep->EnableStaticKernels().TakeAllNodes(); + + InferenceSessionWrapper session{so, GetEnvironment()}; + ASSERT_STATUS_OK(session.RegisterExecutionProvider(std::move(internal_testing_ep))); + ASSERT_STATUS_OK(session.Load(model_uri)); + ASSERT_STATUS_OK(session.Initialize()); + + const auto& graph = session.GetGraph(); + // all nodes should be assigned to the internal testing EP, which also means they should be in NHWC layout + std::string expected_ep(onnxruntime::utils::kInternalTestingExecutionProvider); + for (const auto& node : graph.Nodes()) { + EXPECT_TRUE(node.GetExecutionProviderType() == expected_ep) << node.OpType() << " node named '" << node.Name() + << "' was not assigned to the internal testing EP."; + if (node.OpType() == "Resize") { + EXPECT_EQ(node.Domain(), kMSInternalNHWCDomain) << "Resize was not converted to NHWC layout"; + } + } + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["Transpose"], 2) << "Resize should have been wrapped in 2 Transpose nodes to convert to NHWC"; + + // And the post-Resize Transpose should have been pushed all the way to the end + GraphViewer viewer(graph); + EXPECT_EQ(graph.GetNode(viewer.GetNodesInTopologicalOrder().back())->OpType(), "Transpose"); +} +#endif // !defined(ORT_MINIMAL_BUILD) && !defined(DISABLE_CONTRIB_OPS) + static void CheckSharedInitializerHandling(bool broadcast) { auto model_uri = broadcast ? ORT_TSTR("testdata/transpose_optimizer_shared_initializers_broadcast.onnx") : ORT_TSTR("testdata/transpose_optimizer_shared_initializers.onnx"); diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 9b41ba8c0d2ba..da906ebf76f79 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -1172,6 +1172,12 @@ ::std::vector<::std::basic_string> GetParameterStrings() { ORT_TSTR("bvlc_alexnet"), ORT_TSTR("bvlc_reference_caffenet"), ORT_TSTR("coreml_VGG16_ImageNet"), + ORT_TSTR("VGG 16-fp32"), + ORT_TSTR("VGG 19-caffe2"), + ORT_TSTR("VGG 19-bn"), + ORT_TSTR("VGG 16-bn"), + ORT_TSTR("VGG 19"), + ORT_TSTR("VGG 16"), ORT_TSTR("faster_rcnn"), ORT_TSTR("GPT2"), ORT_TSTR("GPT2_LM_HEAD"), diff --git a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc index 0434b16dc66ce..2ead9ec91f93f 100644 --- a/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/resize_op_test.cc @@ -780,7 +780,7 @@ TEST(ResizeOpTest, ResizeOpLinearUpSampleTest_5DTrilinear_pytorch_half_pixel) { } TEST(ResizeOpTest, ResizeOpLinearScalesNoOpTest) { - // To test NNAPI EP, we need the sclaes/sizes to be in initializers + // To test NNAPI EP, we need the scales/sizes to be in initializers auto run_test = [](bool scales_in_initializer) { OpTester test("Resize", 13); std::vector roi{}; diff --git a/onnxruntime/test/python/onnxruntime_test_collective.py b/onnxruntime/test/python/onnxruntime_test_collective.py index db1ebb5384730..4882b403c3c91 100644 --- a/onnxruntime/test/python/onnxruntime_test_collective.py +++ b/onnxruntime/test/python/onnxruntime_test_collective.py @@ -155,6 +155,7 @@ def _create_alltoall_ut_model_for_boolean_tensor( ) return ORTBertPretrainTest._create_model_with_opsets(graph_def) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") @parameterized.expand( [ (np.float32, TensorProto.FLOAT), @@ -193,6 +194,7 @@ def test_all_reduce(self, np_elem_type, elem_type): outputs[0], size * input, err_msg=f"{rank}: AllGather ({np_elem_type}, {elem_type}): results mismatch" ) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") @parameterized.expand( [ (np.float32, TensorProto.FLOAT, TensorProto.FLOAT), @@ -231,6 +233,7 @@ def test_all_gather(self, np_elem_type, elem_type, communication_elem_type): err_msg=f"{rank}: AllGather (axis0) ({np_elem_type}, {elem_type}, {communication_elem_type}): results mismatch", ) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") def test_all_gather_bool(self): model = self._create_allgather_ut_model((4,), 0, TensorProto.INT64, TensorProto.INT64) rank, _ = self._get_rank_size() @@ -250,6 +253,7 @@ def test_all_gather_bool(self): np.testing.assert_allclose(y, y_expected, err_msg=f"{rank}: AllGather (bool): results mismatch") + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") def test_all_gather_axis1(self): model = self._create_allgather_ut_model((128, 128), 1) rank, size = self._get_rank_size() @@ -268,6 +272,7 @@ def test_all_gather_axis1(self): np.testing.assert_allclose(outputs[0], expected_output, err_msg=f"{rank}: AllGather (axis1): results mismatch") + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") @parameterized.expand( [ (np.float32, TensorProto.FLOAT, TensorProto.FLOAT), @@ -349,6 +354,7 @@ def test_all_to_all(self, np_elem_type, elem_type, communication_elem_type): err_msg=f"{rank}: AllToAll ({np_elem_type}, {elem_type}, {communication_elem_type}): results mismatch", ) + @unittest.skipIf(not ort.has_collective_ops(), reason="onnx not compiled with mpi support") def test_all_to_all_bool(self): rank, _ = self._get_rank_size() diff --git a/onnxruntime/test/xctest/xcgtest.mm b/onnxruntime/test/xctest/xcgtest.mm index 5367f3e89c07c..c02f18d906cbe 100644 --- a/onnxruntime/test/xctest/xcgtest.mm +++ b/onnxruntime/test/xctest/xcgtest.mm @@ -201,7 +201,7 @@ + (void)registerTestClasses { delete listeners.Release(listeners.default_result_printer()); free(argv); - BOOL runDisabledTests = testing::GTEST_FLAG(also_run_disabled_tests); + BOOL runDisabledTests = GTEST_FLAG_GET(also_run_disabled_tests); NSMutableDictionary* testFilterMap = [NSMutableDictionary dictionary]; NSCharacterSet* decimalDigitCharacterSet = [NSCharacterSet decimalDigitCharacterSet]; diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 174edabbc91fe..968eece361724 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -9,6 +9,7 @@ #include "api.h" #include +#include #include namespace { @@ -17,6 +18,14 @@ OrtErrorCode g_last_error_code; std::string g_last_error_message; } // namespace +enum DataLocation { + DATA_LOCATION_NONE = 0, + DATA_LOCATION_CPU = 1, + DATA_LOCATION_CPU_PINNED = 2, + DATA_LOCATION_TEXTURE = 3, + DATA_LOCATION_GPU_BUFFER = 4 +}; + static_assert(sizeof(const char*) == sizeof(size_t), "size of a pointer and a size_t value should be the same."); static_assert(sizeof(size_t) == 4, "size of size_t should be 4 in this build (wasm32)."); @@ -223,13 +232,23 @@ void OrtFree(void* ptr) { } } -OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length) { +OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location) { + if (data_location != DATA_LOCATION_CPU && + data_location != DATA_LOCATION_CPU_PINNED && + data_location != DATA_LOCATION_GPU_BUFFER) { + std::ostringstream ostr; + ostr << "Invalid data location: " << data_location; + CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str())); + return nullptr; + } + std::vector shapes(dims_length); for (size_t i = 0; i < dims_length; i++) { shapes[i] = dims[i]; } if (data_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { + // data_location is ignored for string tensor. It is always CPU. OrtAllocator* allocator = nullptr; RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); @@ -244,12 +263,16 @@ OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* return UNREGISTER_AUTO_RELEASE(value); } else { - OrtMemoryInfo* memoryInfo = nullptr; - RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memoryInfo); - REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memoryInfo); + OrtMemoryInfo* memory_info = nullptr; + if (data_location != DATA_LOCATION_GPU_BUFFER) { + RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); + } else { + RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + } + REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memory_info); OrtValue* value = nullptr; - int error_code = CHECK_STATUS(CreateTensorWithDataAsOrtValue, memoryInfo, data, data_length, + int error_code = CHECK_STATUS(CreateTensorWithDataAsOrtValue, memory_info, data, data_length, dims_length > 0 ? shapes.data() : nullptr, dims_length, static_cast(data_type), &value); @@ -373,15 +396,85 @@ void OrtReleaseRunOptions(OrtRunOptions* run_options) { Ort::GetApi().ReleaseRunOptions(run_options); } +OrtIoBinding* OrtCreateBinding(OrtSession* session) { + OrtIoBinding* binding = nullptr; + int error_code = CHECK_STATUS(CreateIoBinding, session, &binding); + return (error_code == ORT_OK) ? binding : nullptr; +} + +int EMSCRIPTEN_KEEPALIVE OrtBindInput(OrtIoBinding* io_binding, + const char* name, + OrtValue* input) { + return CHECK_STATUS(BindInput, io_binding, name, input); +} + +int EMSCRIPTEN_KEEPALIVE OrtBindOutput(OrtIoBinding* io_binding, + const char* name, + OrtValue* output, + int output_location) { + if (output) { + return CHECK_STATUS(BindOutput, io_binding, name, output); + } else { + if (output_location != DATA_LOCATION_NONE && + output_location != DATA_LOCATION_CPU && + output_location != DATA_LOCATION_CPU_PINNED && + output_location != DATA_LOCATION_GPU_BUFFER) { + std::ostringstream ostr; + ostr << "Invalid data location (" << output_location << ") for output: \"" << name << "\"."; + return CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str())); + } + + OrtMemoryInfo* memory_info = nullptr; + if (output_location != DATA_LOCATION_GPU_BUFFER) { + RETURN_ERROR_CODE_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); + } else { + RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + } + REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memory_info); + return CHECK_STATUS(BindOutputToDevice, io_binding, name, memory_info); + } +} + +void OrtClearBoundOutputs(OrtIoBinding* io_binding) { + Ort::GetApi().ClearBoundOutputs(io_binding); +} + +void OrtReleaseBinding(OrtIoBinding* io_binding) { + Ort::GetApi().ReleaseIoBinding(io_binding); +} + +int OrtRunWithBinding(OrtSession* session, + OrtIoBinding* io_binding, + size_t output_count, + OrtValue** outputs, + OrtRunOptions* run_options) { + RETURN_ERROR_CODE_IF_ERROR(RunWithBinding, session, run_options, io_binding); + + OrtAllocator* allocator = nullptr; + RETURN_ERROR_CODE_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); + + size_t binding_output_count = 0; + OrtValue** binding_outputs = nullptr; + RETURN_ERROR_CODE_IF_ERROR(GetBoundOutputValues, io_binding, allocator, &binding_outputs, &binding_output_count); + REGISTER_AUTO_RELEASE_BUFFER(OrtValue*, binding_outputs, allocator); + + if (binding_output_count != output_count) { + return CheckStatus( + Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, "Output count is inconsistent with IO Binding output data.")); + } + + for (size_t i = 0; i < output_count; i++) { + outputs[i] = binding_outputs[i]; + } + + return ORT_OK; +} + int OrtRun(OrtSession* session, const char** input_names, const ort_tensor_handle_t* inputs, size_t input_count, const char** output_names, size_t output_count, ort_tensor_handle_t* outputs, OrtRunOptions* run_options) { - auto status_code = CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs); -#if defined(USE_JSEP) - EM_ASM({ Module.jsepRunPromiseResolve ?.($0); }, status_code); -#endif - return status_code; + return CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs); } char* OrtEndProfiling(ort_session_handle_t session) { diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 398c901e0e5ed..9a0664697f0ff 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -15,6 +15,9 @@ struct OrtSession; using ort_session_handle_t = OrtSession*; +struct OrtIoBinding; +using ort_io_binding_handle_t = OrtIoBinding*; + struct OrtSessionOptions; using ort_session_options_handle_t = OrtSessionOptions*; @@ -164,9 +167,10 @@ void EMSCRIPTEN_KEEPALIVE OrtFree(void* ptr); * @param data_length size of the buffer 'data' in bytes. * @param dims a pointer to an array of dims. the array should contain (dims_length) element(s). * @param dims_length the length of the tensor's dimension + * @param data_location specify the memory location of the tensor data. 0 for CPU, 1 for GPU buffer. * @returns a tensor handle. Caller must release it after use by calling OrtReleaseTensor(). */ -ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length); +ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location); /** * get type, shape info and data of the specified tensor. @@ -216,6 +220,58 @@ int EMSCRIPTEN_KEEPALIVE OrtAddRunConfigEntry(ort_run_options_handle_t run_optio */ void EMSCRIPTEN_KEEPALIVE OrtReleaseRunOptions(ort_run_options_handle_t run_options); +/** + * create an instance of ORT IO binding. + */ +ort_io_binding_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateBinding(ort_session_handle_t session); + +/** + * bind an input tensor to the IO binding instance. A cross device copy will be performed if necessary. + * @param io_binding handle of the IO binding + * @param name name of the input + * @param input handle of the input tensor + * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtBindInput(ort_io_binding_handle_t io_binding, + const char* name, + ort_tensor_handle_t input); + +/** + * bind an output tensor or location to the IO binding instance. + * @param io_binding handle of the IO binding + * @param name name of the output + * @param output handle of the output tensor. nullptr for output location binding. + * @param output_location specify the memory location of the output tensor data. + * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtBindOutput(ort_io_binding_handle_t io_binding, + const char* name, + ort_tensor_handle_t output, + int output_location); + +/** + * clear all bound outputs. + */ +void EMSCRIPTEN_KEEPALIVE OrtClearBoundOutputs(ort_io_binding_handle_t io_binding); + +/** + * release the specified ORT IO binding. + */ +void EMSCRIPTEN_KEEPALIVE OrtReleaseBinding(ort_io_binding_handle_t io_binding); + +/** + * inference the model. + * @param session handle of the specified session + * @param io_binding handle of the IO binding + * @param run_options handle of the run options + * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtRunWithBinding(ort_session_handle_t session, + ort_io_binding_handle_t io_binding, + size_t output_count, + ort_tensor_handle_t* outputs, + ort_run_options_handle_t run_options); + /** * inference the model. * @param session handle of the specified session diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 15d393f4ce62d..427ad6f6d14f3 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -14,40 +14,156 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module.jsepReleaseKernel = releaseKernel; Module.jsepRunKernel = runKernel; - Module['jsepOnRunStart'] = sessionId => { - Module['jsepRunPromise'] = new Promise(r => { - Module.jsepRunPromiseResolve = r; - }); - - if (Module.jsepSessionState) { - throw new Error('Session already started'); - } - - Module.jsepSessionState = { - sessionId, - errors: [] + // This is a simplified version of cwrap() with options.async === true (-sASYNCIFY=1) + // It removes some overhead in cwarp() and ccall() that we don't need. + // + // Currently in JSEP build, we only use this for the following functions: + // - OrtRun() + // - OrtRunWithBinding() + // - OrtBindInput() + // + // Note: about parameters "getFunc" and "setFunc": + // - Emscripten has different behaviors for Debug and Release builds for generating exported function wrapper. + // + // - In Debug build, it will generate a wrapper function for each exported function. For example, it generates a + // wrapper for OrtRun() like this (minified): + // ``` + // var _OrtRun = Module["_OrtRun"] = createExportWrapper("OrtRun"); + // ``` + // + // - In Release build, it will generate a lazy loading wrapper for each exported function. For example, it generates + // a wrapper for OrtRun() like this (minified): + // ``` + // d._OrtRun = (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); + // ``` + // + // The behavior of these two wrappers are different. The debug build will assign `Module["_OrtRun"]` only once + // because `createExportWrapper()` does not reset `Module["_OrtRun"]` inside. The release build, however, will + // reset d._OrtRun to J.ka when the first time it is called. + // + // The difference is important because we need to design the async wrapper in a way that it can handle both cases. + // + // Now, let's look at how the async wrapper is designed to work for both cases: + // + // - Debug build: + // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to `createExportWrapper("OrtRun")`. + // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async + // wrapper function. + // Value of `Module["_OrtRun"]` will not be changed again. + // + // - Release build: + // 1. When Web assembly is being loaded, `Module["_OrtRun"]` is assigned to a lazy loading wrapper function. + // 2. When the first time `Module["jsepInit"]` is called, `Module["_OrtRun"]` is re-assigned to a new async + // wrapper function. + // 3. When the first time `Module["_OrtRun"]` is called, the async wrapper will be called. It will call into this + // function: + // ``` + // (a, b, c, e, f, h, l, q) => (d._OrtRun = J.ka)(a, b, c, e, f, h, l, q); + // ``` + // This function will assign d._OrtRun (ie. the minimized `Module["_OrtRun"]`) to the real function (J.ka). + // 4. Since d._OrtRun is re-assigned, we need to update the async wrapper to re-assign its stored + // function to the updated value (J.ka), and re-assign the value of `d._OrtRun` back to the async wrapper. + // Value of `Module["_OrtRun"]` will not be changed again. + // + // The value of `Module["_OrtRun"]` will need to be assigned for 2 times for debug build and 4 times for release + // build. + // + // This is why we need this `getFunc` and `setFunc` parameters. They are used to get the current value of an + // exported function and set the new value of an exported function. + // + const jsepWrapAsync = (func, getFunc, setFunc) => { + return (...args) => { + // cache the async data before calling the function. + const previousAsync = Asyncify.currData; + + const previousFunc = getFunc?.(); + const ret = func(...args); + const newFunc = getFunc?.(); + if (previousFunc !== newFunc) { + // The exported function has been updated. + // Set the sync function reference to the new function. + func = newFunc; + // Set the exported function back to the async wrapper. + setFunc(previousFunc); + // Remove getFunc and setFunc. They are no longer needed. + setFunc = null; + getFunc = null; + } + + // If the async data has been changed, it means that the function started an async operation. + if (Asyncify.currData != previousAsync) { + // returns the promise + return Asyncify.whenDone(); + } + // the function is synchronous. returns the result. + return ret; }; }; - Module['jsepOnRunEnd'] = sessionId => { - if (Module.jsepSessionState.sessionId !== sessionId) { - throw new Error('Session ID mismatch'); - } - - const errorPromises = Module.jsepSessionState.errors; - Module.jsepSessionState = null; - - return errorPromises.length === 0 ? Promise.resolve() : new Promise((resolve, reject) => { - Promise.all(errorPromises).then(errors => { - errors = errors.filter(e => e); - if (errors.length > 0) { - reject(new Error(errors.join('\n'))); - } else { - resolve(); + // This is a wrapper for OrtRun() and OrtRunWithBinding() to ensure that Promises are handled correctly. + const runAsync = (runAsyncFunc) => { + return async (...args) => { + try { + // Module.jsepSessionState should be null, unless we are in the middle of a session. + // If it is not null, it means that the previous session has not finished yet. + if (Module.jsepSessionState) { + throw new Error('Session already started'); + } + const state = Module.jsepSessionState = {sessionHandle: args[0], errors: []}; + + // Run the acyncified function: OrtRun() or OrtRunWithBinding() + const ret = await runAsyncFunc(...args); + + // Check if the session is still valid. this object should be the same as the one we set above. + if (Module.jsepSessionState !== state) { + throw new Error('Session mismatch'); + } + + // Flush the backend. This will submit all pending commands to the GPU. + backend['flush'](); + + // Await all pending promises. This includes GPU validation promises for diagnostic purposes. + const errorPromises = state.errors; + if (errorPromises.length > 0) { + let errors = await Promise.all(errorPromises); + errors = errors.filter(e => e); + if (errors.length > 0) { + throw new Error(errors.join('\n')); + } } - }, reason => { - reject(reason); - }); - }); + + return ret; + } finally { + Module.jsepSessionState = null; + } + }; + }; + + // replace the original functions with asyncified versions + Module['_OrtRun'] = runAsync(jsepWrapAsync( + Module['_OrtRun'], + () => Module['_OrtRun'], + v => Module['_OrtRun'] = v)); + Module['_OrtRunWithBinding'] = runAsync(jsepWrapAsync( + Module['_OrtRunWithBinding'], + () => Module['_OrtRunWithBinding'], + v => Module['_OrtRunWithBinding'] = v)); + Module['_OrtBindInput'] = jsepWrapAsync( + Module['_OrtBindInput'], + () => Module['_OrtBindInput'], + v => Module['_OrtBindInput'] = v); + + // expose webgpu backend functions + Module['jsepRegisterBuffer'] = (sessionId, index, buffer, size) => { + return backend['registerBuffer'](sessionId, index, buffer, size); + }; + Module['jsepUnregisterBuffers'] = sessionId => { + backend['unregisterBuffers'](sessionId); + }; + Module['jsepGetBuffer'] = (dataId) => { + return backend['getBuffer'](dataId); + }; + Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { + return backend['createDownloader'](gpuBuffer, size, type); }; }; diff --git a/orttraining/orttraining/python/orttraining_python_module.cc b/orttraining/orttraining/python/orttraining_python_module.cc index 7024244629c3e..88ef90a7feaa8 100644 --- a/orttraining/orttraining/python/orttraining_python_module.cc +++ b/orttraining/orttraining/python/orttraining_python_module.cc @@ -15,6 +15,12 @@ namespace onnxruntime { namespace python { namespace py = pybind11; +#if defined(USE_MPI) && defined(ORT_USE_NCCL) +static constexpr bool HAS_COLLECTIVE_OPS = true; +#else +static constexpr bool HAS_COLLECTIVE_OPS = false; +#endif + using namespace onnxruntime::logging; std::unique_ptr CreateExecutionProviderInstance( @@ -361,6 +367,8 @@ PYBIND11_MODULE(onnxruntime_pybind11_state, m) { }, "Clean the execution provider instances used in ort training module."); + m.def("has_collective_ops", []() -> bool { return HAS_COLLECTIVE_OPS; }); + // See documentation for class TrainingEnvInitialzer earlier in this module // for an explanation as to why this is needed. auto atexit = py::module_::import("atexit"); diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml index 96a0ebd753d8e..fe7f752513f3c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-wasm-ci.yml @@ -98,7 +98,6 @@ jobs: cd '$(Build.SourcesDirectory)/cmake/external/emsdk' ./emsdk install 3.1.44 ccache-git-emscripten-64bit ./emsdk activate 3.1.44 ccache-git-emscripten-64bit - ln -s $(Build.SourcesDirectory)/cmake/external/emsdk/ccache/git-emscripten_64bit/bin/ccache /usr/local/bin/ccache displayName: 'emsdk install and activate ccache for emscripten' condition: eq('${{ parameters.WithCache }}', 'true') diff --git a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml index d737376eb99b5..788b02f539821 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-web-ci.yml @@ -29,6 +29,7 @@ jobs: pool: ${{ parameters.PoolName }} variables: + webgpuCommandlineExtraFlags: '--chromium-flags=--ignore-gpu-blocklist --chromium-flags=--gpu-vendor-id=0x10de' runCodesignValidationInjection: false timeoutInMinutes: 60 workspace: @@ -159,12 +160,22 @@ jobs: npm test -- -e=edge -b=webgl,wasm,xnnpack workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (wasm,webgl,xnnpack backend)' - condition: ne('${{ parameters.RunWebGpuTests }}', 'true') + condition: eq('${{ parameters.RunWebGpuTests }}', 'false') - script: | - npm test -- -e=edge -b=webgl,wasm,xnnpack,webgpu --chromium-flags=--ignore-gpu-blocklist --chromium-flags=--gpu-vendor-id=0x10de + npm test -- -e=edge -b=webgl,wasm,xnnpack,webgpu $(webgpuCommandlineExtraFlags) workingDirectory: '$(Build.SourcesDirectory)\js\web' displayName: 'Run ort-web tests (ALL backends)' - condition: ne('${{ parameters.RunWebGpuTests }}', 'false') + condition: eq('${{ parameters.RunWebGpuTests }}', 'true') + - script: | + npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-tensor $(webgpuCommandlineExtraFlags) + workingDirectory: '$(Build.SourcesDirectory)\js\web' + displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-tensor)' + condition: eq('${{ parameters.RunWebGpuTests }}', 'true') + - script: | + npm test -- suite1 -e=edge -b=webgpu --io-binding=gpu-location $(webgpuCommandlineExtraFlags) + workingDirectory: '$(Build.SourcesDirectory)\js\web' + displayName: 'Run ort-web tests (Suite1, webgpu, IO-binding=gpu-location)' + condition: eq('${{ parameters.RunWebGpuTests }}', 'true') - script: | npm test -- --webgl-texture-pack-mode -b=webgl -e=edge workingDirectory: '$(Build.SourcesDirectory)\js\web' diff --git a/winml/adapter/winml_adapter_environment.cpp b/winml/adapter/winml_adapter_environment.cpp index 43babdf43967e..e2da473c7d5b5 100644 --- a/winml/adapter/winml_adapter_environment.cpp +++ b/winml/adapter/winml_adapter_environment.cpp @@ -9,6 +9,7 @@ #include "winml_adapter_apis.h" #include "core/framework/error_code_helper.h" #include "core/session/ort_env.h" +#include "core/session/user_logging_sink.h" #ifdef USE_DML #include "abi_custom_registry_impl.h" @@ -18,12 +19,12 @@ #endif USE_DML namespace winmla = Windows::AI::MachineLearning::Adapter; -class WinmlAdapterLoggingWrapper : public LoggingWrapper { +class WinmlAdapterLoggingWrapper : public onnxruntime::UserLoggingSink { public: WinmlAdapterLoggingWrapper( OrtLoggingFunction logging_function, OrtProfilingFunction profiling_function, void* logger_param ) - : LoggingWrapper(logging_function, logger_param), + : onnxruntime::UserLoggingSink(logging_function, logger_param), profiling_function_(profiling_function) {} void SendProfileEvent(onnxruntime::profiling::EventRecord& event_record) const override { diff --git a/winml/test/model/model_tests.cpp b/winml/test/model/model_tests.cpp index 0b4c10eac9142..5057f74046638 100644 --- a/winml/test/model/model_tests.cpp +++ b/winml/test/model/model_tests.cpp @@ -380,13 +380,6 @@ std::string GetFullNameOfTest(ITestCase* testCase, winml::LearningModelDeviceKin name += tokenizedModelPath[tokenizedModelPath.size() - 2] += "_"; // model name name += tokenizedModelPath[tokenizedModelPath.size() - 3]; // opset version - // To introduce models from model zoo, the model path is structured like this "///?.onnx" - std::string source = tokenizedModelPath[tokenizedModelPath.size() - 4]; - // `models` means the root of models, to be ompatible with the old structure, that is, the source name is empty. - if (source != "models") { - name += "_" + source; - } - std::replace_if( name.begin(), name.end(), [](char c) { return !absl::ascii_isalnum(c); }, '_' ); @@ -405,6 +398,13 @@ std::string GetFullNameOfTest(ITestCase* testCase, winml::LearningModelDeviceKin ModifyNameIfDisabledTest(/*inout*/ name, deviceKind); } + // To introduce models from model zoo, the model path is structured like this "///?.onnx" + std::string source = tokenizedModelPath[tokenizedModelPath.size() - 4]; + // `models` means the root of models, to be ompatible with the old structure, that is, the source name is empty. + if (source != "models") { + name += "_" + source; + } + return name; }