Skip to content

Commit

Permalink
w2
Browse files Browse the repository at this point in the history
  • Loading branch information
fs-eire committed Aug 16, 2024
1 parent d6c6c18 commit 237b6a6
Show file tree
Hide file tree
Showing 11 changed files with 290 additions and 74 deletions.
14 changes: 4 additions & 10 deletions onnxruntime/core/providers/webgpu/allocator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,16 @@ void* GpuBufferAllocator::Alloc(size_t size) {
return nullptr;
}

// GetContext().Device().CreateBuffer()

//void* p = EM_ASM_PTR({ return Module.jsepAlloc($0); }, size);
ORT_ENFORCE(false, "not implemented");
auto buffer = context_.BufferManager().Create(size);

stats_.num_allocs++;
stats_.bytes_in_use += size;
return nullptr;
return buffer.Get();
}

void GpuBufferAllocator::Free(void* p) {
if (p != nullptr) {
//size_t size = (size_t)(void*)EM_ASM_PTR({ return Module.jsepFree($0); }, p);
ORT_ENFORCE(false, "not implemented");

//stats_.bytes_in_use -= size;
context_.BufferManager().Release(static_cast<WGPUBuffer>(p));
stats_.num_allocs--;
}
}

Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/webgpu/allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ namespace webgpu {

class GpuBufferAllocator : public IAllocator {
public:
GpuBufferAllocator()
GpuBufferAllocator(const WebGpuContext& context)
: IAllocator(
OrtMemoryInfo(WEBGPU_BUFFER, OrtAllocatorType::OrtDeviceAllocator,
OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0),
0, OrtMemTypeDefault)) {
0, OrtMemTypeDefault)),
context_{context} {
}

virtual void* Alloc(size_t size) override;
Expand All @@ -24,6 +25,7 @@ class GpuBufferAllocator : public IAllocator {

private:
AllocatorStats stats_;
const WebGpuContext& context_;
};

} // namespace webgpu
Expand Down
173 changes: 173 additions & 0 deletions onnxruntime/core/providers/webgpu/buffer_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,176 @@
// Licensed under the MIT License.

#include "core/providers/webgpu/buffer_manager.h"

namespace onnxruntime {
namespace webgpu {

class DisabledCacheManager : public IBufferCacheManager {
size_t CalculateBufferSize(size_t request_size) override {
return (request_size + 15) / 16 * 16;
}

WGPUBuffer TryAcquireCachedBuffer(size_t /*buffer_size*/, wgpu::BufferUsage /*usage*/) override {
// always return empty buffer
return nullptr;
}
void RegisterBuffer(WGPUBuffer /*buffer*/, size_t /*request_size*/, size_t /*buffer_size*/, wgpu::BufferUsage /*usage*/) override {
// no-op
}
void ReleaseBuffer(WGPUBuffer buffer, size_t /*buffer_size*/, wgpu::BufferUsage /*usage*/) override {
wgpuBufferDestroy(buffer);
}

void OnRefresh() override {
// no-op
}
};

class SimpleCacheManager : public IBufferCacheManager {
size_t CalculateBufferSize(size_t request_size) override {
return (request_size + 15) / 16 * 16;
}

WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size, wgpu::BufferUsage usage) override {
if (usage | wgpu::BufferUsage::Storage) {
auto it = buffers_.find(buffer_size);
if (it != buffers_.end() && !it->second.empty()) {
auto buffer = it->second.back();
it->second.pop_back();
return buffer;
}
}

return nullptr;
}
void RegisterBuffer(WGPUBuffer buffer, size_t /*request_size*/, size_t buffer_size, wgpu::BufferUsage usage) override {
}
void ReleaseBuffer(WGPUBuffer buffer, size_t buffer_size, wgpu::BufferUsage usage) override {
if (usage | wgpu::BufferUsage::Storage) {
pending_buffers_.emplace_back(buffer, buffer_size);
} else {
wgpuBufferDestroy(buffer);
}
}
void OnRefresh() override {
for (auto& pair : pending_buffers_) {
buffers_[pair.second].push_back(pair.first);
}
pending_buffers_.clear();
}

std::map<size_t, std::vector<WGPUBuffer>> buffers_;
std::vector<std::pair<WGPUBuffer, size_t>> pending_buffers_;
};

class BucketCacheManager : public IBufferCacheManager {
static const std::unordered_map<size_t, size_t> kBucketSizes;
static const std::array<size_t> kBucketSizesArray;

size_t CalculateBufferSize(size_t request_size) override {
ORT_NOT_IMPLEMENTED("TODO");
}

WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size, wgpu::BufferUsage usage) override {
ORT_NOT_IMPLEMENTED("TODO");
}
void RegisterBuffer(WGPUBuffer buffer, size_t request_size, size_t buffer_size, wgpu::BufferUsage usage) override {
ORT_NOT_IMPLEMENTED("TODO");
}
void ReleaseBuffer(WGPUBuffer buffer, size_t buffer_size, wgpu::BufferUsage usage) override {
ORT_NOT_IMPLEMENTED("TODO");
}
void OnRefresh() override {
ORT_NOT_IMPLEMENTED("TODO");
}
};

