From be69d25bd640ad5ac5f403f1b120e0afe1aae766 Mon Sep 17 00:00:00 2001 From: Yulong Wang <7679871+fs-eire@users.noreply.github.com> Date: Wed, 18 Dec 2024 09:35:51 -0800 Subject: [PATCH] [WebGPU EP] allows GPUDevice to be released after use --- .../core/providers/webgpu/webgpu_context.cc | 103 +++++---- .../core/providers/webgpu/webgpu_context.h | 39 +++- .../webgpu/webgpu_execution_provider.cc | 9 +- .../webgpu/webgpu_execution_provider.h | 24 +- .../webgpu/webgpu_provider_factory.cc | 217 ++++++++++-------- 5 files changed, 225 insertions(+), 167 deletions(-) diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index d66c2a79d28a8..052049db0e1e0 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -23,37 +23,14 @@ namespace onnxruntime { namespace webgpu { -void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info, const void* dawn_proc_table) { - std::call_once(init_flag_, [this, &webgpu_ep_info, dawn_proc_table]() { - // Initialization.Step.1 - Create wgpu::Instance - if (instance_ == nullptr) { - const DawnProcTable* dawn_procs = reinterpret_cast(dawn_proc_table); -#if defined(BUILD_DAWN_MONOLITHIC_LIBRARY) - ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn."); -#else -#if !defined(USE_EXTERNAL_DAWN) - if (dawn_procs == nullptr) { - dawn_procs = &dawn::native::GetProcs(); - } -#else - ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided."); -#endif - dawnProcSetProcs(dawn_procs); -#endif - - wgpu::InstanceDescriptor instance_desc{}; - instance_desc.features.timedWaitAnyEnable = true; - instance_ = wgpu::CreateInstance(&instance_desc); - - ORT_ENFORCE(instance_ != nullptr, "Failed to create wgpu::Instance."); - } - - // Initialization.Step.2 - Create wgpu::Adapter +void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type) { + std::call_once(init_flag_, [this, &buffer_cache_config, backend_type]() { + // Create wgpu::Adapter if (adapter_ == nullptr) { wgpu::RequestAdapterOptions req_adapter_options = {}; wgpu::DawnTogglesDescriptor adapter_toggles_desc = {}; req_adapter_options.nextInChain = &adapter_toggles_desc; - req_adapter_options.backendType = static_cast(webgpu_ep_info.backend_type); + req_adapter_options.backendType = static_cast(backend_type); req_adapter_options.powerPreference = wgpu::PowerPreference::HighPerformance; auto enabled_adapter_toggles = GetEnabledAdapterToggles(); @@ -72,7 +49,7 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info ORT_ENFORCE(adapter_ != nullptr, "Failed to get a WebGPU adapter."); } - // Initialization.Step.3 - Create wgpu::Device + // Create wgpu::Device if (device_ == nullptr) { wgpu::DeviceDescriptor device_desc = {}; wgpu::DawnTogglesDescriptor device_toggles_desc = {}; @@ -124,7 +101,10 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info device_limits_ = device_supported_limits.limits; // create buffer manager - buffer_mgr_ = BufferManagerFactory::Create(*this, webgpu_ep_info.storage_buffer_cache_mode, webgpu_ep_info.uniform_buffer_cache_mode, webgpu_ep_info.query_resolve_buffer_cache_mode); + buffer_mgr_ = BufferManagerFactory::Create(*this, + buffer_cache_config.storage.mode, + buffer_cache_config.uniform.mode, + buffer_cache_config.query_resolve.mode); // create program manager program_mgr_ = std::make_unique(Device(), DeviceLimits()); @@ -635,18 +615,46 @@ void WebGpuContext::Flush() { num_pending_dispatches_ = 0; } -std::unordered_map> WebGpuContextFactory::contexts_; +std::unordered_map WebGpuContextFactory::contexts_; std::mutex WebGpuContextFactory::mutex_; +std::once_flag WebGpuContextFactory::init_default_flag_; +wgpu::Instance WebGpuContextFactory::default_instance_; + +WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& config) { + const int context_id = config.context_id; + WGPUInstance instance = config.instance; + WGPUAdapter adapter = config.adapter; + WGPUDevice device = config.device; -WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, - WGPUInstance instance, - WGPUAdapter adapter, - WGPUDevice device, - ValidationMode validation_mode) { if (context_id == 0) { // context ID is preserved for the default context. User cannot use context ID 0 as a custom context. ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr, "WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device."); + + std::call_once(init_default_flag_, [dawn_proc_table = config.dawn_proc_table]() { + // Step.1 - setup dawn proc table + const DawnProcTable* dawn_procs = reinterpret_cast(dawn_proc_table); +#if defined(BUILD_DAWN_MONOLITHIC_LIBRARY) + ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn."); +#else +#if !defined(USE_EXTERNAL_DAWN) + if (dawn_procs == nullptr) { + dawn_procs = &dawn::native::GetProcs(); + } +#else + ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided."); +#endif + dawnProcSetProcs(dawn_procs); +#endif + + // Step.2 - Create wgpu::Instance + wgpu::InstanceDescriptor instance_desc{}; + instance_desc.features.timedWaitAnyEnable = true; + default_instance_ = wgpu::CreateInstance(&instance_desc); + + ORT_ENFORCE(default_instance_ != nullptr, "Failed to create wgpu::Instance."); + }); + instance = default_instance_.Get(); } else { // for context ID > 0, user must provide custom WebGPU instance, adapter and device. ORT_ENFORCE(instance != nullptr && adapter != nullptr && device != nullptr, @@ -658,13 +666,16 @@ WebGpuContext& WebGpuContextFactory::CreateContext(int context_id, auto it = contexts_.find(context_id); if (it == contexts_.end()) { GSL_SUPPRESS(r.11) - auto context = std::unique_ptr(new WebGpuContext(instance, adapter, device, validation_mode)); - it = contexts_.emplace(context_id, std::move(context)).first; + auto context = std::unique_ptr(new WebGpuContext(instance, adapter, device, config.validation_mode)); + it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first; } else if (context_id != 0) { - ORT_ENFORCE(it->second->instance_.Get() == instance && it->second->adapter_.Get() == adapter && it->second->device_.Get() == device, + ORT_ENFORCE(it->second.context->instance_.Get() == instance && + it->second.context->adapter_.Get() == adapter && + it->second.context->device_.Get() == device, "WebGPU EP context ID ", context_id, " is already created with different WebGPU instance, adapter or device."); } - return *it->second; + it->second.ref_count++; + return *it->second.context; } WebGpuContext& WebGpuContextFactory::GetContext(int context_id) { @@ -673,12 +684,24 @@ WebGpuContext& WebGpuContextFactory::GetContext(int context_id) { auto it = contexts_.find(context_id); ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found."); - return *it->second; + return *it->second.context; +} + +void WebGpuContextFactory::ReleaseContext(int context_id) { + std::lock_guard lock(mutex_); + + auto it = contexts_.find(context_id); + ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found."); + + if (--it->second.ref_count == 0) { + contexts_.erase(it); + } } void WebGpuContextFactory::Cleanup() { std::lock_guard lock(mutex_); contexts_.clear(); + default_instance_ = nullptr; } void CleanupWebGpuContexts() { diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index be05b06523b9c..6c0b60a6d0245 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -25,28 +25,53 @@ class WebGpuContext; class ComputeContext; class ProgramBase; +struct WebGpuContextConfig { + int context_id; + WGPUInstance instance; + WGPUAdapter adapter; + WGPUDevice device; + const void* dawn_proc_table; + ValidationMode validation_mode; +}; + +struct WebGpuBufferCacheConfig { + struct ConfigEntry { + BufferCacheMode mode; + std::string config_string; + }; + ConfigEntry storage; + ConfigEntry uniform; + ConfigEntry query_resolve; + ConfigEntry default_entry; +}; + class WebGpuContextFactory { public: - static WebGpuContext& CreateContext(int context_id, - WGPUInstance instance, - WGPUAdapter adapter, - WGPUDevice device, - ValidationMode validation_mode); + struct WebGpuContextInfo { + std::unique_ptr context; + int ref_count; + }; + + static WebGpuContext& CreateContext(const WebGpuContextConfig& config); static WebGpuContext& GetContext(int context_id); + static void ReleaseContext(int context_id); + static void Cleanup(); private: WebGpuContextFactory() {} - static std::unordered_map> contexts_; + static std::unordered_map contexts_; static std::mutex mutex_; + static std::once_flag init_default_flag_; + static wgpu::Instance default_instance_; }; // Class WebGpuContext includes all necessary resources for the context. class WebGpuContext final { public: - void Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info, const void* dawn_proc_table); + void Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type); Status Wait(wgpu::Future f); diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 66209adf6f1a9..65fff8728a994 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -743,13 +743,13 @@ using namespace webgpu; WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id, WebGpuContext& context, - WebGpuExecutionProviderInfo&& info) + WebGpuExecutionProviderConfig&& config) : IExecutionProvider{kWebGpuExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)}, context_id_{context_id}, context_{context}, - preferred_data_layout_{info.data_layout}, - force_cpu_node_names_{std::move(info.force_cpu_node_names)}, - enable_graph_capture_{info.enable_graph_capture} { + preferred_data_layout_{config.data_layout}, + force_cpu_node_names_{std::move(config.force_cpu_node_names)}, + enable_graph_capture_{config.enable_graph_capture} { } std::vector WebGpuExecutionProvider::CreatePreferredAllocators() { @@ -824,6 +824,7 @@ std::unique_ptr WebGpuExecutionProvider::GetDataTran } WebGpuExecutionProvider::~WebGpuExecutionProvider() { + WebGpuContextFactory::ReleaseContext(context_id_); } std::unique_ptr WebGpuExecutionProvider::GetProfiler() { diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index f9c43c6bfd7d0..ad81924e06901 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -22,32 +22,22 @@ enum class BufferCacheMode; class WebGpuProfiler; } // namespace webgpu -struct WebGpuExecutionProviderInfo { - WebGpuExecutionProviderInfo(DataLayout data_layout, bool enable_graph_capture) +struct WebGpuExecutionProviderConfig { + WebGpuExecutionProviderConfig(DataLayout data_layout, bool enable_graph_capture) : data_layout{data_layout}, - enable_graph_capture{enable_graph_capture}, - backend_type{}, - storage_buffer_cache_mode{}, - uniform_buffer_cache_mode{}, - query_resolve_buffer_cache_mode{}, - default_buffer_cache_mode{} {} - WebGpuExecutionProviderInfo(WebGpuExecutionProviderInfo&&) = default; - WebGpuExecutionProviderInfo& operator=(WebGpuExecutionProviderInfo&&) = default; - ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderInfo); + enable_graph_capture{enable_graph_capture} {} + WebGpuExecutionProviderConfig(WebGpuExecutionProviderConfig&&) = default; + WebGpuExecutionProviderConfig& operator=(WebGpuExecutionProviderConfig&&) = default; + ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderConfig); DataLayout data_layout; bool enable_graph_capture; - int backend_type; - webgpu::BufferCacheMode storage_buffer_cache_mode; - webgpu::BufferCacheMode uniform_buffer_cache_mode; - webgpu::BufferCacheMode query_resolve_buffer_cache_mode; - webgpu::BufferCacheMode default_buffer_cache_mode; std::vector force_cpu_node_names; }; class WebGpuExecutionProvider : public IExecutionProvider { public: - WebGpuExecutionProvider(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderInfo&& info); + WebGpuExecutionProvider(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderConfig&& config); ~WebGpuExecutionProvider() override; std::vector> GetCapability( diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 6cfe9aac0b0e9..64eb80b26fbf9 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -17,25 +17,25 @@ using namespace onnxruntime::webgpu::options; namespace onnxruntime { struct WebGpuProviderFactory : IExecutionProviderFactory { - WebGpuProviderFactory(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderInfo&& webgpu_ep_info) - : context_id_{context_id}, context_{context}, info_{std::move(webgpu_ep_info)} { + WebGpuProviderFactory(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderConfig&& webgpu_ep_config) + : context_id_{context_id}, context_{context}, config_{std::move(webgpu_ep_config)} { } std::unique_ptr CreateProvider() override { - return std::make_unique(context_id_, context_, std::move(info_)); + return std::make_unique(context_id_, context_, std::move(config_)); } private: int context_id_; webgpu::WebGpuContext& context_; - WebGpuExecutionProviderInfo info_; + WebGpuExecutionProviderConfig config_; }; std::shared_ptr WebGpuProviderFactoryCreator::Create(const ConfigOptions& config_options) { // - // STEP.1 - prepare WebGpuExecutionProviderInfo + // STEP.1 - prepare WebGpuExecutionProviderConfig // - WebGpuExecutionProviderInfo webgpu_ep_info{ + WebGpuExecutionProviderConfig webgpu_ep_config{ // preferred layout is NHWC by default DataLayout::NHWC, // graph capture feature is disabled by default @@ -45,109 +45,33 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( std::string preferred_layout_str; if (config_options.TryGetConfigEntry(kPreferredLayout, preferred_layout_str)) { if (preferred_layout_str == kPreferredLayout_NHWC) { - webgpu_ep_info.data_layout = DataLayout::NHWC; + webgpu_ep_config.data_layout = DataLayout::NHWC; } else if (preferred_layout_str == kPreferredLayout_NCHW) { - webgpu_ep_info.data_layout = DataLayout::NCHW; + webgpu_ep_config.data_layout = DataLayout::NCHW; } else { ORT_THROW("Invalid preferred layout: ", preferred_layout_str); } } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP preferred layout: " << int(webgpu_ep_info.data_layout) << " (parsed from \"" + LOGS_DEFAULT(VERBOSE) << "WebGPU EP preferred layout: " << int(webgpu_ep_config.data_layout) << " (parsed from \"" << preferred_layout_str << "\")"; std::string enable_graph_capture_str; if (config_options.TryGetConfigEntry(kEnableGraphCapture, enable_graph_capture_str)) { if (enable_graph_capture_str == kEnableGraphCapture_ON) { - webgpu_ep_info.enable_graph_capture = true; + webgpu_ep_config.enable_graph_capture = true; } else if (enable_graph_capture_str == kEnableGraphCapture_OFF) { - webgpu_ep_info.enable_graph_capture = false; + webgpu_ep_config.enable_graph_capture = false; } else { ORT_THROW("Invalid enable graph capture: ", enable_graph_capture_str); } } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_info.enable_graph_capture; - - std::string backend_type_str; - if (config_options.TryGetConfigEntry(kDawnBackendType, backend_type_str)) { -#ifdef _WIN32 - // Setup Windows default backend type based on the build configuration -#if defined(onnxruntime_ENABLE_DAWN_BACKEND_D3D12) - webgpu_ep_info.backend_type = static_cast(WGPUBackendType_D3D12); -#elif defined(onnxruntime_ENABLE_DAWN_BACKEND_VULKAN) - webgpu_ep_info.backend_type = static_cast(WGPUBackendType_Vulkan); -#endif -#endif - if (backend_type_str == kDawnBackendType_D3D12) { - webgpu_ep_info.backend_type = static_cast(WGPUBackendType_D3D12); - } else if (backend_type_str == kDawnBackendType_Vulkan) { - webgpu_ep_info.backend_type = static_cast(WGPUBackendType_Vulkan); - } else { - ORT_THROW("Invalid Dawn backend type: ", backend_type_str); - } - } - LOGS_DEFAULT(VERBOSE) << "WebGPU EP Dawn backend type: " << webgpu_ep_info.backend_type; - - auto parse_buffer_cache_mode = [&config_options](const std::string& config_entry_str, - webgpu::BufferCacheMode default_value) -> webgpu::BufferCacheMode { - std::string buffer_cache_mode_str; - if (config_options.TryGetConfigEntry(config_entry_str, buffer_cache_mode_str)) { - if (buffer_cache_mode_str == kBufferCacheMode_Disabled) { - return webgpu::BufferCacheMode::Disabled; - } else if (buffer_cache_mode_str == kBufferCacheMode_LazyRelease) { - return webgpu::BufferCacheMode::LazyRelease; - } else if (buffer_cache_mode_str == kBufferCacheMode_Simple) { - return webgpu::BufferCacheMode::Simple; - } else if (buffer_cache_mode_str == kBufferCacheMode_Bucket) { - return webgpu::BufferCacheMode::Bucket; - } else { - ORT_THROW("Invalid buffer cache mode: ", config_entry_str); - } - } else { - return default_value; - } - }; - - webgpu_ep_info.storage_buffer_cache_mode = parse_buffer_cache_mode(kStorageBufferCacheMode, webgpu::BufferCacheMode::Bucket); - LOGS_DEFAULT(VERBOSE) << "WebGPU EP storage buffer cache mode: " << webgpu_ep_info.storage_buffer_cache_mode; - - webgpu_ep_info.uniform_buffer_cache_mode = parse_buffer_cache_mode(kUniformBufferCacheMode, webgpu::BufferCacheMode::Simple); - LOGS_DEFAULT(VERBOSE) << "WebGPU EP uniform buffer cache mode: " << webgpu_ep_info.uniform_buffer_cache_mode; - - webgpu_ep_info.query_resolve_buffer_cache_mode = parse_buffer_cache_mode(kQueryResolveBufferCacheMode, webgpu::BufferCacheMode::Disabled); - LOGS_DEFAULT(VERBOSE) << "WebGPU EP query resolve buffer cache mode: " << webgpu_ep_info.query_resolve_buffer_cache_mode; - - webgpu_ep_info.default_buffer_cache_mode = parse_buffer_cache_mode(kDefaultBufferCacheMode, webgpu::BufferCacheMode::Disabled); - LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << webgpu_ep_info.default_buffer_cache_mode; - - webgpu::ValidationMode validation_mode = -#ifndef NDEBUG - webgpu::ValidationMode::Full // for debug build, enable full validation by default -#else - webgpu::ValidationMode::Basic // for release build, enable basic validation by default -#endif // !NDEBUG - ; - std::string validation_mode_str; - if (config_options.TryGetConfigEntry(kValidationMode, validation_mode_str)) { - if (validation_mode_str == kValidationMode_Disabled) { - validation_mode = webgpu::ValidationMode::Disabled; - } else if (validation_mode_str == kValidationMode_wgpuOnly) { - validation_mode = webgpu::ValidationMode::WGPUOnly; - } else if (validation_mode_str == kValidationMode_basic) { - validation_mode = webgpu::ValidationMode::Basic; - } else if (validation_mode_str == kValidationMode_full) { - validation_mode = webgpu::ValidationMode::Full; - } else { - ORT_THROW("Invalid validation mode: ", validation_mode_str); - } - } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP graph capture enable: " << webgpu_ep_config.enable_graph_capture; // parse force CPU node names // The force CPU node names are separated by EOL (\n or \r\n) in the config entry. // each line is a node name that will be forced to run on CPU. std::string force_cpu_node_names_str; if (config_options.TryGetConfigEntry(kForceCpuNodeNames, force_cpu_node_names_str)) { - std::vector force_cpu_node_names; - // split the string by EOL (\n or \r\n) std::istringstream ss(force_cpu_node_names_str); std::string line; @@ -157,14 +81,13 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( continue; } - force_cpu_node_names.push_back(line); + webgpu_ep_config.force_cpu_node_names.push_back(line); } - - webgpu_ep_info.force_cpu_node_names = std::move(force_cpu_node_names); } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP force CPU node count: " << webgpu_ep_config.force_cpu_node_names.size(); // - // STEP.2 - prepare WebGpuContext + // STEP.2 - prepare WebGpuContextConfig // int context_id = 0; std::string context_id_str; @@ -204,14 +127,110 @@ std::shared_ptr WebGpuProviderFactoryCreator::Create( std::from_chars(dawn_proc_table_str.data(), dawn_proc_table_str.data() + dawn_proc_table_str.size(), dawn_proc_table).ec); } - auto& context = webgpu::WebGpuContextFactory::CreateContext(context_id, - reinterpret_cast(webgpu_instance), - reinterpret_cast(webgpu_adapter), - reinterpret_cast(webgpu_device), - validation_mode); - context.Initialize(webgpu_ep_info, reinterpret_cast(dawn_proc_table)); + webgpu::ValidationMode validation_mode = +#ifndef NDEBUG + webgpu::ValidationMode::Full // for debug build, enable full validation by default +#else + webgpu::ValidationMode::Basic // for release build, enable basic validation by default +#endif // !NDEBUG + ; + std::string validation_mode_str; + if (config_options.TryGetConfigEntry(kValidationMode, validation_mode_str)) { + if (validation_mode_str == kValidationMode_Disabled) { + validation_mode = webgpu::ValidationMode::Disabled; + } else if (validation_mode_str == kValidationMode_wgpuOnly) { + validation_mode = webgpu::ValidationMode::WGPUOnly; + } else if (validation_mode_str == kValidationMode_basic) { + validation_mode = webgpu::ValidationMode::Basic; + } else if (validation_mode_str == kValidationMode_full) { + validation_mode = webgpu::ValidationMode::Full; + } else { + ORT_THROW("Invalid validation mode: ", validation_mode_str); + } + } + + webgpu::WebGpuContextConfig context_config{ + context_id, + reinterpret_cast(webgpu_instance), + reinterpret_cast(webgpu_adapter), + reinterpret_cast(webgpu_device), + reinterpret_cast(dawn_proc_table), + validation_mode, + }; + + // + // STEP.3 - prepare parameters for WebGPU context initialization. + // + + int backend_type = 0; +#ifdef _WIN32 + // Setup Windows default backend type based on the build configuration +#if defined(DAWN_ENABLE_D3D12) + backend_type = static_cast(WGPUBackendType_D3D12); +#elif defined(DAWN_ENABLE_VULKAN) + backend_type = static_cast(WGPUBackendType_Vulkan); +#endif +#endif + + std::string backend_type_str; + if (config_options.TryGetConfigEntry(kDawnBackendType, backend_type_str)) { + if (backend_type_str == kDawnBackendType_D3D12) { + backend_type = static_cast(WGPUBackendType_D3D12); + } else if (backend_type_str == kDawnBackendType_Vulkan) { + backend_type = static_cast(WGPUBackendType_Vulkan); + } else { + ORT_THROW("Invalid Dawn backend type: ", backend_type_str); + } + } + LOGS_DEFAULT(VERBOSE) << "WebGPU EP Dawn backend type: " << backend_type; + + // buffer cache modes + auto parse_buffer_cache_mode = [&config_options](const std::string& config_entry_str, + webgpu::BufferCacheMode default_value) -> webgpu::BufferCacheMode { + std::string buffer_cache_mode_str; + if (config_options.TryGetConfigEntry(config_entry_str, buffer_cache_mode_str)) { + if (buffer_cache_mode_str == kBufferCacheMode_Disabled) { + return webgpu::BufferCacheMode::Disabled; + } else if (buffer_cache_mode_str == kBufferCacheMode_LazyRelease) { + return webgpu::BufferCacheMode::LazyRelease; + } else if (buffer_cache_mode_str == kBufferCacheMode_Simple) { + return webgpu::BufferCacheMode::Simple; + } else if (buffer_cache_mode_str == kBufferCacheMode_Bucket) { + return webgpu::BufferCacheMode::Bucket; + } else { + ORT_THROW("Invalid buffer cache mode: ", config_entry_str); + } + } else { + return default_value; + } + }; + + webgpu::WebGpuBufferCacheConfig buffer_cache_config; + + buffer_cache_config.storage.mode = parse_buffer_cache_mode(kStorageBufferCacheMode, webgpu::BufferCacheMode::Bucket); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP storage buffer cache mode: " << buffer_cache_config.storage.mode; + + buffer_cache_config.uniform.mode = parse_buffer_cache_mode(kUniformBufferCacheMode, webgpu::BufferCacheMode::Simple); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP uniform buffer cache mode: " << buffer_cache_config.uniform.mode; + + buffer_cache_config.query_resolve.mode = parse_buffer_cache_mode(kQueryResolveBufferCacheMode, webgpu::BufferCacheMode::Disabled); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP query resolve buffer cache mode: " << buffer_cache_config.query_resolve.mode; + + buffer_cache_config.default_entry.mode = parse_buffer_cache_mode(kDefaultBufferCacheMode, webgpu::BufferCacheMode::Disabled); + LOGS_DEFAULT(VERBOSE) << "WebGPU EP default buffer cache mode: " << buffer_cache_config.default_entry.mode; + + // + // STEP.4 - start initialization. + // + + // Load the Dawn library and create the WebGPU instance and adapter. + auto& context = webgpu::WebGpuContextFactory::CreateContext(context_config); + + // Create WebGPU device and initialize the context. + context.Initialize(buffer_cache_config, backend_type); - return std::make_shared(context_id, context, std::move(webgpu_ep_info)); + // Create WebGPU EP factory. + return std::make_shared(context_id, context, std::move(webgpu_ep_config)); } } // namespace onnxruntime