diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 732c0511d400f..d72b61a0859b2 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -77,6 +77,7 @@ if(WIN32) onnxruntime_add_shared_library(onnxruntime ${SYMBOL_FILE} "${ONNXRUNTIME_ROOT}/core/dll/dllmain.cc" + "${ONNXRUNTIME_ROOT}/core/dll/delay_load_hook.cc" "${ONNXRUNTIME_ROOT}/core/dll/onnxruntime.rc" ) elseif(onnxruntime_BUILD_APPLE_FRAMEWORK) diff --git a/cmake/onnxruntime_nodejs.cmake b/cmake/onnxruntime_nodejs.cmake index 376d895be34a9..60c56f4c22237 100644 --- a/cmake/onnxruntime_nodejs.cmake +++ b/cmake/onnxruntime_nodejs.cmake @@ -103,4 +103,8 @@ add_custom_target(nodejs_binding_wrapper ALL add_dependencies(js_common_npm_ci js_npm_ci) add_dependencies(nodejs_binding_wrapper js_common_npm_ci) add_dependencies(nodejs_binding_wrapper onnxruntime) +if (WIN32 AND onnxruntime_USE_WEBGPU) + add_dependencies(nodejs_binding_wrapper copy_dxil_dll) + add_dependencies(nodejs_binding_wrapper dxcompiler) +endif() endif() diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index e822f0a3655fc..9e3ab4d41f416 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -525,6 +525,9 @@ set (onnxruntime_global_thread_pools_test_SRC set (onnxruntime_webgpu_external_dawn_test_SRC ${TEST_SRC_DIR}/webgpu/external_dawn/main.cc) +set (onnxruntime_webgpu_delay_load_test_SRC + ${TEST_SRC_DIR}/webgpu/delay_load/main.cc) + # tests from lowest level library up. # the order of libraries should be maintained, with higher libraries being added first in the list @@ -1864,4 +1867,13 @@ if (onnxruntime_USE_WEBGPU AND onnxruntime_USE_EXTERNAL_DAWN) onnxruntime_add_include_to_target(onnxruntime_webgpu_external_dawn_test dawn::dawncpp_headers dawn::dawn_headers) endif() +if (onnxruntime_USE_WEBGPU AND WIN32 AND onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_MINIMAL_BUILD) + AddTest(DYN + TARGET onnxruntime_webgpu_delay_load_test + SOURCES ${onnxruntime_webgpu_delay_load_test_SRC} + LIBS ${SYS_PATH_LIB} + DEPENDS ${all_dependencies} + ) +endif() + include(onnxruntime_fuzz_test.cmake) diff --git a/js/node/CMakeLists.txt b/js/node/CMakeLists.txt index d79a82c572dc2..7e1ca8f6338e1 100644 --- a/js/node/CMakeLists.txt +++ b/js/node/CMakeLists.txt @@ -117,6 +117,14 @@ if (WIN32) file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/DirectML.dll DESTINATION ${dist_folder}) endif () + if(USE_WEBGPU) + file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/webgpu_dawn.dll + DESTINATION ${dist_folder}) + file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/dxil.dll + DESTINATION ${dist_folder}) + file(COPY ${ONNXRUNTIME_WIN_BIN_DIR}/dxcompiler.dll + DESTINATION ${dist_folder}) + endif () elseif (APPLE) file(COPY ${ONNXRUNTIME_BUILD_DIR}/libonnxruntime.dylib DESTINATION ${dist_folder} FOLLOW_SYMLINK_CHAIN) diff --git a/js/node/src/directml_load_helper.cc b/js/node/src/directml_load_helper.cc deleted file mode 100644 index 6aafe4d5fa788..0000000000000 --- a/js/node/src/directml_load_helper.cc +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#ifdef _WIN32 -#include "common.h" -#include "windows.h" - -void LoadDirectMLDll(Napi::Env env) { - DWORD pathLen = MAX_PATH; - std::wstring path(pathLen, L'\0'); - HMODULE moduleHandle = nullptr; - - GetModuleHandleEx(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, - reinterpret_cast(&LoadDirectMLDll), &moduleHandle); - - DWORD getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast(path.c_str()), pathLen); - while (getModuleFileNameResult == 0 || getModuleFileNameResult == pathLen) { - int ret = GetLastError(); - if (ret == ERROR_INSUFFICIENT_BUFFER && pathLen < 32768) { - pathLen *= 2; - path.resize(pathLen); - getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast(path.c_str()), pathLen); - } else { - ORT_NAPI_THROW_ERROR(env, "Failed getting path to load DirectML.dll, error code: ", ret); - } - } - - path.resize(path.rfind(L'\\') + 1); - path.append(L"DirectML.dll"); - HMODULE libraryLoadResult = LoadLibraryW(path.c_str()); - - if (!libraryLoadResult) { - int ret = GetLastError(); - ORT_NAPI_THROW_ERROR(env, "Failed loading bundled DirectML.dll, error code: ", ret); - } -} -#endif diff --git a/js/node/src/directml_load_helper.h b/js/node/src/directml_load_helper.h deleted file mode 100644 index 074a4f95ed476..0000000000000 --- a/js/node/src/directml_load_helper.h +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#if defined(USE_DML) && defined(_WIN32) -void LoadDirectMLDll(Napi::Env env); -#endif diff --git a/js/node/src/inference_session_wrap.cc b/js/node/src/inference_session_wrap.cc index 23d859351f426..04ab71dc48ec2 100644 --- a/js/node/src/inference_session_wrap.cc +++ b/js/node/src/inference_session_wrap.cc @@ -4,7 +4,6 @@ #include "onnxruntime_cxx_api.h" #include "common.h" -#include "directml_load_helper.h" #include "inference_session_wrap.h" #include "run_options_helper.h" #include "session_options_helper.h" @@ -19,9 +18,6 @@ Napi::FunctionReference& InferenceSessionWrap::GetTensorConstructor() { } Napi::Object InferenceSessionWrap::Init(Napi::Env env, Napi::Object exports) { -#if defined(USE_DML) && defined(_WIN32) - LoadDirectMLDll(env); -#endif // create ONNX runtime env Ort::InitApi(); ORT_NAPI_THROW_ERROR_IF( diff --git a/onnxruntime/core/dll/delay_load_hook.cc b/onnxruntime/core/dll/delay_load_hook.cc new file mode 100644 index 0000000000000..55d5b0e270629 --- /dev/null +++ b/onnxruntime/core/dll/delay_load_hook.cc @@ -0,0 +1,91 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// == workaround for delay loading of dependencies of onnxruntime.dll == +// +// Problem: +// +// When onnxruntime.dll uses delay loading for its dependencies, the dependencies are loaded using LoadLibraryEx, +// which search the directory of process (.exe) instead of this library (onnxruntime.dll). This is a problem for +// usages of Node.js binding and python binding, because Windows will try to find the dependencies in the directory +// of node.exe or python.exe, which is not the directory of onnxruntime.dll. +// +// Solution: +// +// By using the delay load hook `__pfnDliNotifyHook2`, we can intervene the loading procedure by loading from an +// absolute path. The absolute path is constructed by appending the name of the DLL to load to the directory of +// onnxruntime.dll. This way, we can ensure that the dependencies are loaded from the same directory as onnxruntime.dll. +// +// See also: +// - https://learn.microsoft.com/en-us/cpp/build/reference/understanding-the-helper-function?view=msvc-170#structure-and-constant-definitions +// - https://learn.microsoft.com/en-us/windows/win32/dlls/dynamic-link-library-search-order#alternate-search-order-for-unpackaged-apps +// +// The DLL DelayLoad hook is only enabled when: +// - The compiler is MSVC +// - at least one of USE_WEBGPU or USE_DML is defined +// +#if defined(_MSC_VER) && (defined(USE_WEBGPU) || defined(USE_DML)) + +#include +#include +#include +#include + +namespace { + +#define DEFINE_KNOWN_DLL(name) {#name ".dll", L#name L".dll"} + +constexpr struct { + const char* str; + const wchar_t* wstr; +} known_dlls[] = { +#if defined(USE_WEBGPU) + DEFINE_KNOWN_DLL(webgpu_dawn), +#endif +#if defined(USE_DML) + DEFINE_KNOWN_DLL(DirectML), +#endif +}; +} // namespace + +FARPROC WINAPI delay_load_hook(unsigned dliNotify, PDelayLoadInfo pdli) { + if (dliNotify == dliNotePreLoadLibrary) { + for (size_t i = 0; i < _countof(known_dlls); ++i) { + if (_stricmp(pdli->szDll, known_dlls[i].str) == 0) { + // Try to load the DLL from the same directory as onnxruntime.dll + + // First, get the path to onnxruntime.dll + DWORD pathLen = MAX_PATH; + std::wstring path(pathLen, L'\0'); + HMODULE moduleHandle = nullptr; + + GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + reinterpret_cast(&delay_load_hook), &moduleHandle); + + DWORD getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast(path.c_str()), pathLen); + while (getModuleFileNameResult == 0 || getModuleFileNameResult == pathLen) { + int ret = GetLastError(); + if (ret == ERROR_INSUFFICIENT_BUFFER && pathLen < 32768) { + pathLen *= 2; + path.resize(pathLen); + getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast(path.c_str()), pathLen); + } else { + // Failed to get the path to onnxruntime.dll. In this case, we will just return NULL and let the system + // search for the DLL in the default search order. + return NULL; + } + } + + path.resize(path.rfind(L'\\') + 1); + path.append(known_dlls[i].wstr); + + return FARPROC(LoadLibraryExW(path.c_str(), NULL, LOAD_LIBRARY_SEARCH_DEFAULT_DIRS | LOAD_LIBRARY_SEARCH_DLL_LOAD_DIR)); + } + } + } + return NULL; +} + +extern "C" const PfnDliHook __pfnDliNotifyHook2 = delay_load_hook; + +#endif diff --git a/onnxruntime/core/dll/dllmain.cc b/onnxruntime/core/dll/dllmain.cc index 2e7bdafd0599f..ac5dcd9c96084 100644 --- a/onnxruntime/core/dll/dllmain.cc +++ b/onnxruntime/core/dll/dllmain.cc @@ -13,7 +13,7 @@ #pragma GCC diagnostic pop #endif -// dllmain.cpp : Defines the entry point for the DLL application. +// dllmain.cc : Defines the entry point for the DLL application. BOOL APIENTRY DllMain(HMODULE /*hModule*/, DWORD ul_reason_for_call, LPVOID /*lpReserved*/ diff --git a/onnxruntime/core/providers/webgpu/dll_delay_load_helper.cc b/onnxruntime/core/providers/webgpu/dll_delay_load_helper.cc new file mode 100644 index 0000000000000..9bf9200c2a5d4 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/dll_delay_load_helper.cc @@ -0,0 +1,87 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/dll_delay_load_helper.h" + +#if defined(_WIN32) && defined(_MSC_VER) && !defined(__EMSCRIPTEN__) + +#include +#include +#include +#include +#include + +namespace onnxruntime { +namespace webgpu { + +namespace { + +// Get the directory of the current DLL (usually it's onnxruntime.dll). +std::wstring GetCurrentDllDir() { + DWORD pathLen = MAX_PATH; + std::wstring path(pathLen, L'\0'); + HMODULE moduleHandle = nullptr; + + GetModuleHandleExW(GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, + reinterpret_cast(&GetCurrentDllDir), &moduleHandle); + + DWORD getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast(path.c_str()), pathLen); + while (getModuleFileNameResult == 0 || getModuleFileNameResult == pathLen) { + int ret = GetLastError(); + if (ret == ERROR_INSUFFICIENT_BUFFER && pathLen < 32768) { + pathLen *= 2; + path.resize(pathLen); + getModuleFileNameResult = GetModuleFileNameW(moduleHandle, const_cast(path.c_str()), pathLen); + } else { + // Failed to get the path to onnxruntime.dll. Returns an empty string. + return std::wstring{}; + } + } + path.resize(path.rfind(L'\\') + 1); + return path; +} + +std::once_flag run_once_before_load_deps_mutex; +std::once_flag run_once_after_load_deps_mutex; +bool dll_dir_set = false; + +} // namespace + +DllDelayLoadHelper::DllDelayLoadHelper() { + // Setup DLL search directory + std::call_once(run_once_before_load_deps_mutex, []() { + std::wstring path = GetCurrentDllDir(); + if (!path.empty()) { + SetDllDirectoryW(path.c_str()); + dll_dir_set = true; + } + }); +} + +DllDelayLoadHelper::~DllDelayLoadHelper() { + // Restore DLL search directory + std::call_once(run_once_after_load_deps_mutex, []() { + if (dll_dir_set) { + SetDllDirectoryW(NULL); + } + }); +} + +} // namespace webgpu +} // namespace onnxruntime + +#else // defined(_WIN32) && defined(_MSC_VER) && !defined(__EMSCRIPTEN__) + +namespace onnxruntime { +namespace webgpu { + +DllDelayLoadHelper::DllDelayLoadHelper() { +} + +DllDelayLoadHelper::~DllDelayLoadHelper() { +} + +} // namespace webgpu +} // namespace onnxruntime + +#endif diff --git a/onnxruntime/core/providers/webgpu/dll_delay_load_helper.h b/onnxruntime/core/providers/webgpu/dll_delay_load_helper.h new file mode 100644 index 0000000000000..7dfd9cac43013 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/dll_delay_load_helper.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +namespace webgpu { + +// The DLL delay load helper is a RAII style guard to ensure DLL loading is done correctly. +// +// - On Windows, the helper sets the DLL search path to the directory of the current DLL. +// - On other platforms, the helper does nothing. +// +struct DllDelayLoadHelper final { + DllDelayLoadHelper(); + ~DllDelayLoadHelper(); +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index d66c2a79d28a8..95991d3df9323 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -19,6 +19,7 @@ #include "core/providers/webgpu/program_cache_key.h" #include "core/providers/webgpu/program_manager.h" #include "core/providers/webgpu/string_macros.h" +#include "core/providers/webgpu/dll_delay_load_helper.h" namespace onnxruntime { namespace webgpu { @@ -50,6 +51,10 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info // Initialization.Step.2 - Create wgpu::Adapter if (adapter_ == nullptr) { + // DLL delay loading happens inside wgpuRequestAdapter(). + // Use this helper as RAII to ensure the DLL search path is set correctly. + DllDelayLoadHelper helper{}; + wgpu::RequestAdapterOptions req_adapter_options = {}; wgpu::DawnTogglesDescriptor adapter_toggles_desc = {}; req_adapter_options.nextInChain = &adapter_toggles_desc; diff --git a/onnxruntime/test/webgpu/delay_load/main.cc b/onnxruntime/test/webgpu/delay_load/main.cc new file mode 100644 index 0000000000000..357cb0055d5a8 --- /dev/null +++ b/onnxruntime/test/webgpu/delay_load/main.cc @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#define ORT_API_MANUAL_INIT +#include "core/session/onnxruntime_cxx_api.h" + +// This program is to test the delay loading of onnxruntime.dll. +// +// To verify the delay loading actually works, we need to do the test in 2 steps: +// +// 1. Prepare a folder structure like below: +// +// ├── webgpu_delay_load_test_root (newly created folder) +// │ ├── dlls +// │ │ ├── onnxruntime.dll +// │ │ ├── webgpu_dawn.dll +// │ │ ├── dxil.dll +// │ │ └── dxcompiler.dll +// │ └── test.exe +// └── onnxruntime_webgpu_delay_load_test.exe (this binary) +// +// This folder structure ensures no DLLs are in the same folder as the executable. +// +// 2. Launch the test binary from the root folder of the above structure. +// +// So, there are 2 modes of this program: +// 1. "Prepare" mode: Do the step 1 above. (default) +// 2. "Test" mode: Do the step 2 above. (specified by --test argument) + +int prepare_main(); +int test_main(); + +int wmain(int argc, wchar_t* argv[]) { + if (argc == 2 && wcscmp(argv[1], L"--test") == 0) { + return test_main(); + } else { + return prepare_main(); + } +} + +int prepare_main() { + WCHAR path[32768]; + GetModuleFileNameW(NULL, path, 32768); + namespace fs = std::filesystem; + std::wstring path_str(path); + fs::path exe_full_path{path_str}; // /onnxruntime_webgpu_delay_load_test.exe + fs::path test_dir = exe_full_path.parent_path(); // / + fs::path exe_name = exe_full_path.filename(); // onnxruntime_webgpu_delay_load_test.exe + fs::path root_folder = test_dir / L"webgpu_delay_load_test_root\\"; // /webgpu_delay_load_test_root/ + fs::path dlls_folder = root_folder / L"dlls\\"; // /webgpu_delay_load_test_root/dlls/ + + // ensure the test folder exists and is empty + if (fs::exists(root_folder)) { + fs::remove_all(root_folder); + } + fs::create_directories(dlls_folder); + + fs::current_path(test_dir); + + // copy the required DLLs to the dlls folder + fs::copy_file(L"onnxruntime.dll", dlls_folder / L"onnxruntime.dll"); + fs::copy_file(L"webgpu_dawn.dll", dlls_folder / L"webgpu_dawn.dll"); + fs::copy_file(L"dxil.dll", dlls_folder / L"dxil.dll"); + fs::copy_file(L"dxcompiler.dll", dlls_folder / L"dxcompiler.dll"); + + // copy the test binary to the root folder + fs::copy_file(exe_full_path, root_folder / L"test.exe"); + + // run "onnxruntime_webgpu_delay_load_test.exe --test" from the test root folder + fs::current_path(root_folder); + return _wsystem(L"test.exe --test"); +} + +int run() { + Ort::Env env{nullptr}; + int retval = 0; + try { + env = Ort::Env{ORT_LOGGING_LEVEL_WARNING, "Default"}; + + // model is https://github.com/onnx/onnx/blob/v1.15.0/onnx/backend/test/data/node/test_abs/model.onnx + constexpr uint8_t MODEL_DATA[] = {8, 7, 18, 12, 98, 97, 99, 107, 101, 110, + 100, 45, 116, 101, 115, 116, 58, 73, 10, 11, + 10, 1, 120, 18, 1, 121, 34, 3, 65, 98, + 115, 18, 8, 116, 101, 115, 116, 95, 97, 98, + 115, 90, 23, 10, 1, 120, 18, 18, 10, 16, + 8, 1, 18, 12, 10, 2, 8, 3, 10, 2, + 8, 4, 10, 2, 8, 5, 98, 23, 10, 1, + 121, 18, 18, 10, 16, 8, 1, 18, 12, 10, + 2, 8, 3, 10, 2, 8, 4, 10, 2, 8, + 5, 66, 4, 10, 0, 16, 13}; + + Ort::SessionOptions session_options; + session_options.DisableMemPattern(); + std::unordered_map provider_options; + session_options.AppendExecutionProvider("WebGPU", provider_options); + Ort::Session session{env, MODEL_DATA, sizeof(MODEL_DATA), session_options}; + + // successfully initialized + std::cout << "Successfully initialized WebGPU EP." << std::endl; + retval = 0; + } catch (const std::exception& ex) { + std::cerr << ex.what() << std::endl; + + std::cerr << "Unexpected exception." << std::endl; + retval = -1; + } + + return retval; +} + +int test_main() { + HMODULE hModule = LoadLibraryA("dlls\\onnxruntime.dll"); + if (hModule == NULL) { + std::cout << "Failed to load dlls\\onnxruntime.dll" << std::endl; + return 1; + } + + int retval = 0; + + using OrtGetApiBaseFunction = decltype(&OrtGetApiBase); + auto fnOrtGetApiBase = (OrtGetApiBaseFunction)GetProcAddress(hModule, "OrtGetApiBase"); + if (fnOrtGetApiBase == NULL) { + std::cout << "Failed to get OrtGetApiBase" << std::endl; + retval = 1; + goto cleanup; + } + Ort::InitApi(fnOrtGetApiBase()->GetApi(ORT_API_VERSION)); + + retval = run(); + +cleanup: + if (hModule != NULL) { + FreeLibrary(hModule); + } + return retval; +} diff --git a/onnxruntime/test/webgpu/external_dawn/main.cc b/onnxruntime/test/webgpu/external_dawn/main.cc index ed8d2eab94ce9..1cb22b131d76b 100644 --- a/onnxruntime/test/webgpu/external_dawn/main.cc +++ b/onnxruntime/test/webgpu/external_dawn/main.cc @@ -1,5 +1,4 @@ // Copyright (c) Microsoft Corporation. All rights reserved. -// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates // Licensed under the MIT License. #include