Skip to content

Commit

Permalink
Allow cuda custom ops allocate deferred cpu mem (#17893)
Browse files Browse the repository at this point in the history
Expose a new allocator from cuda stream.
The allocator manages deferred cpu memory which only get recycled before
stream destruction.

---------

Co-authored-by: Randy Shuai <[email protected]>
  • Loading branch information
RandySheriffH and RandyShuai authored Oct 20, 2023
1 parent 2f57625 commit 009cd4e
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 17 deletions.
31 changes: 31 additions & 0 deletions include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ struct CudaContext : public CustomOpContext {
cudaStream_t cuda_stream = {};
cudnnHandle_t cudnn_handle = {};
cublasHandle_t cublas_handle = {};
OrtAllocator* deferred_cpu_allocator = {};

void Init(const OrtKernelContext& kernel_ctx) override {
const auto& ort_api = Ort::GetApi();
Expand All @@ -44,6 +45,36 @@ struct CudaContext : public CustomOpContext {
ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cublas_handle = reinterpret_cast<cublasHandle_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
deferred_cpu_allocator = reinterpret_cast<OrtAllocator*>(resource);
}

void* AllocDeferredCpuMem(size_t size) const {
if (0 == size) {
return {};
}
const auto& ort_api = Ort::GetApi();
void* mem = {};
auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem);
if (status) {
ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
return mem;
}

void FreeDeferredCpuMem(void* mem) const {
if (mem) {
const auto& ort_api = Ort::GetApi();
auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem);
if (status) {
ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
}
}
};

Expand Down
5 changes: 3 additions & 2 deletions include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

#include "core/providers/resource.h"

#define ORT_CUDA_RESOUCE_VERSION 1
#define ORT_CUDA_RESOUCE_VERSION 2

enum CudaResource : int {
cuda_stream_t = cuda_resource_offset,
cudnn_handle_t,
cublas_handle_t
cublas_handle_t,
deferred_cpu_allocator_t,
};
25 changes: 24 additions & 1 deletion onnxruntime/core/providers/cuda/cuda_stream_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@

namespace onnxruntime {

DeferredCpuAllocator::DeferredCpuAllocator(CudaStream& cuda_stream) : cuda_stream_(cuda_stream) {
OrtAllocator::version = ORT_API_VERSION;
OrtAllocator::Alloc =
[](OrtAllocator* this_, size_t size) {
auto self = reinterpret_cast<DeferredCpuAllocator*>(this_);
return self->cuda_stream_.GetCpuAllocator()->Alloc(size);
};
OrtAllocator::Free =
[](OrtAllocator* this_, void* p) {
auto self = reinterpret_cast<DeferredCpuAllocator*>(this_);
self->cuda_stream_.EnqueDeferredCPUBuffer(p);
};
OrtAllocator::Info =
[](const OrtAllocator* this_) {
auto self = reinterpret_cast<const DeferredCpuAllocator*>(this_);
return &self->cuda_stream_.GetCpuAllocator()->Info();
};
}

struct CudaNotification : public synchronize::Notification {
CudaNotification(Stream& s) : Notification(s) {
CUDA_CALL_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming));
Expand Down Expand Up @@ -46,7 +65,8 @@ CudaStream::CudaStream(cudaStream_t stream,
cublasHandle_t external_cublas_handle) : Stream(stream, device),
own_stream_(own_flag),
cpu_allocator_(cpu_allocator),
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream) {
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream),
deferred_cpu_allocator_(*this) {
if (own_flag) {
CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_));
CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream));
Expand Down Expand Up @@ -162,6 +182,9 @@ void* CudaStream::GetResource(int version, int id) const {
case CudaResource::cublas_handle_t:
return reinterpret_cast<void*>(cublas_handle_);
break;
case CudaResource::deferred_cpu_allocator_t:
return const_cast<DeferredCpuAllocator*>(&deferred_cpu_allocator_);
break;
default:
break;
}
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_stream_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@

