From ed7f830100f7cbb723f741bf6359421495588bed Mon Sep 17 00:00:00 2001
From: Yulong Wang <7679871+fs-eire@users.noreply.github.com>
Date: Sat, 17 Aug 2024 01:46:36 -0700
Subject: [PATCH] w3

---
 cmake/onnxruntime_providers_webgpu.cmake      |   2 +
 .../webgpu/webgpu_contrib_kernels.cc          |  38 +--
 .../core/providers/webgpu/allocator.cc        |   2 +-
 onnxruntime/core/providers/webgpu/allocator.h |   2 +
 .../core/providers/webgpu/buffer_manager.cc   | 243 +++++++++++++++---
 .../core/providers/webgpu/buffer_manager.h    |  31 +--
 .../core/providers/webgpu/compute_context.cc  |   5 +
 .../core/providers/webgpu/compute_context.h   |  32 +++
 .../core/providers/webgpu/data_transfer.cc    |   3 +-
 .../core/providers/webgpu/data_transfer.h     |   2 +
 .../webgpu/math/unary_elementwise_ops.cc      |  15 +-
 .../webgpu/math/unary_elementwise_ops.h       |   1 -
 .../core/providers/webgpu/webgpu_context.cc   |   7 +-
 .../core/providers/webgpu/webgpu_context.h    |  62 ++++-
 .../webgpu/webgpu_execution_provider.cc       |   2 +-
 .../webgpu/webgpu_execution_provider.h        |   7 +-
 .../core/providers/webgpu/webgpu_kernel.h     |   2 +-
 .../webgpu/webgpu_provider_factory.cc         |   6 +-
 18 files changed, 346 insertions(+), 116 deletions(-)
 create mode 100644 onnxruntime/core/providers/webgpu/compute_context.cc
 create mode 100644 onnxruntime/core/providers/webgpu/compute_context.h

diff --git a/cmake/onnxruntime_providers_webgpu.cmake b/cmake/onnxruntime_providers_webgpu.cmake
index f3f63aa18a7e3..303ab9483c38a 100644
--- a/cmake/onnxruntime_providers_webgpu.cmake
+++ b/cmake/onnxruntime_providers_webgpu.cmake
@@ -26,3 +26,5 @@
   onnxruntime_add_static_library(onnxruntime_providers_webgpu ${onnxruntime_providers_webgpu_cc_srcs})
   onnxruntime_add_include_to_target(onnxruntime_providers_webgpu onnxruntime_common onnx onnx_proto flatbuffers::flatbuffers Boost::mp11 safeint_interface)
   target_link_libraries(onnxruntime_providers_webgpu dawn::webgpu_dawn)
