Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enrich cuda resources with ep options #19014

Merged
merged 9 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,38 +28,45 @@
cudnnHandle_t cudnn_handle = {};
cublasHandle_t cublas_handle = {};
OrtAllocator* deferred_cpu_allocator = {};
// below are cuda ep options
int16_t device_id = 0;
int32_t arena_extend_strategy = 0;
int32_t cudnn_conv_algo_search = 0;
bool cudnn_conv_use_max_workspace = true;
bool cudnn_conv1d_pad_to_nc1d = false;
bool enable_skip_layer_norm_strict_mode = false;
bool prefer_nhwc = false;

void Init(const OrtKernelContext& kernel_ctx) {
const auto& ort_api = Ort::GetApi();
void* resource = {};
OrtStatus* status = nullptr;

status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cuda_stream_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch cuda stream", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cuda_stream = reinterpret_cast<cudaStream_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cudnn_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch cudnn handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cudnn_handle = reinterpret_cast<cudnnHandle_t>(resource);
cuda_stream = FetchResource<cudaStream_t>(kernel_ctx, CudaResource::cuda_stream_t);
cudnn_handle = FetchResource<cudnnHandle_t>(kernel_ctx, CudaResource::cudnn_handle_t);
cublas_handle = FetchResource<cublasHandle_t>(kernel_ctx, CudaResource::cublas_handle_t);
deferred_cpu_allocator = FetchResource<OrtAllocator*>(kernel_ctx, CudaResource::deferred_cpu_allocator_t);

device_id = FetchResource<int16_t>(kernel_ctx, CudaResource::device_id_t);
arena_extend_strategy = FetchResource<int32_t>(kernel_ctx, CudaResource::arena_extend_strategy_t);
cudnn_conv_algo_search = FetchResource<int32_t>(kernel_ctx, CudaResource::cudnn_conv_algo_search_t);
cudnn_conv_use_max_workspace = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv_use_max_workspace_t);

cudnn_conv1d_pad_to_nc1d = FetchResource<bool>(kernel_ctx, CudaResource::cudnn_conv1d_pad_to_nc1d_t);
enable_skip_layer_norm_strict_mode = FetchResource<bool>(kernel_ctx, CudaResource::enable_skip_layer_norm_strict_mode_t);

Check warning on line 52 in include/onnxruntime/core/providers/cuda/cuda_context.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/cuda/cuda_context.h#L52

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/cuda/cuda_context.h:52:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
prefer_nhwc = FetchResource<bool>(kernel_ctx, CudaResource::prefer_nhwc_t);
}

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::cublas_handle_t, &resource);
if (status) {
ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
template <typename T>
T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) {
if (sizeof(T) > sizeof(void*)) {
ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), OrtErrorCode::ORT_INVALID_ARGUMENT);

Check warning on line 59 in include/onnxruntime/core/providers/cuda/cuda_context.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/cuda/cuda_context.h#L59

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/cuda/cuda_context.h:59:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
cublas_handle = reinterpret_cast<cublasHandle_t>(resource);

resource = {};
status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource);
const auto& ort_api = Ort::GetApi();
void* resource = {};
OrtStatus* status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, resource_type, &resource);

Check warning on line 63 in include/onnxruntime/core/providers/cuda/cuda_context.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/cuda/cuda_context.h#L63

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/cuda/cuda_context.h:63:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (status) {
ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
ORT_CXX_API_THROW("Failed to fetch cuda ep resource, resouce type: " + std::to_string(resource_type), OrtErrorCode::ORT_RUNTIME_EXCEPTION);

Check warning on line 65 in include/onnxruntime/core/providers/cuda/cuda_context.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/cuda/cuda_context.h#L65

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/cuda/cuda_context.h:65:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
deferred_cpu_allocator = reinterpret_cast<OrtAllocator*>(resource);
T t = {};
memcpy(&t, &resource, sizeof(T));
return t;
}

