Skip to content

Commit

Permalink
[WebGPU EP] allows GPUDevice to be released after use (#23144)
Browse files Browse the repository at this point in the history
### Description

This change allows the `WebGpuContext` class to be released after all
active inference sessions are released. This will cause:
- for default context (ID=0), the underlying `wgpu::Device` and
`wgpu::Adapter` to be released, together with all resources created by
the Device.
- for custom context (ID>0), the reference counts of passed in Instance,
Adapter and Device will decrement correctly.
  • Loading branch information
fs-eire authored and guschmue committed Dec 20, 2024
1 parent 3e03adc commit 3f17fe2
Show file tree
Hide file tree
Showing 5 changed files with 225 additions and 167 deletions.
103 changes: 63 additions & 40 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,9 @@
namespace onnxruntime {
namespace webgpu {

void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info, const void* dawn_proc_table) {
std::call_once(init_flag_, [this, &webgpu_ep_info, dawn_proc_table]() {
// Initialization.Step.1 - Create wgpu::Instance
if (instance_ == nullptr) {
const DawnProcTable* dawn_procs = reinterpret_cast<const DawnProcTable*>(dawn_proc_table);
#if defined(BUILD_DAWN_MONOLITHIC_LIBRARY)
ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn.");
#else
#if !defined(USE_EXTERNAL_DAWN)
if (dawn_procs == nullptr) {
dawn_procs = &dawn::native::GetProcs();
}
#else
ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided.");
#endif
dawnProcSetProcs(dawn_procs);
#endif

wgpu::InstanceDescriptor instance_desc{};
instance_desc.features.timedWaitAnyEnable = true;
instance_ = wgpu::CreateInstance(&instance_desc);

ORT_ENFORCE(instance_ != nullptr, "Failed to create wgpu::Instance.");
}

// Initialization.Step.2 - Create wgpu::Adapter
void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type) {
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type]() {
// Create wgpu::Adapter
if (adapter_ == nullptr) {
#if !defined(__EMSCRIPTEN__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN)
// If we are using the D3D12 backend on Windows and the build does not use external Dawn, dxil.dll and dxcompiler.dll are required.
Expand Down Expand Up @@ -79,7 +56,7 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info
wgpu::RequestAdapterOptions req_adapter_options = {};
wgpu::DawnTogglesDescriptor adapter_toggles_desc = {};
req_adapter_options.nextInChain = &adapter_toggles_desc;
req_adapter_options.backendType = static_cast<wgpu::BackendType>(webgpu_ep_info.backend_type);
req_adapter_options.backendType = static_cast<wgpu::BackendType>(backend_type);
req_adapter_options.powerPreference = wgpu::PowerPreference::HighPerformance;

auto enabled_adapter_toggles = GetEnabledAdapterToggles();
Expand All @@ -98,7 +75,7 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info
ORT_ENFORCE(adapter_ != nullptr, "Failed to get a WebGPU adapter.");
}

// Initialization.Step.3 - Create wgpu::Device
// Create wgpu::Device
if (device_ == nullptr) {
wgpu::DeviceDescriptor device_desc = {};
wgpu::DawnTogglesDescriptor device_toggles_desc = {};
Expand Down Expand Up @@ -150,7 +127,10 @@ void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info
device_limits_ = device_supported_limits.limits;

// create buffer manager
buffer_mgr_ = BufferManagerFactory::Create(*this, webgpu_ep_info.storage_buffer_cache_mode, webgpu_ep_info.uniform_buffer_cache_mode, webgpu_ep_info.query_resolve_buffer_cache_mode);
buffer_mgr_ = BufferManagerFactory::Create(*this,
buffer_cache_config.storage.mode,
buffer_cache_config.uniform.mode,
buffer_cache_config.query_resolve.mode);

// create program manager
program_mgr_ = std::make_unique<ProgramManager>(Device(), DeviceLimits());
Expand Down Expand Up @@ -661,18 +641,46 @@ void WebGpuContext::Flush() {
num_pending_dispatches_ = 0;
}

std::unordered_map<int32_t, std::unique_ptr<WebGpuContext>> WebGpuContextFactory::contexts_;
std::unordered_map<int32_t, WebGpuContextFactory::WebGpuContextInfo> WebGpuContextFactory::contexts_;
std::mutex WebGpuContextFactory::mutex_;
std::once_flag WebGpuContextFactory::init_default_flag_;
wgpu::Instance WebGpuContextFactory::default_instance_;

WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& config) {
const int context_id = config.context_id;
WGPUInstance instance = config.instance;
WGPUAdapter adapter = config.adapter;
WGPUDevice device = config.device;

WebGpuContext& WebGpuContextFactory::CreateContext(int context_id,
WGPUInstance instance,
WGPUAdapter adapter,
WGPUDevice device,
ValidationMode validation_mode) {
if (context_id == 0) {
// context ID is preserved for the default context. User cannot use context ID 0 as a custom context.
ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr,
"WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device.");

std::call_once(init_default_flag_, [dawn_proc_table = config.dawn_proc_table]() {
// Step.1 - setup dawn proc table
const DawnProcTable* dawn_procs = reinterpret_cast<const DawnProcTable*>(dawn_proc_table);
#if defined(BUILD_DAWN_MONOLITHIC_LIBRARY)
ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn.");
#else
#if !defined(USE_EXTERNAL_DAWN)
if (dawn_procs == nullptr) {
dawn_procs = &dawn::native::GetProcs();
}
#else
ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided.");
#endif
dawnProcSetProcs(dawn_procs);
#endif

// Step.2 - Create wgpu::Instance
wgpu::InstanceDescriptor instance_desc{};
instance_desc.features.timedWaitAnyEnable = true;
default_instance_ = wgpu::CreateInstance(&instance_desc);

ORT_ENFORCE(default_instance_ != nullptr, "Failed to create wgpu::Instance.");
});
instance = default_instance_.Get();
} else {
// for context ID > 0, user must provide custom WebGPU instance, adapter and device.
ORT_ENFORCE(instance != nullptr && adapter != nullptr && device != nullptr,
Expand All @@ -684,13 +692,16 @@ WebGpuContext& WebGpuContextFactory::CreateContext(int context_id,
auto it = contexts_.find(context_id);
if (it == contexts_.end()) {
GSL_SUPPRESS(r.11)
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, adapter, device, validation_mode));
it = contexts_.emplace(context_id, std::move(context)).first;
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, adapter, device, config.validation_mode));
it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first;
} else if (context_id != 0) {
ORT_ENFORCE(it->second->instance_.Get() == instance && it->second->adapter_.Get() == adapter && it->second->device_.Get() == device,
ORT_ENFORCE(it->second.context->instance_.Get() == instance &&
it->second.context->adapter_.Get() == adapter &&
it->second.context->device_.Get() == device,
"WebGPU EP context ID ", context_id, " is already created with different WebGPU instance, adapter or device.");
}
return *it->second;
it->second.ref_count++;
return *it->second.context;
}

WebGpuContext& WebGpuContextFactory::GetContext(int context_id) {
Expand All @@ -699,12 +710,24 @@ WebGpuContext& WebGpuContextFactory::GetContext(int context_id) {
auto it = contexts_.find(context_id);
ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found.");

return *it->second;
return *it->second.context;
}

void WebGpuContextFactory::ReleaseContext(int context_id) {
std::lock_guard<std::mutex> lock(mutex_);

auto it = contexts_.find(context_id);
ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found.");

if (--it->second.ref_count == 0) {
contexts_.erase(it);
}
}

void WebGpuContextFactory::Cleanup() {
std::lock_guard<std::mutex> lock(mutex_);
contexts_.clear();
default_instance_ = nullptr;
}

void CleanupWebGpuContexts() {
Expand Down
39 changes: 32 additions & 7 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,53 @@ class WebGpuContext;
class ComputeContext;
class ProgramBase;

struct WebGpuContextConfig {
int context_id;
WGPUInstance instance;
WGPUAdapter adapter;
WGPUDevice device;
const void* dawn_proc_table;
ValidationMode validation_mode;
};

struct WebGpuBufferCacheConfig {
struct ConfigEntry {
BufferCacheMode mode;
std::string config_string;
};
ConfigEntry storage;
ConfigEntry uniform;
ConfigEntry query_resolve;
ConfigEntry default_entry;
};

class WebGpuContextFactory {
public:
static WebGpuContext& CreateContext(int context_id,
WGPUInstance instance,
WGPUAdapter adapter,
WGPUDevice device,
ValidationMode validation_mode);
struct WebGpuContextInfo {
std::unique_ptr<WebGpuContext> context;
int ref_count;
};

static WebGpuContext& CreateContext(const WebGpuContextConfig& config);
static WebGpuContext& GetContext(int context_id);

static void ReleaseContext(int context_id);

static void Cleanup();

private:
WebGpuContextFactory() {}

static std::unordered_map<int32_t, std::unique_ptr<WebGpuContext>> contexts_;
static std::unordered_map<int32_t, WebGpuContextInfo> contexts_;
static std::mutex mutex_;
static std::once_flag init_default_flag_;
static wgpu::Instance default_instance_;
};

// Class WebGpuContext includes all necessary resources for the context.
class WebGpuContext final {
public:
void Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info, const void* dawn_proc_table);
void Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type);

Status Wait(wgpu::Future f);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -743,13 +743,13 @@ using namespace webgpu;

WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
WebGpuContext& context,
WebGpuExecutionProviderInfo&& info)
WebGpuExecutionProviderConfig&& config)
: IExecutionProvider{kWebGpuExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)},
context_id_{context_id},
context_{context},
preferred_data_layout_{info.data_layout},
force_cpu_node_names_{std::move(info.force_cpu_node_names)},
enable_graph_capture_{info.enable_graph_capture} {
preferred_data_layout_{config.data_layout},
force_cpu_node_names_{std::move(config.force_cpu_node_names)},
enable_graph_capture_{config.enable_graph_capture} {
}

std::vector<AllocatorPtr> WebGpuExecutionProvider::CreatePreferredAllocators() {
Expand Down Expand Up @@ -824,6 +824,7 @@ std::unique_ptr<onnxruntime::IDataTransfer> WebGpuExecutionProvider::GetDataTran
}

WebGpuExecutionProvider::~WebGpuExecutionProvider() {
WebGpuContextFactory::ReleaseContext(context_id_);
}

std::unique_ptr<profiling::EpProfiler> WebGpuExecutionProvider::GetProfiler() {
Expand Down
24 changes: 7 additions & 17 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,22 @@ enum class BufferCacheMode;
class WebGpuProfiler;
} // namespace webgpu

struct WebGpuExecutionProviderInfo {
WebGpuExecutionProviderInfo(DataLayout data_layout, bool enable_graph_capture)
struct WebGpuExecutionProviderConfig {
WebGpuExecutionProviderConfig(DataLayout data_layout, bool enable_graph_capture)
: data_layout{data_layout},
enable_graph_capture{enable_graph_capture},
backend_type{},
storage_buffer_cache_mode{},
uniform_buffer_cache_mode{},
query_resolve_buffer_cache_mode{},
default_buffer_cache_mode{} {}
WebGpuExecutionProviderInfo(WebGpuExecutionProviderInfo&&) = default;
WebGpuExecutionProviderInfo& operator=(WebGpuExecutionProviderInfo&&) = default;
ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderInfo);
enable_graph_capture{enable_graph_capture} {}
WebGpuExecutionProviderConfig(WebGpuExecutionProviderConfig&&) = default;
WebGpuExecutionProviderConfig& operator=(WebGpuExecutionProviderConfig&&) = default;
ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderConfig);

DataLayout data_layout;
bool enable_graph_capture;
int backend_type;
webgpu::BufferCacheMode storage_buffer_cache_mode;
webgpu::BufferCacheMode uniform_buffer_cache_mode;
webgpu::BufferCacheMode query_resolve_buffer_cache_mode;
webgpu::BufferCacheMode default_buffer_cache_mode;
std::vector<std::string> force_cpu_node_names;
};

class WebGpuExecutionProvider : public IExecutionProvider {
public:
WebGpuExecutionProvider(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderInfo&& info);
WebGpuExecutionProvider(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderConfig&& config);
~WebGpuExecutionProvider() override;

std::vector<std::unique_ptr<ComputeCapability>> GetCapability(
Expand Down
Loading

0 comments on commit 3f17fe2

Please sign in to comment.