diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake index f3f63aa18a7e3..303ab9483c38a 100644 --- a/cmake/onnxruntime_providers_webgpu.cmake +++ b/cmake/onnxruntime_providers_webgpu.cmake @@ -26,3 +26,5 @@ onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_webgpu onnxruntime_common onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface) target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn) + + set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime") diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc index b2b5f88eb84e1..91f51df588fca 100644 --- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc @@ -35,25 +35,25 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Sk Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, // default entry to avoid the list become empty after ops-reducing - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo, + // BuildKernelCreateInfo }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc index d3e25637dfd97..8e27acdc285d4 100644 --- a/onnxruntime/core/providers/webgpu/allocator.cc +++ b/onnxruntime/core/providers/webgpu/allocator.cc @@ -20,7 +20,7 @@ void* GpuBufferAllocator::Alloc(size_t size) { auto buffer = context_.BufferManager().Create(size); stats_.num_allocs++; - return buffer.Get(); + return buffer; } void GpuBufferAllocator::Free(void* p) { diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h index d139c9eb9574c..51ca65a8b4822 100644 --- a/onnxruntime/core/providers/webgpu/allocator.h +++ b/onnxruntime/core/providers/webgpu/allocator.h @@ -9,6 +9,8 @@ namespace onnxruntime { namespace webgpu { +class WebGpuContext; + class GpuBufferAllocator : public IAllocator { public: GpuBufferAllocator(const WebGpuContext& context) diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc index d0f4b1142caf2..8722eba77eaa1 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.cc +++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc @@ -2,23 +2,28 @@ // Licensed under the MIT License. #include "core/providers/webgpu/buffer_manager.h" +#include "core/providers/webgpu/webgpu_context.h" namespace onnxruntime { namespace webgpu { +size_t NormalizeBufferSize(size_t size) { + return (size + 15) / 16 * 16; +} + class DisabledCacheManager : public IBufferCacheManager { size_t CalculateBufferSize(size_t request_size) override { - return (request_size + 15) / 16 * 16; + return NormalizeBufferSize(request_size); } WGPUBuffer TryAcquireCachedBuffer(size_t /*buffer_size*/, wgpu::BufferUsage /*usage*/) override { // always return empty buffer return nullptr; } - void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/, size_t /*buffer_size*/, wgpu::BufferUsage /*usage*/) override { + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { // no-op } - void ReleaseBuffer(WGPUBuffer buffer, size_t /*buffer_size*/, wgpu::BufferUsage /*usage*/) override { + void ReleaseBuffer(WGPUBuffer buffer) override { wgpuBufferDestroy(buffer); } @@ -29,7 +34,7 @@ class DisabledCacheManager : public IBufferCacheManager { class SimpleCacheManager : public IBufferCacheManager { size_t CalculateBufferSize(size_t request_size) override { - return (request_size + 15) / 16 * 16; + return NormalizeBufferSize(request_size); } WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size, wgpu::BufferUsage usage) override { @@ -44,49 +49,33 @@ class SimpleCacheManager : public IBufferCacheManager { return nullptr; } - void RegisterBuffer(WGPUBuffer buffer, size_t /*request_size*/, size_t buffer_size, wgpu::BufferUsage usage) override { + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op } - void ReleaseBuffer(WGPUBuffer buffer, size_t buffer_size, wgpu::BufferUsage usage) override { - if (usage | wgpu::BufferUsage::Storage) { - pending_buffers_.emplace_back(buffer, buffer_size); + + void ReleaseBuffer(WGPUBuffer buffer) override { + auto usage = wgpuBufferGetUsage(buffer); + if (usage | WGPUBufferUsage_Storage) { + pending_buffers_.emplace_back(buffer); } else { wgpuBufferDestroy(buffer); } } + void OnRefresh() override { - for (auto& pair : pending_buffers_) { - buffers_[pair.second].push_back(pair.first); + for (auto& buffer : pending_buffers_) { + buffers_[wgpuBufferGetSize(buffer)].push_back(buffer); } pending_buffers_.clear(); } std::map> buffers_; - std::vector> pending_buffers_; + std::vector pending_buffers_; }; -class BucketCacheManager : public IBufferCacheManager { - static const std::unordered_map kBucketSizes; - static const std::array kBucketSizesArray; - - size_t CalculateBufferSize(size_t request_size) override { - ORT_NOT_IMPLEMENTED("TODO"); - } - - WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size, wgpu::BufferUsage usage) override { - ORT_NOT_IMPLEMENTED("TODO"); - } - void RegisterBuffer(WGPUBuffer buffer, size_t request_size, size_t buffer_size, wgpu::BufferUsage usage) override { - ORT_NOT_IMPLEMENTED("TODO"); - } - void ReleaseBuffer(WGPUBuffer buffer, size_t buffer_size, wgpu::BufferUsage usage) override { - ORT_NOT_IMPLEMENTED("TODO"); - } - void OnRefresh() override { - ORT_NOT_IMPLEMENTED("TODO"); - } -}; - -constexpr std::initializer_list> BUCKET_TABLE = { +// TODO: maybe use different bucket size for storage and uniform buffers? +constexpr std::initializer_list> BUCKET_DEFAULT_LIMIT_TABLE = { {64, 250}, {128, 200}, {256, 200}, @@ -116,8 +105,116 @@ constexpr std::initializer_list> BUCKET_TABLE = { {134217728, 6}, {167772160, 6}, }; -const std::unordered_map BucketCacheManager::kBucketSizes{BUCKET_TABLE}; +class BucketCacheManager : public IBufferCacheManager { + public: + BucketCacheManager() : buckets_limit_{BUCKET_DEFAULT_LIMIT_TABLE} { + Initialize(); + } + BucketCacheManager(std::unordered_map&& buckets_limit) : buckets_limit_{buckets_limit} { + Initialize(); + } + + size_t CalculateBufferSize(size_t request_size) override { + // binary serch size + auto it = std::lower_bound(buckets_keys_.begin(), buckets_keys_.end(), request_size); + if (it == buckets_keys_.end()) { + return NormalizeBufferSize(request_size); + } else { + return *it; + } + } + + WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size, wgpu::BufferUsage usage) override { + std::unordered_map>* buckets = nullptr; + if (usage | wgpu::BufferUsage::Storage) { + buckets = &buckets_storage_; + } else if (usage | wgpu::BufferUsage::Uniform) { + buckets = &buckets_uniform_; + } + if (buckets) { + auto it = buckets->find(buffer_size); + if (it != buckets->end() && !it->second.empty()) { + auto buffer = it->second.back(); + it->second.pop_back(); + return buffer; + } + } + return nullptr; + } + + void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override { + // no-op + } + + void ReleaseBuffer(WGPUBuffer buffer) override { + std::vector* pending_buffers = nullptr; + auto usage = wgpuBufferGetUsage(buffer); + if (usage | WGPUBufferUsage_Storage) { + pending_buffers = &pending_storage_buffers_; + } else if (usage | WGPUBufferUsage_Uniform) { + pending_buffers = &pending_uniform_buffers_; + } + if (pending_buffers) { + pending_buffers->emplace_back(buffer); + } else { + wgpuBufferDestroy(buffer); + } + } + + void OnRefresh() override { + // TODO: consider graph capture. currently not supported + + for (auto& buffer : pending_storage_buffers_) { + auto buffer_size = wgpuBufferGetSize(buffer); + + auto it = buckets_storage_.find(buffer_size); + if (it != buckets_storage_.end() && it->second.size() < buckets_limit_[buffer_size]) { + it->second.push_back(buffer); + } else { + wgpuBufferDestroy(buffer); + } + } + + for (auto& buffer : pending_uniform_buffers_) { + auto buffer_size = wgpuBufferGetSize(buffer); + + auto it = buckets_uniform_.find(buffer_size); + if (it != buckets_uniform_.end() && it->second.size() < buckets_limit_[buffer_size]) { + it->second.push_back(buffer); + } else { + wgpuBufferDestroy(buffer); + } + } + } + + protected: + void Initialize() { + buckets_keys_.reserve(buckets_limit_.size()); + buckets_storage_.reserve(buckets_limit_.size()); + buckets_uniform_.reserve(buckets_limit_.size()); + for (const auto& pair : buckets_limit_) { + buckets_keys_.push_back(pair.first); + buckets_storage_.emplace(pair.first, std::vector()); + buckets_uniform_.emplace(pair.first, std::vector()); + } +#ifndef NDEBUG + for (size_t i = 0; i < buckets_keys_.size(); ++i) { + ORT_ENFORCE(buckets_keys_[i] % 16 == 0, "Bucket sizes must be multiples of 16."); + } + + for (size_t i = 1; i < buckets_keys_.size(); ++i) { + ORT_ENFORCE(buckets_keys_[i] > buckets_keys_[i - 1], "Bucket sizes must be in increasing order."); + } +#endif + } + std::unordered_map buckets_limit_; + std::unordered_map> buckets_storage_; + std::vector pending_storage_buffers_; + std::unordered_map> buckets_uniform_; + std::vector pending_uniform_buffers_; + std::vector buckets_keys_; +}; std::unique_ptr CreateBufferCacheManager(BufferCacheMode cache_mode) { switch (cache_mode) { @@ -132,13 +229,51 @@ std::unique_ptr CreateBufferCacheManager(BufferCacheMode ca } } -BufferManager::BufferManager(wgpu::Device device, BufferCacheMode cache_mode) : IBufferManager{device, CreateBufferCacheManager(cache_mode)} { +class BufferManager : public IBufferManager { + public: + BufferManager(const WebGpuContext& context, BufferCacheMode cache_mode); + + void Upload(void* src, WGPUBuffer dst, size_t size) const override; + void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const override; + WGPUBuffer Create(size_t size, wgpu::BufferUsage usage) const override; + void Release(WGPUBuffer buffer) const override; + void Download(WGPUBuffer src, void* dst, size_t size) const override; + void RefreshPendingBuffers() const override; +}; + +BufferManager::BufferManager(const WebGpuContext& context, BufferCacheMode cache_mode) : IBufferManager{context, CreateBufferCacheManager(cache_mode)} { } void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) const { + auto buffer_size = NormalizeBufferSize(size); + + wgpu::BufferDescriptor desc; + desc.size = buffer_size; + desc.usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapWrite; + desc.mappedAtCreation = true; + + auto staging_buffer = context_.Device().CreateBuffer(&desc); + auto mapped_data = staging_buffer.GetMappedRange(); + memcpy(mapped_data, src, size); + staging_buffer.Unmap(); + + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(staging_buffer, 0, dst, 0, buffer_size); + staging_buffer.Destroy(); } void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const { + ORT_ENFORCE(src != dst, "Source and destination buffers must be different."); + + auto buffer_size = NormalizeBufferSize(size); + ORT_ENFORCE(buffer_size <= wgpuBufferGetSize(src) && buffer_size <= wgpuBufferGetSize(dst), + "Source and destination buffers must have enough space for the copy operation. src_size=", + wgpuBufferGetSize(src), ", dst_size=", wgpuBufferGetSize(dst), ", copy_size=", buffer_size, "."); + + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(src, 0, dst, 0, buffer_size); } WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) const { @@ -153,25 +288,49 @@ WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) const { wgpu::BufferDescriptor desc; desc.size = buffer_size; desc.usage = usage; - buffer = device_.CreateBuffer(&desc); + buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle(); - ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", usage, "."); + ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", uint64_t(usage), "."); - cache_->RegisterBuffer(buffer, size, buffer_size, usage); + cache_->RegisterBuffer(buffer, size); return buffer; } void BufferManager::Release(WGPUBuffer buffer) const { - cache_->ReleaseBuffer(buffer, 0, wgpu::BufferUsage::None); + cache_->ReleaseBuffer(buffer); } -wgpu::Future BufferManager::Download(WGPUBuffer src, void* dst, size_t size) const { - return wgpu::Future(); +void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) const { + auto buffer_size = NormalizeBufferSize(size); + + wgpu::BufferDescriptor desc; + desc.size = buffer_size; + desc.usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead; + + auto staging_buffer = context_.Device().CreateBuffer(&desc); + auto& command_encoder = context_.GetCommandEncoder(); + context_.EndComputePass(); + command_encoder.CopyBufferToBuffer(src, 0, staging_buffer, 0, buffer_size); + context_.Flush(); + + wgpu::BufferMapCallbackInfo callback_info; + callback_info.mode = wgpu::CallbackMode::WaitAnyOnly; + callback_info.callback = [](WGPUBufferMapAsyncStatus status, void*) { + ORT_ENFORCE(status == WGPUBufferMapAsyncStatus_Success, "Failed to download data from buffer"); + }; + ORT_ENFORCE(context_.Wait(staging_buffer.MapAsync(wgpu::MapMode::Read, 0, buffer_size, callback_info)) == Status::OK()); + + auto mapped_data = staging_buffer.GetMappedRange(); + memcpy(dst, mapped_data, size); } void BufferManager::RefreshPendingBuffers() const { cache_->OnRefresh(); } +std::unique_ptr BufferManagerFactory::Create(const WebGpuContext& context, BufferCacheMode mode) { + return std::make_unique(context, mode); +} + } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.h b/onnxruntime/core/providers/webgpu/buffer_manager.h index 8dac698f9ccef..a411c44812339 100644 --- a/onnxruntime/core/providers/webgpu/buffer_manager.h +++ b/onnxruntime/core/providers/webgpu/buffer_manager.h @@ -15,6 +15,8 @@ namespace onnxruntime { namespace webgpu { +class WebGpuContext; + enum class BufferCacheMode { None, Simple, @@ -32,10 +34,10 @@ class IBufferCacheManager { virtual WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size, wgpu::BufferUsage usage) = 0; // register a newly created buffer - virtual void RegisterBuffer(WGPUBuffer buffer, size_t request_size, size_t buffer_size, wgpu::BufferUsage usage) = 0; + virtual void RegisterBuffer(WGPUBuffer buffer, size_t request_size) = 0; // release a buffer - virtual void ReleaseBuffer(WGPUBuffer buffer, size_t buffer_size, wgpu::BufferUsage usage) = 0; + virtual void ReleaseBuffer(WGPUBuffer buffer) = 0; // when a stream refresh is requested virtual void OnRefresh() = 0; @@ -43,7 +45,7 @@ class IBufferCacheManager { class IBufferManager { protected: - IBufferManager(wgpu::Device device, std::unique_ptr cache) : device_{device}, cache_{std::move(cache)} {} + IBufferManager(const WebGpuContext& context, std::unique_ptr cache) : context_{context}, cache_{std::move(cache)} {} public: virtual ~IBufferManager() = default; @@ -51,35 +53,22 @@ class IBufferManager { virtual void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const = 0; virtual WGPUBuffer Create(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst) const = 0; virtual void Release(WGPUBuffer buffer) const = 0; - virtual wgpu::Future Download(WGPUBuffer src, void* dst, size_t size) const = 0; + virtual void Download(WGPUBuffer src, void* dst, size_t size) const = 0; virtual void RefreshPendingBuffers() const = 0; // TODO: add statistics protected: - wgpu::Device device_; + const WebGpuContext& context_; std::unique_ptr cache_; }; -class BufferManager : public IBufferManager { +class BufferManagerFactory { public: - BufferManager(wgpu::Device device, BufferCacheMode cache_mode); - - void Upload(void* src, WGPUBuffer dst, size_t size) const override; - void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const override; - WGPUBuffer Create(size_t size, wgpu::BufferUsage usage) const override; - void Release(WGPUBuffer buffer) const override; - wgpu::Future Download(WGPUBuffer src, void* dst, size_t size) const override; - void RefreshPendingBuffers() const override; + static std::unique_ptr Create(const WebGpuContext& context, BufferCacheMode mode); private: - struct PendingBuffer { - wgpu::Buffer buffer; - void* data; - size_t size; - }; - - std::vector pending_buffers_; + BufferManagerFactory() {} }; } // namespace webgpu diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc new file mode 100644 index 0000000000000..a558b8da68533 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/compute_context.cc @@ -0,0 +1,5 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webgpu/compute_context.h" +#include "core/providers/webgpu/webgpu_context.h" diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h new file mode 100644 index 0000000000000..627454d5ef7cf --- /dev/null +++ b/onnxruntime/core/providers/webgpu/compute_context.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#ifdef __EMSCRIPTEN__ +#include +#endif + +#include + +#include "core/framework/execution_provider.h" + +namespace onnxruntime { +namespace webgpu { + +class WebGpuContext; + +class ComputeContext { + public: + ComputeContext(const WebGpuContext& context) : context_{context} {} + + virtual ~ComputeContext() = default; + + virtual void Dispatch(WGPUComputePassEncoder pass) const = 0; + + protected: + const WebGpuContext& context_; +}; + +} // namespace webgpu +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/data_transfer.cc b/onnxruntime/core/providers/webgpu/data_transfer.cc index d211980128501..615ae11175782 100644 --- a/onnxruntime/core/providers/webgpu/data_transfer.cc +++ b/onnxruntime/core/providers/webgpu/data_transfer.cc @@ -37,8 +37,7 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { } } else /* if (src_device.Type() == OrtDevice::GPU) */ { // copy from GPU to CPU - ORT_RETURN_IF_ERROR(context_.Wait( - context_.BufferManager().Download(static_cast(const_cast(src_data)), dst_data, bytes))); + context_.BufferManager().Download(static_cast(const_cast(src_data)), dst_data, bytes); } } diff --git a/onnxruntime/core/providers/webgpu/data_transfer.h b/onnxruntime/core/providers/webgpu/data_transfer.h index 26f3d159d2b4d..79853483e0c23 100644 --- a/onnxruntime/core/providers/webgpu/data_transfer.h +++ b/onnxruntime/core/providers/webgpu/data_transfer.h @@ -9,6 +9,8 @@ namespace onnxruntime { namespace webgpu { +class WebGpuContext; + class DataTransfer : public IDataTransfer { public: DataTransfer(const WebGpuContext& context) : context_{context} {}; diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc index 114256d2c2218..16672a434b321 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc @@ -19,21 +19,20 @@ Status UnaryElementwise::ComputeInternal(OpKernelContext* context) const { return Status(common::ONNXRUNTIME, common::FAIL); } -#define WEBGPU_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ - ONNX_OPERATOR_KERNEL_EX( \ - OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", TYPE), \ +#define WEBGPU_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \ + ONNX_OPERATOR_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ KERNEL_CLASS); #define WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ - OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ - KernelDefBuilder().TypeConstraint("T", TYPE), \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \ + KernelDefBuilder().TypeConstraint("T", TYPE), \ KERNEL_CLASS); WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Abs, 6, 12, Abs, WebGpuSupportedFloatTypes()) WEBGPU_ELEMENTWISE_KERNEL(Abs, 13, Abs, WebGpuSupportedFloatTypes()) - } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h index 0da0cfa74c150..7e913b6656bc1 100644 --- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h +++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h @@ -26,6 +26,5 @@ class Abs final : public UnaryElementwise { // Status ComputeInternal(OpKernelContext* context) const override; }; - } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc index 1c8831992096b..e4686a355c09e 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc @@ -91,7 +91,7 @@ void WebGpuContext::Initialize() { ORT_ENFORCE(device_.GetLimits(&limits)); // create buffer manager - buffer_mgr_ = std::make_unique(device_); + buffer_mgr_ = BufferManagerFactory::Create(*this, BufferCacheMode::None); }); } @@ -103,12 +103,15 @@ Status WebGpuContext::Wait(wgpu::Future f) const { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status)); } +std::unordered_map> WebGpuContextFactory::contexts_; +std::mutex WebGpuContextFactory::mutex_; + WebGpuContext& WebGpuContextFactory::GetOrCreateContext(int32_t context_id) { std::lock_guard lock(mutex_); auto it = contexts_.find(context_id); if (it == contexts_.end()) { - auto context = std::make_unique(); + auto context = std::unique_ptr(new WebGpuContext()); it = contexts_.emplace(context_id, std::move(context)).first; } return *it->second; diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h index 904507da9bb8c..bdba190278c15 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_context.h +++ b/onnxruntime/core/providers/webgpu/webgpu_context.h @@ -16,17 +16,66 @@ namespace onnxruntime { namespace webgpu { +class WebGpuContext; + +class WebGpuContextFactory { + public: + static WebGpuContext& GetOrCreateContext(int32_t context_id = 0); + + private: + WebGpuContextFactory() {} + + static std::unordered_map> contexts_; + static std::mutex mutex_; +}; // Class WebGpuContext includes all necessary resources for the context. class WebGpuContext final { public: void Initialize(); + // non copyable + WebGpuContext(const WebGpuContext&) = delete; + WebGpuContext& operator=(const WebGpuContext&) = delete; + + // non movable + WebGpuContext(WebGpuContext&&) = delete; + WebGpuContext& operator=(WebGpuContext&&) = delete; + Status Wait(wgpu::Future f) const; const wgpu::Adapter& Adapter() const { return adapter_; } const wgpu::Device& Device() const { return device_; } + const wgpu::CommandEncoder& GetCommandEncoder() const { + if (!current_command_encoder_) { + current_command_encoder_ = device_.CreateCommandEncoder(); + } + return current_command_encoder_; + } + + void EndComputePass() const { + if (current_compute_pass_encoder_) { + current_compute_pass_encoder_.End(); + current_compute_pass_encoder_ = nullptr; + } + } + + void Flush() const { + if (!current_command_encoder_) { + return; + } + + EndComputePass(); + + // TODO: add support for GPU Query + + auto command_buffer = current_command_encoder_.Finish(); + Device().GetQueue().Submit(1, &command_buffer); + BufferManager().RefreshPendingBuffers(); + current_command_encoder_ = nullptr; + } + const IBufferManager& BufferManager() const { return *buffer_mgr_; } private: @@ -39,20 +88,11 @@ class WebGpuContext final { wgpu::Device device_; std::unique_ptr buffer_mgr_; + mutable wgpu::CommandEncoder current_command_encoder_; + mutable wgpu::ComputePassEncoder current_compute_pass_encoder_; friend class WebGpuContextFactory; }; -class WebGpuContextFactory { - public: - static WebGpuContext& GetOrCreateContext(int32_t context_id = 0); - - private: - WebGpuContextFactory() {} - - static std::unordered_map> contexts_; - static std::mutex mutex_; -}; - } // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 181bcde2f47e0..46e9d6d54454f 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -722,7 +722,7 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(const WebGpuContext& context, const WebGpuExecutionProviderInfo& info, const SessionOptions* session_options) : IExecutionProvider{kWebGpuExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)}, - context_(context), + context_{context}, preferred_data_layout_{info.data_layout} { if (session_options) { enable_graph_capture_ = session_options->config_options.GetConfigOrDefault("enableGraphCapture", "false") == "true"; diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h index 873b4f458e33b..23dbb7432abdd 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h @@ -8,17 +8,16 @@ #include "core/framework/session_options.h" #include "core/graph/constants.h" #include "core/providers/providers.h" -#include "core/providers/webgpu/webgpu_context.h" struct pthreadpool; namespace onnxruntime { - namespace webgpu { // forward declaration for this EP's namespace. template KernelCreateInfo BuildKernelCreateInfo(); +class WebGpuContext; } // namespace webgpu struct WebGpuExecutionProviderInfo { @@ -40,7 +39,7 @@ struct WebGpuExecutionProviderInfo { class WebGpuExecutionProvider : public IExecutionProvider { public: - WebGpuExecutionProvider(const WebGpuContext& context, const WebGpuExecutionProviderInfo& info, const SessionOptions* session_options); + WebGpuExecutionProvider(const webgpu::WebGpuContext& context, const WebGpuExecutionProviderInfo& info, const SessionOptions* session_options); ~WebGpuExecutionProvider() override; std::vector> GetCapability( @@ -70,7 +69,7 @@ class WebGpuExecutionProvider : public IExecutionProvider { private: bool IsGraphCaptureAllowed() const; void IncrementRegularRunCountBeforeGraphCapture(); - const WebGpuContext& context_; + const webgpu::WebGpuContext& context_; DataLayout preferred_data_layout_; bool enable_graph_capture_ = false; bool is_graph_captured_ = false; diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h index bcd0e0c59b2c9..bdabcf5548f81 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h +++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h @@ -36,5 +36,5 @@ class WebGpuKernel : public OpKernel { virtual Status ComputeInternal(OpKernelContext* p_op_kernel_context) const = 0; }; -} // namespace cuda +} // namespace webgpu } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc index 3e2571be5b721..754cc4b1e83ca 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc @@ -12,8 +12,8 @@ namespace onnxruntime { struct WebGpuProviderFactory : IExecutionProviderFactory { WebGpuProviderFactory(const webgpu::WebGpuContext& context, const ProviderOptions& provider_options, const SessionOptions* session_options) - : context_{context}, - info_{provider_options}, + : info_{provider_options}, + context_{context}, session_options_(session_options) { } @@ -23,8 +23,8 @@ struct WebGpuProviderFactory : IExecutionProviderFactory { private: WebGpuExecutionProviderInfo info_; - const SessionOptions* session_options_; const webgpu::WebGpuContext& context_; + const SessionOptions* session_options_; }; std::shared_ptr WebGpuProviderFactoryCreator::Create(