Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebGPU EP] allows GPUDevice to be released after use #23144

Merged
merged 2 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 63 additions & 40 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,32 +25,9 @@
namespace onnxruntime {
namespace webgpu {

void WebGpuContext::Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info, const void* dawn_proc_table) {
std::call_once(init_flag_, [this, &webgpu_ep_info, dawn_proc_table]() {
// Initialization.Step.1 - Create wgpu::Instance
if (instance_ == nullptr) {
const DawnProcTable* dawn_procs = reinterpret_cast<const DawnProcTable*>(dawn_proc_table);
#if defined(BUILD_DAWN_MONOLITHIC_LIBRARY)
ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn.");
#else
#if !defined(USE_EXTERNAL_DAWN)
if (dawn_procs == nullptr) {
dawn_procs = &dawn::native::GetProcs();
}
#else
ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided.");
#endif
dawnProcSetProcs(dawn_procs);
#endif

wgpu::InstanceDescriptor instance_desc{};
instance_desc.features.timedWaitAnyEnable = true;
instance_ = wgpu::CreateInstance(&instance_desc);

ORT_ENFORCE(instance_ != nullptr, "Failed to create wgpu::Instance.");
}

// Initialization.Step.2 - Create wgpu::Adapter
void WebGpuContext::Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type) {
std::call_once(init_flag_, [this, &buffer_cache_config, backend_type]() {
// Create wgpu::Adapter
if (adapter_ == nullptr) {
#if !defined(__EMSCRIPTEN__) && defined(_MSC_VER) && defined(DAWN_ENABLE_D3D12) && !defined(USE_EXTERNAL_DAWN)
// If we are using the D3D12 backend on Windows and the build does not use external Dawn, dxil.dll and dxcompiler.dll are required.
Expand Down Expand Up @@ -79,7 +56,7 @@
wgpu::RequestAdapterOptions req_adapter_options = {};
wgpu::DawnTogglesDescriptor adapter_toggles_desc = {};
req_adapter_options.nextInChain = &adapter_toggles_desc;
req_adapter_options.backendType = static_cast<wgpu::BackendType>(webgpu_ep_info.backend_type);
req_adapter_options.backendType = static_cast<wgpu::BackendType>(backend_type);
req_adapter_options.powerPreference = wgpu::PowerPreference::HighPerformance;

auto enabled_adapter_toggles = GetEnabledAdapterToggles();
Expand All @@ -98,7 +75,7 @@
ORT_ENFORCE(adapter_ != nullptr, "Failed to get a WebGPU adapter.");
}

// Initialization.Step.3 - Create wgpu::Device
// Create wgpu::Device
if (device_ == nullptr) {
wgpu::DeviceDescriptor device_desc = {};
wgpu::DawnTogglesDescriptor device_toggles_desc = {};
Expand Down Expand Up @@ -150,7 +127,10 @@
device_limits_ = device_supported_limits.limits;

// create buffer manager
buffer_mgr_ = BufferManagerFactory::Create(*this, webgpu_ep_info.storage_buffer_cache_mode, webgpu_ep_info.uniform_buffer_cache_mode, webgpu_ep_info.query_resolve_buffer_cache_mode);
buffer_mgr_ = BufferManagerFactory::Create(*this,
buffer_cache_config.storage.mode,
buffer_cache_config.uniform.mode,
buffer_cache_config.query_resolve.mode);

// create program manager
program_mgr_ = std::make_unique<ProgramManager>(Device(), DeviceLimits());
Expand Down Expand Up @@ -661,18 +641,46 @@
num_pending_dispatches_ = 0;
}

std::unordered_map<int32_t, std::unique_ptr<WebGpuContext>> WebGpuContextFactory::contexts_;
std::unordered_map<int32_t, WebGpuContextFactory::WebGpuContextInfo> WebGpuContextFactory::contexts_;

Check warning on line 644 in onnxruntime/core/providers/webgpu/webgpu_context.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/webgpu_context.cc:644: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
std::mutex WebGpuContextFactory::mutex_;
std::once_flag WebGpuContextFactory::init_default_flag_;
wgpu::Instance WebGpuContextFactory::default_instance_;

WebGpuContext& WebGpuContextFactory::CreateContext(const WebGpuContextConfig& config) {
const int context_id = config.context_id;
WGPUInstance instance = config.instance;
WGPUAdapter adapter = config.adapter;
WGPUDevice device = config.device;

WebGpuContext& WebGpuContextFactory::CreateContext(int context_id,
WGPUInstance instance,
WGPUAdapter adapter,
WGPUDevice device,
ValidationMode validation_mode) {
if (context_id == 0) {
// context ID is preserved for the default context. User cannot use context ID 0 as a custom context.
ORT_ENFORCE(instance == nullptr && adapter == nullptr && device == nullptr,
"WebGPU EP default context (contextId=0) must not have custom WebGPU instance, adapter or device.");

std::call_once(init_default_flag_, [dawn_proc_table = config.dawn_proc_table]() {
// Step.1 - setup dawn proc table
const DawnProcTable* dawn_procs = reinterpret_cast<const DawnProcTable*>(dawn_proc_table);
#if defined(BUILD_DAWN_MONOLITHIC_LIBRARY)
ORT_ENFORCE(dawn_procs == nullptr, "setting DawnProcTable is not allowed when dynamically linked to webgpu_dawn.");
#else
#if !defined(USE_EXTERNAL_DAWN)
if (dawn_procs == nullptr) {
dawn_procs = &dawn::native::GetProcs();
}
#else
ORT_ENFORCE(dawn_procs != nullptr, "DawnProcTable must be provided.");
#endif
dawnProcSetProcs(dawn_procs);
#endif

// Step.2 - Create wgpu::Instance
wgpu::InstanceDescriptor instance_desc{};
instance_desc.features.timedWaitAnyEnable = true;
default_instance_ = wgpu::CreateInstance(&instance_desc);

ORT_ENFORCE(default_instance_ != nullptr, "Failed to create wgpu::Instance.");
});
instance = default_instance_.Get();
} else {
// for context ID > 0, user must provide custom WebGPU instance, adapter and device.
ORT_ENFORCE(instance != nullptr && adapter != nullptr && device != nullptr,
Expand All @@ -684,13 +692,16 @@
auto it = contexts_.find(context_id);
if (it == contexts_.end()) {
GSL_SUPPRESS(r.11)
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, adapter, device, validation_mode));
it = contexts_.emplace(context_id, std::move(context)).first;
auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext(instance, adapter, device, config.validation_mode));
it = contexts_.emplace(context_id, WebGpuContextFactory::WebGpuContextInfo{std::move(context), 0}).first;

Check warning on line 696 in onnxruntime/core/providers/webgpu/webgpu_context.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for move [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/webgpu_context.cc:696: Add #include <utility> for move [build/include_what_you_use] [4]
} else if (context_id != 0) {
ORT_ENFORCE(it->second->instance_.Get() == instance && it->second->adapter_.Get() == adapter && it->second->device_.Get() == device,
ORT_ENFORCE(it->second.context->instance_.Get() == instance &&
it->second.context->adapter_.Get() == adapter &&
it->second.context->device_.Get() == device,
"WebGPU EP context ID ", context_id, " is already created with different WebGPU instance, adapter or device.");
}
return *it->second;
it->second.ref_count++;
return *it->second.context;
}

WebGpuContext& WebGpuContextFactory::GetContext(int context_id) {
Expand All @@ -699,12 +710,24 @@
auto it = contexts_.find(context_id);
ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found.");

return *it->second;
return *it->second.context;
}

void WebGpuContextFactory::ReleaseContext(int context_id) {
std::lock_guard<std::mutex> lock(mutex_);

auto it = contexts_.find(context_id);
ORT_ENFORCE(it != contexts_.end(), "WebGPU EP context ID ", context_id, " is not found.");

if (--it->second.ref_count == 0) {
contexts_.erase(it);
}
}

void WebGpuContextFactory::Cleanup() {
std::lock_guard<std::mutex> lock(mutex_);
contexts_.clear();
default_instance_ = nullptr;
}

void CleanupWebGpuContexts() {
Expand Down
39 changes: 32 additions & 7 deletions onnxruntime/core/providers/webgpu/webgpu_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,28 +26,53 @@
class ComputeContext;
class ProgramBase;

struct WebGpuContextConfig {
int context_id;
WGPUInstance instance;
WGPUAdapter adapter;
WGPUDevice device;
const void* dawn_proc_table;
ValidationMode validation_mode;
};

struct WebGpuBufferCacheConfig {
struct ConfigEntry {
BufferCacheMode mode;
std::string config_string;
};
ConfigEntry storage;
ConfigEntry uniform;
ConfigEntry query_resolve;
ConfigEntry default_entry;
};

class WebGpuContextFactory {
public:
static WebGpuContext& CreateContext(int context_id,
WGPUInstance instance,
WGPUAdapter adapter,
WGPUDevice device,
ValidationMode validation_mode);
struct WebGpuContextInfo {
std::unique_ptr<WebGpuContext> context;
int ref_count;
};

static WebGpuContext& CreateContext(const WebGpuContextConfig& config);
static WebGpuContext& GetContext(int context_id);

static void ReleaseContext(int context_id);

static void Cleanup();

private:
WebGpuContextFactory() {}

static std::unordered_map<int32_t, std::unique_ptr<WebGpuContext>> contexts_;
static std::unordered_map<int32_t, WebGpuContextInfo> contexts_;

Check warning on line 66 in onnxruntime/core/providers/webgpu/webgpu_context.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webgpu/webgpu_context.h:66: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
static std::mutex mutex_;
static std::once_flag init_default_flag_;
static wgpu::Instance default_instance_;
};

// Class WebGpuContext includes all necessary resources for the context.
class WebGpuContext final {
public:
void Initialize(const WebGpuExecutionProviderInfo& webgpu_ep_info, const void* dawn_proc_table);
void Initialize(const WebGpuBufferCacheConfig& buffer_cache_config, int backend_type);

Status Wait(wgpu::Future f);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -743,13 +743,13 @@ using namespace webgpu;

WebGpuExecutionProvider::WebGpuExecutionProvider(int context_id,
WebGpuContext& context,
WebGpuExecutionProviderInfo&& info)
WebGpuExecutionProviderConfig&& config)
: IExecutionProvider{kWebGpuExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)},
context_id_{context_id},
context_{context},
preferred_data_layout_{info.data_layout},
force_cpu_node_names_{std::move(info.force_cpu_node_names)},
enable_graph_capture_{info.enable_graph_capture} {
preferred_data_layout_{config.data_layout},
force_cpu_node_names_{std::move(config.force_cpu_node_names)},
enable_graph_capture_{config.enable_graph_capture} {
}

std::vector<AllocatorPtr> WebGpuExecutionProvider::CreatePreferredAllocators() {
Expand Down Expand Up @@ -824,6 +824,7 @@ std::unique_ptr<onnxruntime::IDataTransfer> WebGpuExecutionProvider::GetDataTran
}

WebGpuExecutionProvider::~WebGpuExecutionProvider() {
WebGpuContextFactory::ReleaseContext(context_id_);
}

std::unique_ptr<profiling::EpProfiler> WebGpuExecutionProvider::GetProfiler() {
Expand Down
24 changes: 7 additions & 17 deletions onnxruntime/core/providers/webgpu/webgpu_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,22 @@ enum class BufferCacheMode;
class WebGpuProfiler;
} // namespace webgpu

struct WebGpuExecutionProviderInfo {
WebGpuExecutionProviderInfo(DataLayout data_layout, bool enable_graph_capture)
struct WebGpuExecutionProviderConfig {
WebGpuExecutionProviderConfig(DataLayout data_layout, bool enable_graph_capture)
: data_layout{data_layout},
enable_graph_capture{enable_graph_capture},
backend_type{},
storage_buffer_cache_mode{},
uniform_buffer_cache_mode{},
query_resolve_buffer_cache_mode{},
default_buffer_cache_mode{} {}
WebGpuExecutionProviderInfo(WebGpuExecutionProviderInfo&&) = default;
WebGpuExecutionProviderInfo& operator=(WebGpuExecutionProviderInfo&&) = default;
ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderInfo);
enable_graph_capture{enable_graph_capture} {}
WebGpuExecutionProviderConfig(WebGpuExecutionProviderConfig&&) = default;
WebGpuExecutionProviderConfig& operator=(WebGpuExecutionProviderConfig&&) = default;
ORT_DISALLOW_COPY_AND_ASSIGNMENT(WebGpuExecutionProviderConfig);

DataLayout data_layout;
bool enable_graph_capture;
int backend_type;
webgpu::BufferCacheMode storage_buffer_cache_mode;
webgpu::BufferCacheMode uniform_buffer_cache_mode;
webgpu::BufferCacheMode query_resolve_buffer_cache_mode;
webgpu::BufferCacheMode default_buffer_cache_mode;
std::vector<std::string> force_cpu_node_names;
};

class WebGpuExecutionProvider : public IExecutionProvider {
public:
WebGpuExecutionProvider(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderInfo&& info);
WebGpuExecutionProvider(int context_id, webgpu::WebGpuContext& context, WebGpuExecutionProviderConfig&& config);
~WebGpuExecutionProvider() override;

std::vector<std::unique_ptr<ComputeCapability>> GetCapability(
Expand Down
Loading
Loading