void* AllocDeferredCpuMem(size_t size) const {
Expand Down
12 changes: 10 additions & 2 deletions include/onnxruntime/core/providers/cuda/cuda_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,19 @@

#include "core/providers/resource.h"

#define ORT_CUDA_RESOUCE_VERSION 2
#define ORT_CUDA_RESOUCE_VERSION 3

enum CudaResource : int {
cuda_stream_t = cuda_resource_offset,
cuda_stream_t = cuda_resource_offset, // 10000
cudnn_handle_t,
cublas_handle_t,
deferred_cpu_allocator_t,
// below are cuda ep options
device_id_t, // 10004
arena_extend_strategy_t,
cudnn_conv_algo_search_t,
cudnn_conv_use_max_workspace_t,
cudnn_conv1d_pad_to_nc1d_t,
enable_skip_layer_norm_strict_mode_t,
prefer_nhwc_t,
};
2 changes: 1 addition & 1 deletion include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4418,7 +4418,7 @@ struct OrtApi {
ORT_API2_STATUS(GetCUDAProviderOptionsByName, _In_ const OrtCUDAProviderOptionsV2* cuda_options, _In_ const char* key, _Outptr_ void** ptr);

/**
* Get a EP resoure.
* Get a EP resource.
* E.g. a cuda stream or a cublas handle
*
* \param context - Kernel context
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2465,7 +2465,8 @@ void CUDAExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegistry&
stream_,
use_ep_level_unified_stream_,
GetPerThreadContext().CudnnHandle(),
GetPerThreadContext().CublasHandle());
GetPerThreadContext().CublasHandle(),
info_);
}

OrtDevice CUDAExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
Expand Down
45 changes: 35 additions & 10 deletions onnxruntime/core/providers/cuda/cuda_stream_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,13 @@
bool release_cpu_buffer_on_cuda_stream,
bool own_flag,
cudnnHandle_t external_cudnn_handle,
cublasHandle_t external_cublas_handle) : Stream(stream, device),
own_stream_(own_flag),
cpu_allocator_(cpu_allocator),
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream),
deferred_cpu_allocator_(*this) {
cublasHandle_t external_cublas_handle,
const CUDAExecutionProviderInfo& ep_info) : Stream(stream, device),
own_stream_(own_flag),
cpu_allocator_(cpu_allocator),
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream),

Check warning on line 69 in onnxruntime/core/providers/cuda/cuda_stream_handle.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/cuda_stream_handle.cc#L69

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cuda/cuda_stream_handle.cc:69:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
deferred_cpu_allocator_(*this),
ep_info_(ep_info) {
if (own_flag) {
CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_));
CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream));
Expand Down Expand Up @@ -185,6 +187,27 @@
case CudaResource::deferred_cpu_allocator_t:
return const_cast<DeferredCpuAllocator*>(&deferred_cpu_allocator_);
break;
case CudaResource::device_id_t:
return reinterpret_cast<void*>(ep_info_.device_id);
break;
case CudaResource::arena_extend_strategy_t:
return reinterpret_cast<void*>(ep_info_.arena_extend_strategy);
break;
case CudaResource::cudnn_conv_algo_search_t:
return reinterpret_cast<void*>(ep_info_.cudnn_conv_algo_search);
break;
case CudaResource::cudnn_conv_use_max_workspace_t:
return reinterpret_cast<void*>(ep_info_.cudnn_conv_use_max_workspace);
break;
case CudaResource::cudnn_conv1d_pad_to_nc1d_t:
return reinterpret_cast<void*>(ep_info_.cudnn_conv1d_pad_to_nc1d);
break;
case CudaResource::enable_skip_layer_norm_strict_mode_t:
return reinterpret_cast<void*>(ep_info_.enable_skip_layer_norm_strict_mode);
break;
case CudaResource::prefer_nhwc_t:
return reinterpret_cast<void*>(ep_info_.prefer_nhwc);
break;
default:
break;
}
Expand All @@ -207,26 +230,28 @@
cudaStream_t external_stream,
bool use_existing_stream,
cudnnHandle_t external_cudnn_handle,
cublasHandle_t external_cublas_handle) {
cublasHandle_t external_cublas_handle,
const CUDAExecutionProviderInfo& ep_info) {
// wait cuda notification on cuda ep
stream_handle_registry.RegisterWaitFn(device_type, device_type, WaitCudaNotificationOnDevice);
// wait cuda notification on cpu ep
stream_handle_registry.RegisterWaitFn(device_type, OrtDevice::CPU, WaitCudaNotificationOnHost);
if (!use_existing_stream)
stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_cuda_stream](const OrtDevice& device) {
stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator, release_cpu_buffer_on_cuda_stream, ep_info](const OrtDevice& device) {

Check warning on line 240 in onnxruntime/core/providers/cuda/cuda_stream_handle.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/cuda_stream_handle.cc#L240

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cuda/cuda_stream_handle.cc:240:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
CUDA_CALL_THROW(cudaSetDevice(device.Id()));
cudaStream_t stream = nullptr;
CUDA_CALL_THROW(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
// CUDA_CALL_THROW(cudaStreamCreate(&stream));
return std::make_unique<CudaStream>(stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, true, nullptr, nullptr);
return std::make_unique<CudaStream>(stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, true, nullptr, nullptr, ep_info);

Check warning on line 245 in onnxruntime/core/providers/cuda/cuda_stream_handle.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/cuda_stream_handle.cc#L245

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cuda/cuda_stream_handle.cc:245:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
});
else
stream_handle_registry.RegisterCreateStreamFn(device_type, [cpu_allocator,
release_cpu_buffer_on_cuda_stream,
external_stream,
external_cudnn_handle,
external_cublas_handle](const OrtDevice& device) {
return std::make_unique<CudaStream>(external_stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, false, external_cudnn_handle, external_cublas_handle);
external_cublas_handle,
ep_info](const OrtDevice& device) {
return std::make_unique<CudaStream>(external_stream, device, cpu_allocator, release_cpu_buffer_on_cuda_stream, false, external_cudnn_handle, external_cublas_handle, ep_info);

Check warning on line 254 in onnxruntime/core/providers/cuda/cuda_stream_handle.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/cuda_stream_handle.cc#L254

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/cuda/cuda_stream_handle.cc:254:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

Check warning on line 254 in onnxruntime/core/providers/cuda/cuda_stream_handle.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/cuda/cuda_stream_handle.cc#L254

Add #include <memory> for make_unique<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/cuda/cuda_stream_handle.cc:254:  Add #include <memory> for make_unique<>  [build/include_what_you_use] [4]
});
}

Expand Down
8 changes: 6 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_stream_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/framework/stream_handles.h"
#include "core/providers/cuda/cuda_execution_provider_info.h"

namespace onnxruntime {

Expand All @@ -23,7 +24,8 @@ struct CudaStream : Stream {
bool release_cpu_buffer_on_cuda_stream,
bool own_flag,
cudnnHandle_t external_cudnn_handle,
cublasHandle_t external_cublass_handle);
cublasHandle_t external_cublass_handle,
const CUDAExecutionProviderInfo& ep_info);

~CudaStream();

Expand All @@ -50,6 +52,7 @@ struct CudaStream : Stream {
AllocatorPtr cpu_allocator_;
bool release_cpu_buffer_on_cuda_stream_{true};
DeferredCpuAllocator deferred_cpu_allocator_;
const CUDAExecutionProviderInfo ep_info_;
};

void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry,
Expand All @@ -59,6 +62,7 @@ void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_regis
cudaStream_t external_stream,
bool use_existing_stream,
cudnnHandle_t external_cudnn_handle,
cublasHandle_t external_cublass_handle);
cublasHandle_t external_cublass_handle,
const CUDAExecutionProviderInfo& ep_info);
void WaitCudaNotificationOnDevice(Stream& stream, synchronize::Notification& notification);
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -3473,7 +3473,8 @@ void TensorrtExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegis
stream_,
external_stream_ /* use_existing_stream */,
external_cudnn_handle_,
external_cublas_handle_);
external_cublas_handle_,
{});
}