constexpr std::initializer_list<std::pair<size_t, size_t>> BUCKET_TABLE = {
{64, 250},
{128, 200},
{256, 200},
{512, 200},
{2048, 230},
{4096, 200},
{8192, 50},
{16384, 50},
{32768, 50},
{65536, 50},
{131072, 50},
{262144, 50},
{524288, 50},
{1048576, 50},
{2097152, 30},
{4194304, 20},
{8388608, 10},
{12582912, 10},
{16777216, 10},
{26214400, 15},
{33554432, 22},
{44236800, 2},
{58982400, 6},
// we don't want to cache the bucket sizes below but not caching them
// results in some major performance hits for models like sd-turbo.
{67108864, 6},
{134217728, 6},
{167772160, 6},
};
const std::unordered_map<size_t, size_t> BucketCacheManager::kBucketSizes{BUCKET_TABLE};


std::unique_ptr<IBufferCacheManager> CreateBufferCacheManager(BufferCacheMode cache_mode) {
switch (cache_mode) {
case BufferCacheMode::None:
return std::make_unique<DisabledCacheManager>();
case BufferCacheMode::Simple:
return std::make_unique<SimpleCacheManager>();
case BufferCacheMode::Bucket:
return std::make_unique<BucketCacheManager>();
default:
ORT_NOT_IMPLEMENTED("Unsupported buffer cache mode");
}
}

BufferManager::BufferManager(wgpu::Device device, BufferCacheMode cache_mode) : IBufferManager{device, CreateBufferCacheManager(cache_mode)} {
}

void BufferManager::Upload(void* src, WGPUBuffer dst, size_t size) const {
}

void BufferManager::MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const {
}

WGPUBuffer BufferManager::Create(size_t size, wgpu::BufferUsage usage) const {
auto buffer_size = cache_->CalculateBufferSize(size);

auto buffer = cache_->TryAcquireCachedBuffer(buffer_size, usage);
if (buffer) {
return buffer;
}

// cache miss, create a new buffer
wgpu::BufferDescriptor desc;
desc.size = buffer_size;
desc.usage = usage;
buffer = device_.CreateBuffer(&desc);

ORT_ENFORCE(buffer, "Failed to create GPU buffer: size=", buffer_size, ", usage=", usage, ".");

cache_->RegisterBuffer(buffer, size, buffer_size, usage);
return buffer;
}

void BufferManager::Release(WGPUBuffer buffer) const {
cache_->ReleaseBuffer(buffer, 0, wgpu::BufferUsage::None);
}

wgpu::Future BufferManager::Download(WGPUBuffer src, void* dst, size_t size) const {
return wgpu::Future();
}

