Skip to content

Commit

Permalink
w3
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Aug 17, 2024
1 parent a92422f commit fc4f0e6
Show file tree
Hide file tree
Showing 17 changed files with 344 additions and 116 deletions.
38 changes: 19 additions & 19 deletions onnxruntime/contrib_ops/webgpu/webgpu_contrib_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,25 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Sk
Status RegisterWebGpuContribKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
// // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1,
// SimplifiedLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipSimplifiedLayerNormalization)>
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Attention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasAdd)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, BiasSplitGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FastGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, FusedConv)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, Gelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, GroupQueryAttention)>,
// // LayerNormalization used to be a contrib op that (incorrectly) used kOnnxDomain so we need to version it
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 16, LayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MatMulNBits)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, MultiHeadAttention)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, QuickGelu)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1, RotaryEmbedding)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1,
// SimplifiedLayerNormalization)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kMSDomain, 1,
// SkipSimplifiedLayerNormalization)>
};

for (auto& function_table_entry : function_table) {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/webgpu/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ void* GpuBufferAllocator::Alloc(size_t size) {
auto buffer = context_.BufferManager().Create(size);

stats_.num_allocs++;
return buffer.Get();
return buffer;
}

void GpuBufferAllocator::Free(void* p) {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/webgpu/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
namespace onnxruntime {
namespace webgpu {

class WebGpuContext;

class GpuBufferAllocator : public IAllocator {
public:
GpuBufferAllocator(const WebGpuContext& context)
Expand Down
243 changes: 201 additions & 42 deletions onnxruntime/core/providers/webgpu/buffer_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,28 @@
// Licensed under the MIT License.

#include "core/providers/webgpu/buffer_manager.h"
#include "core/providers/webgpu/webgpu_context.h"

namespace onnxruntime {
namespace webgpu {

size_t NormalizeBufferSize(size_t size) {
return (size + 15) / 16 * 16;
}

class DisabledCacheManager : public IBufferCacheManager {
size_t CalculateBufferSize(size_t request_size) override {
return (request_size + 15) / 16 * 16;
return NormalizeBufferSize(request_size);
}

WGPUBuffer TryAcquireCachedBuffer(size_t /*buffer_size*/, wgpu::BufferUsage /*usage*/) override {
// always return empty buffer
return nullptr;
}
void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/, size_t /*buffer_size*/, wgpu::BufferUsage /*usage*/) override {
void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override {
// no-op
}
void ReleaseBuffer(WGPUBuffer buffer, size_t /*buffer_size*/, wgpu::BufferUsage /*usage*/) override {
void ReleaseBuffer(WGPUBuffer buffer) override {
wgpuBufferDestroy(buffer);
}

Expand All @@ -29,7 +34,7 @@ class DisabledCacheManager : public IBufferCacheManager {

class SimpleCacheManager : public IBufferCacheManager {
size_t CalculateBufferSize(size_t request_size) override {
return (request_size + 15) / 16 * 16;
return NormalizeBufferSize(request_size);
}

WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size, wgpu::BufferUsage usage) override {
Expand All @@ -44,49 +49,33 @@ class SimpleCacheManager : public IBufferCacheManager {

return nullptr;
}
void RegisterBuffer(WGPUBuffer buffer, size_t /*request_size*/, size_t buffer_size, wgpu::BufferUsage usage) override {

void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override {
// no-op
}
void ReleaseBuffer(WGPUBuffer buffer, size_t buffer_size, wgpu::BufferUsage usage) override {
if (usage | wgpu::BufferUsage::Storage) {
pending_buffers_.emplace_back(buffer, buffer_size);

void ReleaseBuffer(WGPUBuffer buffer) override {
auto usage = wgpuBufferGetUsage(buffer);
if (usage | WGPUBufferUsage_Storage) {
pending_buffers_.emplace_back(buffer);
} else {
wgpuBufferDestroy(buffer);
}
}

void OnRefresh() override {
for (auto& pair : pending_buffers_) {
buffers_[pair.second].push_back(pair.first);
for (auto& buffer : pending_buffers_) {
buffers_[wgpuBufferGetSize(buffer)].push_back(buffer);
}
pending_buffers_.clear();
}

std::map<size_t, std::vector<WGPUBuffer>> buffers_;
std::vector<std::pair<WGPUBuffer, size_t>> pending_buffers_;
std::vector<WGPUBuffer> pending_buffers_;
};

class BucketCacheManager : public IBufferCacheManager {
static const std::unordered_map<size_t, size_t> kBucketSizes;
static const std::array<size_t> kBucketSizesArray;

size_t CalculateBufferSize(size_t request_size) override {
ORT_NOT_IMPLEMENTED("TODO");
}

WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size, wgpu::BufferUsage usage) override {
ORT_NOT_IMPLEMENTED("TODO");
}
void RegisterBuffer(WGPUBuffer buffer, size_t request_size, size_t buffer_size, wgpu::BufferUsage usage) override {
ORT_NOT_IMPLEMENTED("TODO");
}
void ReleaseBuffer(WGPUBuffer buffer, size_t buffer_size, wgpu::BufferUsage usage) override {
ORT_NOT_IMPLEMENTED("TODO");
}
void OnRefresh() override {
ORT_NOT_IMPLEMENTED("TODO");
}
};

constexpr std::initializer_list<std::pair<size_t, size_t>> BUCKET_TABLE = {
// TODO: maybe use different bucket size for storage and uniform buffers?
constexpr std::initializer_list<std::pair<const size_t, size_t>> BUCKET_DEFAULT_LIMIT_TABLE = {
{64, 250},
{128, 200},
{256, 200},
Expand Down Expand Up @@ -116,8 +105,116 @@ constexpr std::initializer_list<std::pair<size_t, size_t>> BUCKET_TABLE = {
{134217728, 6},
{167772160, 6},
};
const std::unordered_map<size_t, size_t> BucketCacheManager::kBucketSizes{BUCKET_TABLE};

class BucketCacheManager : public IBufferCacheManager {
public:
BucketCacheManager() : buckets_limit_{BUCKET_DEFAULT_LIMIT_TABLE} {
Initialize();
}
BucketCacheManager(std::unordered_map<size_t, size_t>&& buckets_limit) : buckets_limit_{buckets_limit} {
Initialize();
}

size_t CalculateBufferSize(size_t request_size) override {
// binary serch size
auto it = std::lower_bound(buckets_keys_.begin(), buckets_keys_.end(), request_size);
if (it == buckets_keys_.end()) {
return NormalizeBufferSize(request_size);
} else {
return *it;
}
}

WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size, wgpu::BufferUsage usage) override {
std::unordered_map<size_t, std::vector<WGPUBuffer>>* buckets = nullptr;
if (usage | wgpu::BufferUsage::Storage) {
buckets = &buckets_storage_;
} else if (usage | wgpu::BufferUsage::Uniform) {
buckets = &buckets_uniform_;
}
if (buckets) {
auto it = buckets->find(buffer_size);
if (it != buckets->end() && !it->second.empty()) {
auto buffer = it->second.back();
it->second.pop_back();
return buffer;
}
}
return nullptr;
}

void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/) override {
// no-op
}

void ReleaseBuffer(WGPUBuffer buffer) override {
std::vector<WGPUBuffer>* pending_buffers = nullptr;
auto usage = wgpuBufferGetUsage(buffer);
if (usage | WGPUBufferUsage_Storage) {
pending_buffers = &pending_storage_buffers_;
} else if (usage | WGPUBufferUsage_Uniform) {
pending_buffers = &pending_uniform_buffers_;
}
if (pending_buffers) {
pending_buffers->emplace_back(buffer);
} else {
wgpuBufferDestroy(buffer);
}
}

void OnRefresh() override {
// TODO: consider graph capture. currently not supported

for (auto& buffer : pending_storage_buffers_) {
auto buffer_size = wgpuBufferGetSize(buffer);

auto it = buckets_storage_.find(buffer_size);
if (it != buckets_storage_.end() && it->second.size() < buckets_limit_[buffer_size]) {
it->second.push_back(buffer);
} else {
wgpuBufferDestroy(buffer);
}
}

for (auto& buffer : pending_uniform_buffers_) {
auto buffer_size = wgpuBufferGetSize(buffer);

auto it = buckets_uniform_.find(buffer_size);
if (it != buckets_uniform_.end() && it->second.size() < buckets_limit_[buffer_size]) {
it->second.push_back(buffer);
} else {
wgpuBufferDestroy(buffer);
}
}
}

protected:
void Initialize() {
buckets_keys_.reserve(buckets_limit_.size());
buckets_storage_.reserve(buckets_limit_.size());
buckets_uniform_.reserve(buckets_limit_.size());
for (const auto& pair : buckets_limit_) {
buckets_keys_.push_back(pair.first);
buckets_storage_.emplace(pair.first, std::vector<WGPUBuffer>());
buckets_uniform_.emplace(pair.first, std::vector<WGPUBuffer>());
}
#ifndef NDEBUG
for (size_t i = 0; i < buckets_keys_.size(); ++i) {
ORT_ENFORCE(buckets_keys_[i] % 16 == 0, "Bucket sizes must be multiples of 16.");
}

for (size_t i = 1; i < buckets_keys_.size(); ++i) {
ORT_ENFORCE(buckets_keys_[i] > buckets_keys_[i - 1], "Bucket sizes must be in increasing order.");
}
#endif
}
std::unordered_map<size_t, size_t> buckets_limit_;
std::unordered_map<size_t, std::vector<WGPUBuffer>> buckets_storage_;
std::vector<WGPUBuffer> pending_storage_buffers_;
std::unordered_map<size_t, std::vector<WGPUBuffer>> buckets_uniform_;
std::vector<WGPUBuffer> pending_uniform_buffers_;
std::vector<size_t> buckets_keys_;
};

std::unique_ptr<IBufferCacheManager> CreateBufferCacheManager(BufferCacheMode cache_mode) {
switch (cache_mode) {
Expand All @@ -132,13 +229,51 @@ std::unique_ptr<IBufferCacheManager> CreateBufferCacheManager(BufferCacheMode ca
}
}

BufferManager::BufferManager(wgpu::Device device, BufferCacheMode cache_mode) : IBufferManager{device, CreateBufferCacheManager(cache_mode)} {
class BufferManager : public IBufferManager {
public:
BufferManager(const WebGpuContext& context, BufferCacheMode cache_mode);

void Upload(void* src, WGPUBuffer dst, size_t size) const override;
void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const override;
WGPUBuffer Create(size_t size, wgpu::BufferUsage usage) const override;
void Release(WGPUBuffer buffer) const override;
void Download(WGPUBuffer src, void* dst, size_t size) const override;
void RefreshPendingBuffers() const override;
};

BufferManager::BufferManager(const WebGpuContext& context, BufferCacheMode cache_mode) : IBufferManager{context, CreateBufferCacheManager(cache_mode)} {
}

void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) const {
auto buffer_size = NormalizeBufferSize(size);

wgpu::BufferDescriptor desc;
desc.size = buffer_size;
desc.usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapWrite;
desc.mappedAtCreation = true;

auto staging_buffer = context_.Device().CreateBuffer(&desc);
auto mapped_data = staging_buffer.GetMappedRange();
memcpy(mapped_data, src, size);
staging_buffer.Unmap();

auto& command_encoder = context_.GetCommandEncoder();
context_.EndComputePass();
command_encoder.CopyBufferToBuffer(staging_buffer, 0, dst, 0, buffer_size);
staging_buffer.Destroy();
}

void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const {
ORT_ENFORCE(src != dst, "Source and destination buffers must be different.");

auto buffer_size = NormalizeBufferSize(size);
ORT_ENFORCE(buffer_size <= wgpuBufferGetSize(src) && buffer_size <= wgpuBufferGetSize(dst),
"Source and destination buffers must have enough space for the copy operation. src_size=",
wgpuBufferGetSize(src), ", dst_size=", wgpuBufferGetSize(dst), ", copy_size=", buffer_size, ".");

auto& command_encoder = context_.GetCommandEncoder();
context_.EndComputePass();
command_encoder.CopyBufferToBuffer(src, 0, dst, 0, buffer_size);
}

WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) const {
Expand All @@ -153,25 +288,49 @@ WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) const {
wgpu::BufferDescriptor desc;
desc.size = buffer_size;
desc.usage = usage;
buffer = device_.CreateBuffer(&desc);
buffer = context_.Device().CreateBuffer(&desc).MoveToCHandle();

ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", usage, ".");
ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", uint64_t(usage), ".");

cache_->RegisterBuffer(buffer, size, buffer_size, usage);
cache_->RegisterBuffer(buffer, size);
return buffer;
}

void BufferManager::Release(WGPUBuffer buffer) const {
cache_->ReleaseBuffer(buffer, 0, wgpu::BufferUsage::None);
cache_->ReleaseBuffer(buffer);
}

wgpu::Future BufferManager::Download(WGPUBuffer src, void* dst, size_t size) const {
return wgpu::Future();
void BufferManager::Download(WGPUBuffer src, void* dst, size_t size) const {
auto buffer_size = NormalizeBufferSize(size);

wgpu::BufferDescriptor desc;
desc.size = buffer_size;
desc.usage = wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead;

auto staging_buffer = context_.Device().CreateBuffer(&desc);
auto& command_encoder = context_.GetCommandEncoder();
context_.EndComputePass();
command_encoder.CopyBufferToBuffer(src, 0, staging_buffer, 0, buffer_size);
context_.Flush();

wgpu::BufferMapCallbackInfo callback_info;
callback_info.mode = wgpu::CallbackMode::WaitAnyOnly;
callback_info.callback = [](WGPUBufferMapAsyncStatus status, void*) {
ORT_ENFORCE(status == WGPUBufferMapAsyncStatus_Success, "Failed to download data from buffer");
};
ORT_ENFORCE(context_.Wait(staging_buffer.MapAsync(wgpu::MapMode::Read, 0, buffer_size, callback_info)) == Status::OK());

auto mapped_data = staging_buffer.GetMappedRange();
memcpy(dst, mapped_data, size);
}

void BufferManager::RefreshPendingBuffers() const {
cache_->OnRefresh();
}

std::unique_ptr<IBufferManager> BufferManagerFactory::Create(const WebGpuContext& context, BufferCacheMode mode) {
return std::make_unique<BufferManager>(context, mode);
}

} // namespace webgpu
} // namespace onnxruntime
Loading

0 comments on commit fc4f0e6

Please sign in to comment.