OrtDevice TensorrtExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
Expand Down
3 changes: 0 additions & 3 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,6 @@ ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetResource, _In_ const OrtKernelCont
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Failed to fetch a stream hosting the requested resource");
}
*resource = stream->GetResource(resource_version, resource_id);
if (!(*resource)) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "Requested resource does not exist");
}
return nullptr;
API_IMPL_END
};
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@ void KernelOne(const Ort::Custom::CudaContext& cuda_ctx,
const Ort::Custom::Tensor<float>& X,
const Ort::Custom::Tensor<float>& Y,
Ort::Custom::Tensor<float>& Z) {
auto input_shape = X.Shape();
CUSTOM_ENFORCE(cuda_ctx.cuda_stream, "failed to fetch cuda stream");
CUSTOM_ENFORCE(cuda_ctx.cudnn_handle, "failed to fetch cudnn handle");
CUSTOM_ENFORCE(cuda_ctx.cublas_handle, "failed to fetch cublas handle");
CUSTOM_ENFORCE(cuda_ctx.arena_extend_strategy == 0, "arena_extend_strategy mismatch");
void* deferred_cpu_mem = cuda_ctx.AllocDeferredCpuMem(sizeof(int32_t));
CUSTOM_ENFORCE(deferred_cpu_mem, "failed to allocate deferred cpu allocator");
cuda_ctx.FreeDeferredCpuMem(deferred_cpu_mem);
auto z_raw = Z.Allocate(input_shape);
auto z_raw = Z.Allocate(X.Shape());
cuda_add(Z.NumberOfElement(), z_raw, X.Data(), Y.Data(), cuda_ctx.cuda_stream);
}

Expand Down
Loading