From 8774631aaf51df42d2bfa53811de814d289cd109 Mon Sep 17 00:00:00 2001 From: Randy Shuai Date: Mon, 8 Jan 2024 15:40:43 -0800 Subject: [PATCH] type punning --- include/onnxruntime/core/providers/cuda/cuda_context.h | 7 ++++++- .../core/providers/tensorrt/tensorrt_execution_provider.cc | 3 ++- .../test/testdata/custom_op_library/cuda/cuda_ops.cc | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 635849aa6a3b3..0b731c222f6a6 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -80,13 +80,18 @@ struct CudaContext : public CustomOpContext { template T FetchResource(const OrtKernelContext& kernel_ctx, CudaResource resource_type) { + if (sizeof(T) > sizeof(void*)) { + ORT_CXX_API_THROW("void* is not large enough to hold resource type: " + std::to_string(resource_type), OrtErrorCode::ORT_INVALID_ARGUMENT); + } const auto& ort_api = Ort::GetApi(); void* resource = {}; OrtStatus* status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, resource_type, &resource); if (status) { ORT_CXX_API_THROW("Failed to fetch cuda ep resource, resouce type: " + std::to_string(resource_type), OrtErrorCode::ORT_RUNTIME_EXCEPTION); } - return static_cast(*reinterpret_cast(&resource)); + T t = {}; + memcpy(&t, &resource, sizeof(T)); + return t; } void* AllocDeferredCpuMem(size_t size) const { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 684303a8b6448..7397b84373db7 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -3473,7 +3473,8 @@ void TensorrtExecutionProvider::RegisterStreamHandlers(IStreamCommandHandleRegis stream_, external_stream_ /* use_existing_stream */, external_cudnn_handle_, - external_cublas_handle_); + external_cublas_handle_, + {}); } OrtDevice TensorrtExecutionProvider::GetOrtDeviceByMemType(OrtMemType mem_type) const { diff --git a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc index 2708b6d38aedb..05fb5147e4815 100644 --- a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc @@ -31,7 +31,7 @@ void KernelOne(const Ort::Custom::CudaContext& cuda_ctx, CUSTOM_ENFORCE(cuda_ctx.cuda_stream, "failed to fetch cuda stream"); CUSTOM_ENFORCE(cuda_ctx.cudnn_handle, "failed to fetch cudnn handle"); CUSTOM_ENFORCE(cuda_ctx.cublas_handle, "failed to fetch cublas handle"); - CUSTOM_ENFORCE(cuda_ctx.gpu_mem_limit == std::numeric_limits::max(), ""); + CUSTOM_ENFORCE(cuda_ctx.gpu_mem_limit == std::numeric_limits::max(), "gpu_mem_limit mismatch"); void* deferred_cpu_mem = cuda_ctx.AllocDeferredCpuMem(sizeof(int32_t)); CUSTOM_ENFORCE(deferred_cpu_mem, "failed to allocate deferred cpu allocator"); cuda_ctx.FreeDeferredCpuMem(deferred_cpu_mem);