void BufferManager::RefreshPendingBuffers() const {
cache_->OnRefresh();
}

} // namespace webgpu
} // namespace onnxruntime
55 changes: 38 additions & 17 deletions onnxruntime/core/providers/webgpu/buffer_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,41 +15,62 @@
namespace onnxruntime {
namespace webgpu {

enum class BufferCacheMode {
None,
Simple,
Bucket
};

class IBufferCacheManager {
public:
virtual ~IBufferCacheManager() = default;

virtual wgpu::Buffer GetBuffer(wgpu::Device device, size_t size) = 0;
virtual void ReleaseBuffer(wgpu::Buffer buffer) = 0;
// calculate actual buffer size to allocate based on the requested size.
virtual size_t CalculateBufferSize(size_t request_size) = 0;

// return a buffer if available in cache. otherwise empty.
virtual WGPUBuffer TryAcquireCachedBuffer(size_t buffer_size, wgpu::BufferUsage usage) = 0;

// register a newly created buffer
virtual void RegisterBuffer(WGPUBuffer buffer, size_t request_size, size_t buffer_size, wgpu::BufferUsage usage) = 0;

// release a buffer
virtual void ReleaseBuffer(WGPUBuffer buffer, size_t buffer_size, wgpu::BufferUsage usage) = 0;

// when a stream refresh is requested
virtual void OnRefresh() = 0;
};

class IBufferManager {
protected:
IBufferManager(wgpu::Device device) : device_(device) {}
IBufferManager(wgpu::Device device, std::unique_ptr<IBufferCacheManager> cache) : device_{device}, cache_{std::move(cache)} {}

public:
virtual ~IBufferManager() = default;
virtual void Upload(void* src, WGPUBuffer dst, size_t size) = 0;
virtual void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) = 0;
virtual wgpu::Buffer Create(size_t size, wgpu::BufferUsage usage) = 0;
virtual void Release(WGPUBuffer buffer) = 0;
virtual wgpu::Future Download(WGPUBuffer src, void* dst, size_t size) = 0;
virtual void RefreshPendingBuffers() = 0;
virtual void Upload(void* src, WGPUBuffer dst, size_t size) const = 0;
virtual void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const = 0;
virtual WGPUBuffer Create(size_t size, wgpu::BufferUsage usage = wgpu::BufferUsage::Storage | wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::CopyDst) const = 0;
virtual void Release(WGPUBuffer buffer) const = 0;
virtual wgpu::Future Download(WGPUBuffer src, void* dst, size_t size) const = 0;
virtual void RefreshPendingBuffers() const = 0;

// TODO: add statistics

protected:
wgpu::Device device_;
std::unique_ptr<IBufferCacheManager> cache_;
};

class BufferManager : public IBufferManager {
public:
BufferManager(wgpu::Device device) : IBufferManager(device) {}

void Upload(void* src, WGPUBuffer dst, size_t size) override;
void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) override;
wgpu::Buffer Create(size_t size, wgpu::BufferUsage usage) override;
void Release(WGPUBuffer buffer) override;
wgpu::Future Download(WGPUBuffer src, void* dst, size_t size) override;
void RefreshPendingBuffers() override;
BufferManager(wgpu::Device device, BufferCacheMode cache_mode);

void Upload(void* src, WGPUBuffer dst, size_t size) const override;
void MemCpy(WGPUBuffer src, WGPUBuffer dst, size_t size) const override;
WGPUBuffer Create(size_t size, wgpu::BufferUsage usage) const override;
void Release(WGPUBuffer buffer) const override;
wgpu::Future Download(WGPUBuffer src, void* dst, size_t size) const override;
void RefreshPendingBuffers() const override;

private:
struct PendingBuffer {
Expand Down
10 changes: 5 additions & 5 deletions onnxruntime/core/providers/webgpu/data_transfer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,16 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {
if (dst_device.Type() == OrtDevice::GPU) {
if (src_device.Type() == OrtDevice::GPU) {
// copy from GPU to GPU
GetContext().BufferManager().MemCpy(static_cast<WGPUBuffer>(const_cast<void*>(src_data)),
static_cast<WGPUBuffer>(dst_data), bytes);
context_.BufferManager().MemCpy(static_cast<WGPUBuffer>(const_cast<void*>(src_data)),
static_cast<WGPUBuffer>(dst_data), bytes);
} else {
// copy from CPU to GPU
GetContext().BufferManager().Upload(const_cast<void*>(src_data), static_cast<WGPUBuffer>(dst_data), bytes);
context_.BufferManager().Upload(const_cast<void*>(src_data), static_cast<WGPUBuffer>(dst_data), bytes);
}
} else /* if (src_device.Type() == OrtDevice::GPU) */ {
// copy from GPU to CPU
ORT_RETURN_IF_ERROR(GetContext().Wait(
GetContext().BufferManager().Download(static_cast<WGPUBuffer>(const_cast<void*>(src_data)), dst_data, bytes)));
ORT_RETURN_IF_ERROR(context_.Wait(
context_.BufferManager().Download(static_cast<WGPUBuffer>(const_cast<void*>(src_data)), dst_data, bytes)));
}
}

Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/core/providers/webgpu/data_transfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ namespace webgpu {

class DataTransfer : public IDataTransfer {
public:
DataTransfer() {};
~DataTransfer() {};
DataTransfer(const WebGpuContext& context) : context_{context} {};
~DataTransfer(){};

bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override;

common::Status CopyTensor(const Tensor& src, Tensor& dst) const override;

private:
const WebGpuContext& context_;
};

} // namespace webgpu
Expand Down
25 changes: 19 additions & 6 deletions onnxruntime/core/providers/webgpu/webgpu_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,8 @@ wgpu::RequiredLimits GetAvailableRequiredLimits(const wgpu::Adapter& adapter) {
return required_limits;
}

void WebGpuContext::Init() {
static std::once_flag init_flag;
std::call_once(init_flag, [this]() {
void WebGpuContext::Initialize() {
std::call_once(init_flag_, [this]() {
// Initialization.Step.1 - Create wgpu::Instance

wgpu::InstanceDescriptor instance_desc{};
Expand Down Expand Up @@ -96,9 +95,23 @@ void WebGpuContext::Init() {
});
}

WebGpuContext& GetContext() {
static WebGpuContext context;
return context;
Status WebGpuContext::Wait(wgpu::Future f) const {
auto status = instance_.WaitAny(f, UINT64_MAX);
if (status == wgpu::WaitStatus::Success) {
return Status::OK();
}
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to wait for the operation:", uint32_t(status));
}

WebGpuContext& WebGpuContextFactory::GetOrCreateContext(int32_t context_id) {
std::lock_guard<std::mutex> lock(mutex_);

auto it = contexts_.find(context_id);
if (it == contexts_.end()) {
auto context = std::make_unique<WebGpuContext>();
it = contexts_.emplace(context_id, std::move(context)).first;
}
return *it->second;
}

} // namespace webgpu
Expand Down
Loading

0 comments on commit 237b6a6

Please sign in to comment.