namespace onnxruntime {

struct CudaStream;

struct DeferredCpuAllocator : public OrtAllocator {
DeferredCpuAllocator(CudaStream&);
CudaStream& cuda_stream_;
};

struct CudaStream : Stream {
CudaStream(cudaStream_t stream,
const OrtDevice& device,
Expand Down Expand Up @@ -36,10 +43,13 @@ struct CudaStream : Stream {

void* GetResource(int version, int id) const override;

onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); }

private:
std::vector<void*> deferred_cpu_buffers_;
AllocatorPtr cpu_allocator_;
bool release_cpu_buffer_on_cuda_stream_{true};
DeferredCpuAllocator deferred_cpu_allocator_;
};

void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,
Expand Down
9 changes: 4 additions & 5 deletions onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#ifdef USE_CUDA
#if defined(USE_CUDA) && !defined(ENABLE_TRAINING)

#define ORT_API_MANUAL_INIT
#include "onnxruntime_cxx_api.h"
Expand Down Expand Up @@ -32,6 +32,9 @@ void KernelOne(const Ort::Custom::CudaContext& cuda_ctx,
CUSTOM_ENFORCE(cuda_ctx.cuda_stream, "failed to fetch cuda stream");
CUSTOM_ENFORCE(cuda_ctx.cudnn_handle, "failed to fetch cudnn handle");
CUSTOM_ENFORCE(cuda_ctx.cublas_handle, "failed to fetch cublas handle");
void* deferred_cpu_mem = cuda_ctx.AllocDeferredCpuMem(sizeof(int32_t));
CUSTOM_ENFORCE(deferred_cpu_mem, "failed to allocate deferred cpu allocator");
cuda_ctx.FreeDeferredCpuMem(deferred_cpu_mem);
auto z_raw = Z.Allocate(input_shape);
cuda_add(Z.NumberOfElement(), z_raw, X.Data(), Y.Data(), cuda_ctx.cuda_stream);
}
Expand All @@ -43,8 +46,4 @@ void RegisterOps(Ort::CustomOpDomain& domain) {

} // namespace Cuda

#else

void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {}

#endif
10 changes: 9 additions & 1 deletion onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@

namespace Cuda {

#if defined(USE_CUDA) && !defined(ENABLE_TRAINING)

void RegisterOps(Ort::CustomOpDomain& domain);

}
#else

void RegisterOps(Ort::CustomOpDomain&) {}

#endif

} // namespace Cuda
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "core/framework/ortdevice.h"
#include "core/framework/ortmemoryinfo.h"
#include "cpu/cpu_ops.h"
#include "cuda/cuda_ops.h"
#include "rocm/rocm_ops.h"
#include "onnxruntime_lite_custom_op.h"

static const char* c_OpDomain = "test.customop";
Expand All @@ -31,10 +33,15 @@ OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtA
ORT_TRY {
Ort::CustomOpDomain domain{c_OpDomain};
Cpu::RegisterOps(domain);

Ort::CustomOpDomain domain_v2{"v2"};
Cpu::RegisterOps(domain_v2);

Cuda::RegisterOps(domain);
Cuda::RegisterOps(domain_v2);

Rocm::RegisterOps(domain);
Rocm::RegisterOps(domain_v2);

Ort::UnownedSessionOptions session_options(options);
session_options.Add(domain);
session_options.Add(domain_v2);
Expand Down
8 changes: 2 additions & 6 deletions onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ using namespace Ort::Custom;
throw std::runtime_error(msg); \
}

namespace Cuda {
namespace Rocm {

void KernelOne(const Ort::Custom::RocmContext& rocm_ctx,
const Ort::Custom::Tensor<float>& X,
Expand All @@ -38,10 +38,6 @@ void RegisterOps(Ort::CustomOpDomain& domain) {
domain.Add(c_CustomOpOne.get());
}

} // namespace Cuda

#else

void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {}
} // namespace Rocm

#endif
10 changes: 9 additions & 1 deletion onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@

namespace Rocm {

#ifdef USE_ROCM

void RegisterOps(Ort::CustomOpDomain& domain);

}
#else

inline void RegisterOps(Ort::CustomOpDomain&) {}

#endif

} // namespace Rocm

0 comments on commit 009cd4e

Please sign in to comment.