Skip to content

Commit

Permalink
type punning
Browse files Browse the repository at this point in the history
  • Loading branch information
RandyShuai committed Jan 8, 2024
1 parent d68d1cb commit 8774631
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 3 deletions.
7 changes: 6 additions & 1 deletion include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,18 @@ struct CudaContext : public CustomOpContext {

template<typename T>
T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) {
if (sizeof(T) > sizeof(void*)) {
ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), OrtErrorCode::ORT_INVALID_ARGUMENT);

Check warning on line 84 in include/onnxruntime/core/providers/cuda/cuda_context.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/cuda/cuda_context.h#L84

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/cuda/cuda_context.h:84:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
const auto& ort_api = Ort::GetApi();
void* resource = {};
OrtStatus* status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, resource_type, &resource);

Check warning on line 88 in include/onnxruntime/core/providers/cuda/cuda_context.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/cuda/cuda_context.h#L88

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/cuda/cuda_context.h:88:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (status) {
ORT_CXX_API_THROW("Failed to fetch cuda ep resource, resouce type: " + std::to_string(resource_type), OrtErrorCode::ORT_RUNTIME_EXCEPTION);

Check warning on line 90 in include/onnxruntime/core/providers/cuda/cuda_context.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/cuda/cuda_context.h#L90

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/cuda/cuda_context.h:90:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
return static_cast<T>(*reinterpret_cast<T*>(&resource));
T t = {};
memcpy(&t, &resource, sizeof(T));
return t;
}

void* AllocDeferredCpuMem(size_t size) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3473,7 +3473,8 @@ void TensorrtExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegis
stream_,
external_stream_ /* use_existing_stream */,
external_cudnn_handle_,
external_cublas_handle_);
external_cublas_handle_,
{});
}

OrtDevice TensorrtExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void KernelOne(const Ort::Custom::CudaContext& cuda_ctx,
CUSTOM_ENFORCE(cuda_ctx.cuda_stream, "failed to fetch cuda stream");
CUSTOM_ENFORCE(cuda_ctx.cudnn_handle, "failed to fetch cudnn handle");
CUSTOM_ENFORCE(cuda_ctx.cublas_handle, "failed to fetch cublas handle");
CUSTOM_ENFORCE(cuda_ctx.gpu_mem_limit == std::numeric_limits<size_t>::max(), "");
CUSTOM_ENFORCE(cuda_ctx.gpu_mem_limit == std::numeric_limits<size_t>::max(), "gpu_mem_limit mismatch");
void* deferred_cpu_mem = cuda_ctx.AllocDeferredCpuMem(sizeof(int32_t));
CUSTOM_ENFORCE(deferred_cpu_mem, "failed to allocate deferred cpu allocator");
cuda_ctx.FreeDeferredCpuMem(deferred_cpu_mem);
Expand Down

0 comments on commit 8774631

Please sign in to comment.