+
+  set_target_properties(onnxruntime_providers_webgpu PROPERTIES FOLDER "ONNXRuntime")
diff --git a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
index b2b5f88eb84e1..91f51df588fca 100644
--- a/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
@@ -35,25 +35,25 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Sk
 Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
   static const BuildKernelCreateInfoFn function_table[] = {
       BuildKernelCreateInfo<void>,  // default entry to avoid the list become empty after ops-reducing
-      // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention)>,
-      // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,
-      // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
-      // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
-      // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
-      // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
-      // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
-      // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
-      // BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization)>,
-      // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
-      // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
-      // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
-      // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
-      // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
-      //                                                       SkipLayerNormalization)>,
-      // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1,
-      //                                                       SimplifiedLayerNormalization)>,
-      // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
-      //                                                       SkipSimplifiedLayerNormalization)>
+                                    // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention)>,
+                                    // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,
+                                    // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
+                                    // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
+                                    // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
+                                    // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
+                                    // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
+                                    // // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
+                                    // BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization)>,
+                                    // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
+                                    // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
+                                    // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
+                                    // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
+                                    // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
+                                    //                                                       SkipLayerNormalization)>,
+                                    // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1,
+                                    //                                                       SimplifiedLayerNormalization)>,
+                                    // BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
+                                    //                                                       SkipSimplifiedLayerNormalization)>
   };
 
   for (auto& function_table_entry : function_table) {
diff --git a/onnxruntime/core/providers/webgpu/allocator.cc b/onnxruntime/core/providers/webgpu/allocator.cc
index d3e25637dfd97..8e27acdc285d4 100644
--- a/onnxruntime/core/providers/webgpu/allocator.cc
+++ b/onnxruntime/core/providers/webgpu/allocator.cc
@@ -20,7 +20,7 @@ void* GpuBufferAllocator::Alloc(size_t size) {
   auto buffer = context_.BufferManager().Create(size);
 
   stats_.num_allocs++;
-  return buffer.Get();
+  return buffer;
 }
 
 void GpuBufferAllocator::Free(void* p) {
diff --git a/onnxruntime/core/providers/webgpu/allocator.h b/onnxruntime/core/providers/webgpu/allocator.h
index d139c9eb9574c..51ca65a8b4822 100644
--- a/onnxruntime/core/providers/webgpu/allocator.h
+++ b/onnxruntime/core/providers/webgpu/allocator.h
@@ -9,6 +9,8 @@
 namespace onnxruntime {
 namespace webgpu {
 
+class WebGpuContext;
+
 class GpuBufferAllocator : public IAllocator {
  public:
   GpuBufferAllocator(const WebGpuContext& context)
diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.cc b/onnxruntime/core/providers/webgpu/buffer_manager.cc
index d0f4b1142caf2..8722eba77eaa1 100644
--- a/onnxruntime/core/providers/webgpu/buffer_manager.cc
+++ b/onnxruntime/core/providers/webgpu/buffer_manager.cc
@@ -2,23 +2,28 @@
 // Licensed under the MIT License.
 
 #include "core/providers/webgpu/buffer_manager.h"
+#include "core/providers/webgpu/webgpu_context.h"
 
 namespace onnxruntime {
 namespace webgpu {
 
+size_t NormalizeBufferSize(size_t size) {
+  return (size + 15) / 16 * 16;
+}
+
 class DisabledCacheManager : public IBufferCacheManager {
   size_t CalculateBufferSize(size_t request_size) override {
-    return (request_size + 15) / 16 * 16;
+    return NormalizeBufferSize(request_size);
   }
 
   WGPUBuffer TryAcquireCachedBuffer(size_t /*buffer_size*/, wgpu::BufferUsage /*usage*/) override {
     // always return empty buffer
     return nullptr;
   }
-  void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/, size_t /*buffer_size*/, wgpu::BufferUsage /*usage*/) override {
+  void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override {
     // no-op
   }
-  void ReleaseBuffer(WGPUBuffer buffer, size_t /*buffer_size*/, wgpu::BufferUsage /*usage*/) override {
+  void ReleaseBuffer(WGPUBuffer buffer) override {
     wgpuBufferDestroy(buffer);
   }
 
@@ -29,7 +34,7 @@ class DisabledCacheManager : public IBufferCacheManager {
 
 class SimpleCacheManager : public IBufferCacheManager {
   size_t CalculateBufferSize(size_t request_size) override {
-    return (request_size + 15) / 16 * 16;
+    return NormalizeBufferSize(request_size);
   }
 
   WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size, wgpu::BufferUsage usage) override {
@@ -44,49 +49,33 @@ class SimpleCacheManager : public IBufferCacheManager {
 
     return nullptr;
   }
-  void RegisterBuffer(WGPUBuffer buffer, size_t /*request_size*/, size_t buffer_size, wgpu::BufferUsage usage) override {
+
+  void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override {
+    // no-op
   }
-  void ReleaseBuffer(WGPUBuffer buffer, size_t buffer_size, wgpu::BufferUsage usage) override {
-    if (usage | wgpu::BufferUsage::Storage) {
-      pending_buffers_.emplace_back(buffer, buffer_size);
+
+  void ReleaseBuffer(WGPUBuffer buffer) override {
+    auto usage = wgpuBufferGetUsage(buffer);
+    if (usage | WGPUBufferUsage_Storage) {
+      pending_buffers_.emplace_back(buffer);
     } else {
       wgpuBufferDestroy(buffer);
     }
   }
+
   void OnRefresh() override {
-    for (auto& pair : pending_buffers_) {
-      buffers_[pair.second].push_back(pair.first);
+    for (auto& buffer : pending_buffers_) {
+      buffers_[wgpuBufferGetSize(buffer)].push_back(buffer);
     }
     pending_buffers_.clear();
   }
 
   std::map<size_t, std::vector<WGPUBuffer>> buffers_;
-  std::vector<std::pair<WGPUBuffer, size_t>> pending_buffers_;
+  std::vector<WGPUBuffer> pending_buffers_;
 };
 
-class BucketCacheManager : public IBufferCacheManager {
-  static const std::unordered_map<size_t, size_t> kBucketSizes;
-  static const std::array<size_t> kBucketSizesArray;
-
-  size_t CalculateBufferSize(size_t request_size) override {
-    ORT_NOT_IMPLEMENTED("TODO");
-  }
-
-  WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size, wgpu::BufferUsage usage) override {
-    ORT_NOT_IMPLEMENTED("TODO");
-  }
-  void RegisterBuffer(WGPUBuffer buffer, size_t request_size, size_t buffer_size, wgpu::BufferUsage usage) override {
-    ORT_NOT_IMPLEMENTED("TODO");
-  }
-  void ReleaseBuffer(WGPUBuffer buffer, size_t buffer_size, wgpu::BufferUsage usage) override {
-    ORT_NOT_IMPLEMENTED("TODO");
-  }
-  void OnRefresh() override {
-    ORT_NOT_IMPLEMENTED("TODO");
-  }
-};
-
-constexpr std::initializer_list<std::pair<size_t, size_t>> BUCKET_TABLE = {
+// TODO: maybe use different bucket size for storage and uniform buffers?
+constexpr std::initializer_list<std::pair<const size_t, size_t>> BUCKET_DEFAULT_LIMIT_TABLE = {
     {64, 250},
     {128, 200},
     {256, 200},
@@ -116,8 +105,116 @@ constexpr std::initializer_list<std::pair<size_t, size_t>> BUCKET_TABLE = {
     {134217728, 6},
     {167772160, 6},
 };
-const std::unordered_map<size_t, size_t> BucketCacheManager::kBucketSizes{BUCKET_TABLE};
 
+class BucketCacheManager : public IBufferCacheManager {
+ public:
+  BucketCacheManager() : buckets_limit_{BUCKET_DEFAULT_LIMIT_TABLE} {
+    Initialize();
+  }
+  BucketCacheManager(std::unordered_map<size_t, size_t>&& buckets_limit) : buckets_limit_{buckets_limit} {
+    Initialize();
+  }
+
+  size_t CalculateBufferSize(size_t request_size) override {
+    // binary serch size
+    auto it = std::lower_bound(buckets_keys_.begin(), buckets_keys_.end(), request_size);
+    if (it == buckets_keys_.end()) {
+      return NormalizeBufferSize(request_size);
+    } else {
+      return *it;
+    }
+  }
+
+  WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size, wgpu::BufferUsage usage) override {
+    std::unordered_map<size_t, std::vector<WGPUBuffer>>* buckets = nullptr;
+    if (usage | wgpu::BufferUsage::Storage) {
+      buckets = &buckets_storage_;
+    } else if (usage | wgpu::BufferUsage::Uniform) {
+      buckets = &buckets_uniform_;
+    }
+    if (buckets) {
+      auto it = buckets->find(buffer_size);
+      if (it != buckets->end() && !it->second.empty()) {
+        auto buffer = it->second.back();
+        it->second.pop_back();
+        return buffer;
+      }
+    }
+    return nullptr;
+  }
+
+  void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override {
+    // no-op
+  }
+
+  void ReleaseBuffer(WGPUBuffer buffer) override {
+    std::vector<WGPUBuffer>* pending_buffers = nullptr;
+    auto usage = wgpuBufferGetUsage(buffer);
+    if (usage | WGPUBufferUsage_Storage) {
+      pending_buffers = &pending_storage_buffers_;
+    } else if (usage | WGPUBufferUsage_Uniform) {
+      pending_buffers = &pending_uniform_buffers_;
+    }
+    if (pending_buffers) {
+      pending_buffers->emplace_back(buffer);
+    } else {
+      wgpuBufferDestroy(buffer);
+    }
+  }
+
+  void OnRefresh() override {
+    // TODO: consider graph capture. currently not supported
+
+    for (auto& buffer : pending_storage_buffers_) {
+      auto buffer_size = wgpuBufferGetSize(buffer);
+
+      auto it = buckets_storage_.find(buffer_size);
+      if (it != buckets_storage_.end() && it->second.size() < buckets_limit_[buffer_size]) {
+        it->second.push_back(buffer);
+      } else {
+        wgpuBufferDestroy(buffer);
+      }
+    }
+
+    for (auto& buffer : pending_uniform_buffers_) {
+      auto buffer_size = wgpuBufferGetSize(buffer);
+
+      auto it = buckets_uniform_.find(buffer_size);
+      if (it != buckets_uniform_.end() && it->second.size() < buckets_limit_[buffer_size]) {
+        it->second.push_back(buffer);
+      } else {
+        wgpuBufferDestroy(buffer);
+      }
+    }
+  }
+
+ protected:
+  void Initialize() {
+    buckets_keys_.reserve(buckets_limit_.size());
+    buckets_storage_.reserve(buckets_limit_.size());
+    buckets_uniform_.reserve(buckets_limit_.size());
+    for (const auto& pair : buckets_limit_) {
+      buckets_keys_.push_back(pair.first);
+      buckets_storage_.emplace(pair.first, std::vector<WGPUBuffer>());
+      buckets_uniform_.emplace(pair.first, std::vector<WGPUBuffer>());
+    }
+#ifndef NDEBUG
+    for (size_t i = 0; i < buckets_keys_.size(); ++i) {
+      ORT_ENFORCE(buckets_keys_[i] % 16 == 0, "Bucket sizes must be multiples of 16.");
+    }
+
+    for (size_t i = 1; i < buckets_keys_.size(); ++i) {
+      ORT_ENFORCE(buckets_keys_[i] > buckets_keys_[i - 1], "Bucket sizes must be in increasing order.");
+    }
+#endif
+  }
+  std::unordered_map<size_t, size_t> buckets_limit_;
+  std::unordered_map<size_t, std::vector<WGPUBuffer>> buckets_storage_;
+  std::vector<WGPUBuffer> pending_storage_buffers_;
+  std::unordered_map<size_t, std::vector<WGPUBuffer>> buckets_uniform_;
+  std::vector<WGPUBuffer> pending_uniform_buffers_;
+  std::vector<size_t> buckets_keys_;
+};
 
 std::unique_ptr<IBufferCacheManager> CreateBufferCacheManager(BufferCacheMode cache_mode) {
   switch (cache_mode) {
@@ -132,13 +229,51 @@ std::unique_ptr<IBufferCacheManager> CreateBufferCacheManager(BufferCacheMode ca
   }
 }
 
-BufferManager::BufferManager(wgpu::Device device, BufferCacheMode cache_mode) : IBufferManager{device, CreateBufferCacheManager(cache_mode)} {
+class BufferManager : public IBufferManager {
+ public:
+  BufferManager(const WebGpuContext& context, BufferCacheMode cache_mode);
+
+  void Upload(void* src, WGPUBuffer dst, size_t size) const override;
+  void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const override;
+  WGPUBuffer Create(size_t size, wgpu::BufferUsage usage) const override;
+  void Release(WGPUBuffer buffer) const override;
+  void Download(WGPUBuffer src, void* dst, size_t size) const override;
+  void RefreshPendingBuffers() const override;
+};
+
+BufferManager::BufferManager(const WebGpuContext& context, BufferCacheMode cache_mode) : IBufferManager{context, CreateBufferCacheManager(cache_mode)} {
 }
 
 void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) const {
+  auto buffer_size = NormalizeBufferSize(size);
+
+  wgpu::BufferDescriptor desc;
+  desc.size = buffer_size;
+  desc.usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapWrite;
+  desc.mappedAtCreation = true;
+
+  auto staging_buffer = context_.Device().CreateBuffer(&desc);
+  auto mapped_data = staging_buffer.GetMappedRange();
+  memcpy(mapped_data, src, size);
+  staging_buffer.Unmap();
+
+  auto& command_encoder = context_.GetCommandEncoder();
+  context_.EndComputePass();
+  command_encoder.CopyBufferToBuffer(staging_buffer, 0, dst, 0, buffer_size);
+  staging_buffer.Destroy();
 }
 
 void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const {
+  ORT_ENFORCE(src != dst, "Source and destination buffers must be different.");
+
+  auto buffer_size = NormalizeBufferSize(size);
+  ORT_ENFORCE(buffer_size <= wgpuBufferGetSize(src) && buffer_size <= wgpuBufferGetSize(dst),
+              "Source and destination buffers must have enough space for the copy operation. src_size=",
+              wgpuBufferGetSize(src), ", dst_size=", wgpuBufferGetSize(dst), ", copy_size=", buffer_size, ".");
+
+  auto& command_encoder = context_.GetCommandEncoder();
+  context_.EndComputePass();
+  command_encoder.CopyBufferToBuffer(src, 0, dst, 0, buffer_size);
 }
 
 WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) const {
@@ -153,25 +288,49 @@ WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) const {
   wgpu::BufferDescriptor desc;
   desc.size = buffer_size;
   desc.usage = usage;
-  buffer = device_.CreateBuffer(&desc);
+  buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle();
 
-  ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", usage, ".");
+  ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", uint64_t(usage), ".");
 
-  cache_->RegisterBuffer(buffer, size, buffer_size, usage);
+  cache_->RegisterBuffer(buffer, size);
   return buffer;
 }
 
 void BufferManager::Release(WGPUBuffer buffer) const {
-  cache_->ReleaseBuffer(buffer, 0, wgpu::BufferUsage::None);
+  cache_->ReleaseBuffer(buffer);
 }
 
-wgpu::Future BufferManager::Download(WGPUBuffer src, void* dst, size_t size) const {
-  return wgpu::Future();
+void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) const {
+  auto buffer_size = NormalizeBufferSize(size);
+
+  wgpu::BufferDescriptor desc;
+  desc.size = buffer_size;
+  desc.usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead;
+
+  auto staging_buffer = context_.Device().CreateBuffer(&desc);
+  auto& command_encoder = context_.GetCommandEncoder();
+  context_.EndComputePass();
+  command_encoder.CopyBufferToBuffer(src, 0, staging_buffer, 0, buffer_size);
+  context_.Flush();
+
+  wgpu::BufferMapCallbackInfo callback_info;
+  callback_info.mode = wgpu::CallbackMode::WaitAnyOnly;
+  callback_info.callback = [](WGPUBufferMapAsyncStatus status, void*) {
+    ORT_ENFORCE(status == WGPUBufferMapAsyncStatus_Success, "Failed to download data from buffer");
+  };
+  ORT_ENFORCE(context_.Wait(staging_buffer.MapAsync(wgpu::MapMode::Read, 0, buffer_size, callback_info)) == Status::OK());
+
+  auto mapped_data = staging_buffer.GetMappedRange();
+  memcpy(dst, mapped_data, size);
 }
 
 void BufferManager::RefreshPendingBuffers() const {
   cache_->OnRefresh();
 }
 
+std::unique_ptr<IBufferManager> BufferManagerFactory::Create(const WebGpuContext& context, BufferCacheMode mode) {
+  return std::make_unique<BufferManager>(context, mode);
+}
+
 }  // namespace webgpu
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/webgpu/buffer_manager.h b/onnxruntime/core/providers/webgpu/buffer_manager.h
index 8dac698f9ccef..a411c44812339 100644
--- a/onnxruntime/core/providers/webgpu/buffer_manager.h
+++ b/onnxruntime/core/providers/webgpu/buffer_manager.h
@@ -15,6 +15,8 @@
 namespace onnxruntime {
 namespace webgpu {
 
+class WebGpuContext;
+
 enum class BufferCacheMode {
   None,
   Simple,
@@ -32,10 +34,10 @@ class IBufferCacheManager {
   virtual WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size, wgpu::BufferUsage usage) = 0;
 
   // register a newly created buffer
-  virtual void RegisterBuffer(WGPUBuffer buffer, size_t request_size, size_t buffer_size, wgpu::BufferUsage usage) = 0;
+  virtual void RegisterBuffer(WGPUBuffer buffer, size_t request_size) = 0;
 
   // release a buffer
-  virtual void ReleaseBuffer(WGPUBuffer buffer, size_t buffer_size, wgpu::BufferUsage usage) = 0;
+  virtual void ReleaseBuffer(WGPUBuffer buffer) = 0;
 
   // when a stream refresh is requested
   virtual void OnRefresh() = 0;
@@ -43,7 +45,7 @@ class IBufferCacheManager {
 
 class IBufferManager {
  protected:
-  IBufferManager(wgpu::Device device, std::unique_ptr<IBufferCacheManager> cache) : device_{device}, cache_{std::move(cache)} {}
+  IBufferManager(const WebGpuContext& context, std::unique_ptr<IBufferCacheManager> cache) : context_{context}, cache_{std::move(cache)} {}
 
  public:
   virtual ~IBufferManager() = default;
@@ -51,35 +53,22 @@ class IBufferManager {
   virtual void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const = 0;
   virtual WGPUBuffer Create(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst) const = 0;
   virtual void Release(WGPUBuffer buffer) const = 0;
-  virtual wgpu::Future Download(WGPUBuffer src, void* dst, size_t size) const = 0;
+  virtual void Download(WGPUBuffer src, void* dst, size_t size) const = 0;
   virtual void RefreshPendingBuffers() const = 0;
 
   // TODO: add statistics
 
  protected:
-  wgpu::Device device_;
+  const WebGpuContext& context_;
   std::unique_ptr<IBufferCacheManager> cache_;
 };
 
-class BufferManager : public IBufferManager {
+class BufferManagerFactory {
  public:
-  BufferManager(wgpu::Device device, BufferCacheMode cache_mode);
-
-  void Upload(void* src, WGPUBuffer dst, size_t size) const override;
-  void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const override;
-  WGPUBuffer Create(size_t size, wgpu::BufferUsage usage) const override;
-  void Release(WGPUBuffer buffer) const override;
-  wgpu::Future Download(WGPUBuffer src, void* dst, size_t size) const override;
-  void RefreshPendingBuffers() const override;
+  static std::unique_ptr<IBufferManager> Create(const WebGpuContext& context, BufferCacheMode mode);
 
  private:
-  struct PendingBuffer {
-    wgpu::Buffer buffer;
-    void* data;
-    size_t size;
-  };
-
-  std::vector<PendingBuffer> pending_buffers_;
+  BufferManagerFactory() {}
 };
 
 }  // namespace webgpu
diff --git a/onnxruntime/core/providers/webgpu/compute_context.cc b/onnxruntime/core/providers/webgpu/compute_context.cc
new file mode 100644
index 0000000000000..a558b8da68533
--- /dev/null
+++ b/onnxruntime/core/providers/webgpu/compute_context.cc
@@ -0,0 +1,5 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/webgpu/compute_context.h"
+#include "core/providers/webgpu/webgpu_context.h"
diff --git a/onnxruntime/core/providers/webgpu/compute_context.h b/onnxruntime/core/providers/webgpu/compute_context.h
new file mode 100644
index 0000000000000..627454d5ef7cf
--- /dev/null
+++ b/onnxruntime/core/providers/webgpu/compute_context.h
@@ -0,0 +1,32 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#ifdef __EMSCRIPTEN__
+#include <emscripten/emscripten.h>
+#endif
+
+#include <webgpu/webgpu_cpp.h>
+
+#include "core/framework/execution_provider.h"
+
+namespace onnxruntime {
+namespace webgpu {
+
+class WebGpuContext;
+
+class ComputeContext {
+ public:
+  ComputeContext(const WebGpuContext& context) : context_{context} {}
+
+  virtual ~ComputeContext() = default;
+
+  virtual void Dispatch(WGPUComputePassEncoder pass) const = 0;
+
+ protected:
+  const WebGpuContext& context_;
+};
+
+}  // namespace webgpu
+}  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/webgpu/data_transfer.cc b/onnxruntime/core/providers/webgpu/data_transfer.cc
index d211980128501..615ae11175782 100644
--- a/onnxruntime/core/providers/webgpu/data_transfer.cc
+++ b/onnxruntime/core/providers/webgpu/data_transfer.cc
@@ -37,8 +37,7 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {
       }
     } else /* if (src_device.Type() == OrtDevice::GPU) */ {
       // copy from GPU to CPU
-      ORT_RETURN_IF_ERROR(context_.Wait(
-          context_.BufferManager().Download(static_cast<WGPUBuffer>(const_cast<void*>(src_data)), dst_data, bytes)));
+      context_.BufferManager().Download(static_cast<WGPUBuffer>(const_cast<void*>(src_data)), dst_data, bytes);
     }
   }
 
diff --git a/onnxruntime/core/providers/webgpu/data_transfer.h b/onnxruntime/core/providers/webgpu/data_transfer.h
index 26f3d159d2b4d..79853483e0c23 100644
--- a/onnxruntime/core/providers/webgpu/data_transfer.h
+++ b/onnxruntime/core/providers/webgpu/data_transfer.h
@@ -9,6 +9,8 @@
 namespace onnxruntime {
 namespace webgpu {
 
+class WebGpuContext;
+
 class DataTransfer : public IDataTransfer {
  public:
   DataTransfer(const WebGpuContext& context) : context_{context} {};
diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc
index 114256d2c2218..16672a434b321 100644
--- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc
+++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.cc
@@ -19,21 +19,20 @@ Status UnaryElementwise::ComputeInternal(OpKernelContext* context) const {
   return Status(common::ONNXRUNTIME, common::FAIL);
 }
 
-#define WEBGPU_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE)  \
-  ONNX_OPERATOR_KERNEL_EX(                                             \
-      OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider,         \
-      KernelDefBuilder().TypeConstraint("T", TYPE),                    \
+#define WEBGPU_ELEMENTWISE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \
+  ONNX_OPERATOR_KERNEL_EX(                                              \
+      OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider,          \
+      KernelDefBuilder().TypeConstraint("T", TYPE),                     \
       KERNEL_CLASS);
 
 #define WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \
-  ONNX_OPERATOR_VERSIONED_KERNEL_EX(                                                             \
-      OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider,                  \
-      KernelDefBuilder().TypeConstraint("T", TYPE),                                              \
+  ONNX_OPERATOR_VERSIONED_KERNEL_EX(                                                               \
+      OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider,                    \
+      KernelDefBuilder().TypeConstraint("T", TYPE),                                                \
       KERNEL_CLASS);
 
 WEBGPU_ELEMENTWISE_VERSIONED_KERNEL(Abs, 6, 12, Abs, WebGpuSupportedFloatTypes())
 WEBGPU_ELEMENTWISE_KERNEL(Abs, 13, Abs, WebGpuSupportedFloatTypes())
 
-
 }  // namespace webgpu
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h
index 0da0cfa74c150..7e913b6656bc1 100644
--- a/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h
+++ b/onnxruntime/core/providers/webgpu/math/unary_elementwise_ops.h
@@ -26,6 +26,5 @@ class Abs final : public UnaryElementwise {
   // Status ComputeInternal(OpKernelContext* context) const override;
 };
 
-
 }  // namespace webgpu
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.cc b/onnxruntime/core/providers/webgpu/webgpu_context.cc
index 1c8831992096b..e4686a355c09e 100644
--- a/onnxruntime/core/providers/webgpu/webgpu_context.cc
+++ b/onnxruntime/core/providers/webgpu/webgpu_context.cc
@@ -91,7 +91,7 @@ void WebGpuContext::Initialize() {
     ORT_ENFORCE(device_.GetLimits(&limits));
 
     // create buffer manager
-    buffer_mgr_ = std::make_unique<webgpu::BufferManager>(device_);
+    buffer_mgr_ = BufferManagerFactory::Create(*this, BufferCacheMode::None);
   });
 }
 
@@ -103,12 +103,15 @@ Status WebGpuContext::Wait(wgpu::Future f) const {
   return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status));
 }
 
+std::unordered_map<int32_t, std::unique_ptr<WebGpuContext>> WebGpuContextFactory::contexts_;
+std::mutex WebGpuContextFactory::mutex_;
+
 WebGpuContext& WebGpuContextFactory::GetOrCreateContext(int32_t context_id) {
   std::lock_guard<std::mutex> lock(mutex_);
 
   auto it = contexts_.find(context_id);
   if (it == contexts_.end()) {
-    auto context = std::make_unique<WebGpuContext>();
+    auto context = std::unique_ptr<WebGpuContext>(new WebGpuContext());
     it = contexts_.emplace(context_id, std::move(context)).first;
   }
   return *it->second;
diff --git a/onnxruntime/core/providers/webgpu/webgpu_context.h b/onnxruntime/core/providers/webgpu/webgpu_context.h
index 904507da9bb8c..bdba190278c15 100644
--- a/onnxruntime/core/providers/webgpu/webgpu_context.h
+++ b/onnxruntime/core/providers/webgpu/webgpu_context.h
@@ -16,17 +16,66 @@
 
 namespace onnxruntime {
 namespace webgpu {
+class WebGpuContext;
+
+class WebGpuContextFactory {
+ public:
+  static WebGpuContext& GetOrCreateContext(int32_t context_id = 0);
+
+ private:
+  WebGpuContextFactory() {}
+
+  static std::unordered_map<int32_t, std::unique_ptr<WebGpuContext>> contexts_;
+  static std::mutex mutex_;
+};
 
 // Class WebGpuContext includes all necessary resources for the context.
 class WebGpuContext final {
  public:
   void Initialize();
 
+  // non copyable
+  WebGpuContext(const WebGpuContext&) = delete;
+  WebGpuContext& operator=(const WebGpuContext&) = delete;
+
+  // non movable
+  WebGpuContext(WebGpuContext&&) = delete;
+  WebGpuContext& operator=(WebGpuContext&&) = delete;
+
   Status Wait(wgpu::Future f) const;
 
   const wgpu::Adapter& Adapter() const { return adapter_; }
   const wgpu::Device& Device() const { return device_; }
 
+  const wgpu::CommandEncoder& GetCommandEncoder() const {
+    if (!current_command_encoder_) {
+      current_command_encoder_ = device_.CreateCommandEncoder();
+    }
+    return current_command_encoder_;
+  }
+
+  void EndComputePass() const {
+    if (current_compute_pass_encoder_) {
+      current_compute_pass_encoder_.End();
+      current_compute_pass_encoder_ = nullptr;
+    }
+  }
+
+  void Flush() const {
+    if (!current_command_encoder_) {
+      return;
+    }
+
+    EndComputePass();
+
+    // TODO: add support for GPU Query
+
+    auto command_buffer = current_command_encoder_.Finish();
+    Device().GetQueue().Submit(1, &command_buffer);
+    BufferManager().RefreshPendingBuffers();
+    current_command_encoder_ = nullptr;
+  }
+
   const IBufferManager& BufferManager() const { return *buffer_mgr_; }
 
  private:
@@ -39,20 +88,11 @@ class WebGpuContext final {
   wgpu::Device device_;
 
   std::unique_ptr<IBufferManager> buffer_mgr_;
+  mutable wgpu::CommandEncoder current_command_encoder_;
+  mutable wgpu::ComputePassEncoder current_compute_pass_encoder_;
 
   friend class WebGpuContextFactory;
 };
 
-class WebGpuContextFactory {
- public:
-  static WebGpuContext& GetOrCreateContext(int32_t context_id = 0);
-
- private:
-  WebGpuContextFactory() {}
-
-  static std::unordered_map<int32_t, std::unique_ptr<WebGpuContext>> contexts_;
-  static std::mutex mutex_;
-};
-
 }  // namespace webgpu
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
index 181bcde2f47e0..46e9d6d54454f 100644
--- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
+++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc
@@ -722,7 +722,7 @@ WebGpuExecutionProvider::WebGpuExecutionProvider(const WebGpuContext& context,
                                                  const WebGpuExecutionProviderInfo& info,
                                                  const SessionOptions* session_options)
     : IExecutionProvider{kWebGpuExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)},
-      context_(context),
+      context_{context},
       preferred_data_layout_{info.data_layout} {
   if (session_options) {
     enable_graph_capture_ = session_options->config_options.GetConfigOrDefault("enableGraphCapture", "false") == "true";
diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h
index 873b4f458e33b..23dbb7432abdd 100644
--- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h
+++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.h
@@ -8,17 +8,16 @@
 #include "core/framework/session_options.h"
 #include "core/graph/constants.h"
 #include "core/providers/providers.h"
-#include "core/providers/webgpu/webgpu_context.h"
 
 struct pthreadpool;
 namespace onnxruntime {
-
 namespace webgpu {
 
 // forward declaration for this EP's namespace.
 template <typename T>
 KernelCreateInfo BuildKernelCreateInfo();
 
+class WebGpuContext;
 }  // namespace webgpu
 
 struct WebGpuExecutionProviderInfo {
@@ -40,7 +39,7 @@ struct WebGpuExecutionProviderInfo {
 
 class WebGpuExecutionProvider : public IExecutionProvider {
  public:
-  WebGpuExecutionProvider(const WebGpuContext& context, const WebGpuExecutionProviderInfo& info, const SessionOptions* session_options);
+  WebGpuExecutionProvider(const webgpu::WebGpuContext& context, const WebGpuExecutionProviderInfo& info, const SessionOptions* session_options);
   ~WebGpuExecutionProvider() override;
 
   std::vector<std::unique_ptr<ComputeCapability>> GetCapability(
@@ -70,7 +69,7 @@ class WebGpuExecutionProvider : public IExecutionProvider {
  private:
   bool IsGraphCaptureAllowed() const;
   void IncrementRegularRunCountBeforeGraphCapture();
-  const WebGpuContext& context_;
+  const webgpu::WebGpuContext& context_;
   DataLayout preferred_data_layout_;
   bool enable_graph_capture_ = false;
   bool is_graph_captured_ = false;
diff --git a/onnxruntime/core/providers/webgpu/webgpu_kernel.h b/onnxruntime/core/providers/webgpu/webgpu_kernel.h
index bcd0e0c59b2c9..bdabcf5548f81 100644
--- a/onnxruntime/core/providers/webgpu/webgpu_kernel.h
+++ b/onnxruntime/core/providers/webgpu/webgpu_kernel.h
@@ -36,5 +36,5 @@ class WebGpuKernel : public OpKernel {
   virtual Status ComputeInternal(OpKernelContext* p_op_kernel_context) const = 0;
 };
 
-}  // namespace cuda
+}  // namespace webgpu
 }  // namespace onnxruntime
diff --git a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc
index 3e2571be5b721..754cc4b1e83ca 100644
--- a/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc
+++ b/onnxruntime/core/providers/webgpu/webgpu_provider_factory.cc
@@ -12,8 +12,8 @@ namespace onnxruntime {
 
 struct WebGpuProviderFactory : IExecutionProviderFactory {
   WebGpuProviderFactory(const webgpu::WebGpuContext& context, const ProviderOptions& provider_options, const SessionOptions* session_options)
-      : context_{context},
-        info_{provider_options},
+      : info_{provider_options},
+        context_{context},
         session_options_(session_options) {
   }
 
@@ -23,8 +23,8 @@ struct WebGpuProviderFactory : IExecutionProviderFactory {
 
  private:
   WebGpuExecutionProviderInfo info_;
-  const SessionOptions* session_options_;
   const webgpu::WebGpuContext& context_;
+  const SessionOptions* session_options_;
 };
 
 std::shared_ptr<IExecutionProviderFactory> WebGpuProviderFactoryCreator::Create(