Skip to content

Commit

Permalink
[TensorRT EP] Enable a minimal CUDA EP compilation without kernels (m…
Browse files Browse the repository at this point in the history
…icrosoft#19052)

Adresses microsoft#18542.
I followed the advice given by @RyanUnderhill
[here](microsoft#18731 (comment))
and went with a minimal CUDA EP for now.
  • Loading branch information
gedoensmax authored Jan 17, 2024
1 parent bd9d8fb commit bc219ed
Show file tree
Hide file tree
Showing 11 changed files with 97 additions and 39 deletions.
1 change: 1 addition & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF)
cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS;LINUX" OFF)

option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" OFF)
option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF)
option(onnxruntime_ENABLE_CUDA_LINE_NUMBER_INFO "When building with CUDA support, generate device code line number information." OFF)
option(onnxruntime_USE_OPENVINO "Build with OpenVINO support" OFF)
option(onnxruntime_USE_COREML "Build with CoreML support" OFF)
Expand Down
49 changes: 37 additions & 12 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc"
)

if (onnxruntime_CUDA_MINIMAL)
file(GLOB onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/tunable/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/tunable/*.cc"
)
# Remove pch files
list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs
"${ONNXRUNTIME_ROOT}/core/providers/cuda/integer_gemm.cc"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/triton_kernel.h"
)
else()
file(GLOB_RECURSE onnxruntime_providers_cuda_cc_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cc"
)
endif()
# Remove pch files
list(REMOVE_ITEM onnxruntime_providers_cuda_cc_srcs
"${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_pch.h"
Expand All @@ -16,11 +31,16 @@
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h"
"${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc"
)
file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh"
)


if (onnxruntime_CUDA_MINIMAL)
set(onnxruntime_providers_cuda_shared_srcs "")
else()
file(GLOB_RECURSE onnxruntime_providers_cuda_cu_srcs CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cu"
"${ONNXRUNTIME_ROOT}/core/providers/cuda/*.cuh"
)
endif()
source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs})
set(onnxruntime_providers_cuda_src ${onnxruntime_providers_cuda_cc_srcs} ${onnxruntime_providers_cuda_shared_srcs} ${onnxruntime_providers_cuda_cu_srcs})

Expand Down Expand Up @@ -156,10 +176,15 @@
endif()

add_dependencies(${target} onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES})
target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
if(onnxruntime_CUDNN_HOME)
target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include)
target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib)
if(onnxruntime_CUDA_MINIMAL)
target_compile_definitions(${target} PRIVATE USE_CUDA_MINIMAL)
target_link_libraries(${target} PRIVATE ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
else()
target_link_libraries(${target} PRIVATE cublasLt cublas cudnn curand cufft ${ABSEIL_LIBS} ${ONNXRUNTIME_PROVIDERS_SHARED} Boost::mp11 safeint_interface)
if(onnxruntime_CUDNN_HOME)
target_include_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/include)
target_link_directories(${target} PRIVATE ${onnxruntime_CUDNN_HOME}/lib)
endif()
endif()

if (onnxruntime_USE_TRITON_KERNEL)
Expand Down
3 changes: 2 additions & 1 deletion include/onnxruntime/core/providers/cuda/cuda_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
#include "core/providers/custom_op_context.h"
#include <cuda.h>
#include <cuda_runtime.h>
#ifndef USE_CUDA_MINIMAL
#include <cublas_v2.h>
#include <cudnn.h>

#endif
namespace Ort {

namespace Custom {
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_call.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ const char* CudaErrString<cudaError_t>(cudaError_t x) {
return cudaGetErrorString(x);
}

#ifndef USE_CUDA_MINIMAL
template <>
const char* CudaErrString<cublasStatus_t>(cublasStatus_t e) {
cudaDeviceSynchronize();
Expand Down Expand Up @@ -76,6 +77,7 @@ const char* CudaErrString<cufftResult>(cufftResult e) {
return "Unknown cufft error status";
}
}
#endif

#ifdef ORT_USE_NCCL
template <>
Expand Down Expand Up @@ -132,6 +134,7 @@ std::conditional_t<THRW, void, Status> CudaCall(

template Status CudaCall<cudaError, false>(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line);
template void CudaCall<cudaError, true>(cudaError retCode, const char* exprString, const char* libName, cudaError successCode, const char* msg, const char* file, const int line);
#ifndef USE_CUDA_MINIMAL
template Status CudaCall<cublasStatus_t, false>(cublasStatus_t retCode, const char* exprString, const char* libName, cublasStatus_t successCode, const char* msg, const char* file, const int line);
template void CudaCall<cublasStatus_t, true>(cublasStatus_t retCode, const char* exprString, const char* libName, cublasStatus_t successCode, const char* msg, const char* file, const int line);
template Status CudaCall<cudnnStatus_t, false>(cudnnStatus_t retCode, const char* exprString, const char* libName, cudnnStatus_t successCode, const char* msg, const char* file, const int line);
Expand All @@ -140,6 +143,7 @@ template Status CudaCall<curandStatus_t, false>(curandStatus_t retCode, const ch
template void CudaCall<curandStatus_t, true>(curandStatus_t retCode, const char* exprString, const char* libName, curandStatus_t successCode, const char* msg, const char* file, const int line);
template Status CudaCall<cufftResult, false>(cufftResult retCode, const char* exprString, const char* libName, cufftResult successCode, const char* msg, const char* file, const int line);
template void CudaCall<cufftResult, true>(cufftResult retCode, const char* exprString, const char* libName, cufftResult successCode, const char* msg, const char* file, const int line);
#endif

#ifdef ORT_USE_NCCL
template Status CudaCall<ncclResult_t, false>(ncclResult_t retCode, const char* exprString, const char* libName, ncclResult_t successCode, const char* msg, const char* file, const int line);
Expand Down
42 changes: 22 additions & 20 deletions onnxruntime/core/providers/cuda/cuda_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,27 @@ namespace cuda {
// 0x04 - pedantic
constexpr const char* kCudaGemmOptions = "ORT_CUDA_GEMM_OPTIONS";

const char* CudaDataTypeToString(cudaDataType_t dt) {
switch (dt) {
case CUDA_R_16F:
return "CUDA_R_16F";
case CUDA_R_16BF:
return "CUDA_R_16BF";
case CUDA_R_32F:
return "CUDA_R_32F";
#if !defined(DISABLE_FLOAT8_TYPES)
// Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8
case CUDA_R_8F_E4M3:
return "CUDA_R_8F_E4M3";
case CUDA_R_8F_E5M2:
return "CUDA_R_8F_E5M2";
#endif
default:
return "<unknown>";
}
}

#ifndef USE_CUDA_MINIMAL
// Initialize the singleton instance
HalfGemmOptions HalfGemmOptions::instance;

Expand Down Expand Up @@ -54,26 +75,6 @@ const char* cublasGetErrorEnum(cublasStatus_t error) {
}
}

const char* CudaDataTypeToString(cudaDataType_t dt) {
switch (dt) {
case CUDA_R_16F:
return "CUDA_R_16F";
case CUDA_R_16BF:
return "CUDA_R_16BF";
case CUDA_R_32F:
return "CUDA_R_32F";
#if !defined(DISABLE_FLOAT8_TYPES)
// Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8
case CUDA_R_8F_E4M3:
return "CUDA_R_8F_E4M3";
case CUDA_R_8F_E5M2:
return "CUDA_R_8F_E5M2";
#endif
default:
return "<unknown>";
}
}

const char* CublasComputeTypeToString(cublasComputeType_t ct) {
switch (ct) {
case CUBLAS_COMPUTE_16F:
Expand All @@ -92,6 +93,7 @@ const char* CublasComputeTypeToString(cublasComputeType_t ct) {
return "<unknown>";
}
}
#endif

// It must exist somewhere already.
cudaDataType_t ToCudaDataType(int32_t element_type) {
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,14 @@ namespace onnxruntime {
namespace cuda {

#define CUDA_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDA_CALL(expr))
#ifndef USE_CUDA_MINIMAL
#define CUBLAS_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUBLAS_CALL(expr))
#define CUSPARSE_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUSPARSE_CALL(expr))
#define CURAND_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CURAND_CALL(expr))
#define CUDNN_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUDNN_CALL(expr))
#define CUDNN2_RETURN_IF_ERROR(expr, m) ORT_RETURN_IF_ERROR(CUDNN_CALL2(expr, m))
#define CUFFT_RETURN_IF_ERROR(expr) ORT_RETURN_IF_ERROR(CUFFT_CALL(expr))

#endif
// Type mapping for MLFloat16 to half
template <typename T>
class ToCudaType {
Expand Down Expand Up @@ -93,7 +94,7 @@ inline bool CalculateFdmStrides(gsl::span<fast_divmod> p, const std::vector<int6
}
return true;
}

#ifndef USE_CUDA_MINIMAL
class CublasMathModeSetter {
public:
CublasMathModeSetter(const cudaDeviceProp& prop, cublasHandle_t handle, cublasMath_t mode) : handle_(handle) {
Expand Down Expand Up @@ -189,6 +190,7 @@ const char* cublasGetErrorEnum(cublasStatus_t error);
const char* CudaDataTypeToString(cudaDataType_t dt);

const char* CublasComputeTypeToString(cublasComputeType_t ct);
#endif

cudaDataType_t ToCudaDataType(int32_t element_type);

Expand Down
14 changes: 12 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "core/providers/cuda/gpu_data_transfer.h"
#include "core/providers/cuda/cuda_profiler.h"

#ifndef USE_CUDA_MINIMAL
#ifndef DISABLE_CONTRIB_OPS
#include "contrib_ops/cuda/cuda_contrib_kernels.h"
#endif
Expand All @@ -27,6 +28,7 @@
#ifdef USE_TRITON_KERNEL
#include "core/providers/cuda/triton_kernel.h"
#endif
#endif

#include "core/providers/cuda/cuda_stream_handle.h"

Expand Down Expand Up @@ -169,21 +171,23 @@ CUDAExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de
ArenaExtendStrategy /*arena_extend_strategy*/, CUDAExecutionProviderExternalAllocatorInfo /*external_allocator_info*/,
OrtArenaCfg* /*default_memory_arena_cfg*/) {
CUDA_CALL_THROW(cudaSetDevice(device_id));

#ifndef USE_CUDA_MINIMAL
CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_));
CUBLAS_CALL_THROW(cublasLtCreate(&cublas_lt_handle_));
CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream));

CUDNN_CALL_THROW(cudnnCreate(&cudnn_handle_));
CUDNN_CALL_THROW(cudnnSetStream(cudnn_handle_, stream));

#endif
cuda_graph_.SetStream(stream);
}

CUDAExecutionProvider::PerThreadContext::~PerThreadContext() {
#ifndef USE_CUDA_MINIMAL
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasDestroy(cublas_handle_)));
ORT_IGNORE_RETURN_VALUE(CUBLAS_CALL(cublasLtDestroy(cublas_lt_handle_)));
ORT_IGNORE_RETURN_VALUE(CUDNN_CALL(cudnnDestroy(cudnn_handle_)));
#endif
}

bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const {
Expand Down Expand Up @@ -441,6 +445,7 @@ namespace cuda {
// opset 1 to 9
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyFromHost);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyToHost);
#ifndef USE_CUDA_MINIMAL
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, float, Cos);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, double, Cos);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 7, MLFloat16, Cos);
Expand Down Expand Up @@ -1315,6 +1320,7 @@ class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDom
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Reshape);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Scan);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Shape);
#endif

template <>
KernelCreateInfo BuildKernelCreateInfo<void>() {
Expand All @@ -1326,6 +1332,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<void>, // default entry to avoid the list become empty after ops-reducing
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyFromHost)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MemcpyToHost)>,
#ifndef USE_CUDA_MINIMAL
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 4, 10, Concat)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 10, Unsqueeze)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, 8, Flatten)>,
Expand Down Expand Up @@ -2201,6 +2208,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Reshape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Scan)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 19, Shape)>,
#endif
};

for (auto& function_table_entry : function_table) {
Expand All @@ -2210,6 +2218,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
}
}

#ifndef USE_CUDA_MINIMAL
#ifndef DISABLE_CONTRIB_OPS
ORT_RETURN_IF_ERROR(::onnxruntime::contrib::cuda::RegisterCudaContribKernels(kernel_registry));
#endif
Expand All @@ -2220,6 +2229,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {

#ifdef ENABLE_TRAINING_OPS
ORT_RETURN_IF_ERROR(::onnxruntime::cuda::RegisterCudaTrainingKernels(kernel_registry));
#endif
#endif

return Status::OK();
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_pch.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,19 @@

#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#ifndef USE_CUDA_MINIMAL
#include <cublas_v2.h>
#include <cusparse.h>
#include <curand.h>
#include <cudnn.h>
#include <cufft.h>
#include <cublasLt.h>
#else
typedef void* cudnnHandle_t;
typedef void* cublasHandle_t;
typedef void* cublasLtHandle_t;
#endif

#ifdef ORT_USE_NCCL
#include <nccl.h>
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_stream_handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ CudaStream::CudaStream(cudaStream_t stream,
release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream),
deferred_cpu_allocator_(*this),
ep_info_(ep_info) {
#ifndef USE_CUDA_MINIMAL
if (own_flag) {
CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_));
CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream));
Expand All @@ -80,17 +81,20 @@ CudaStream::CudaStream(cudaStream_t stream,
cudnn_handle_ = external_cudnn_handle;
CUDNN_CALL_THROW(cudnnSetStream(cudnn_handle_, stream));
}
#endif
}

CudaStream::~CudaStream() {
ORT_IGNORE_RETURN_VALUE(CleanUpOnRunEnd());
#ifndef USE_CUDA_MINIMAL
if (own_stream_) {
cublasDestroy(cublas_handle_);
cudnnDestroy(cudnn_handle_);
auto* handle = GetHandle();
if (handle)
cudaStreamDestroy(static_cast<cudaStream_t>(handle));
}
#endif
}

std::unique_ptr<synchronize::Notification> CudaStream::CreateNotification(size_t /*num_consumers*/) {
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cuda/cudnn_common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "core/common/gsl.h"
#include "shared_inc/cuda_call.h"
#include "core/providers/cpu/tensor/utils.h"

#ifndef USE_CUDA_MINIMAL
namespace onnxruntime {
namespace cuda {

Expand Down Expand Up @@ -222,3 +222,4 @@ const Float8E5M2 Consts<Float8E5M2>::One = Float8E5M2(1.0f, true);

} // namespace cuda
} // namespace onnxruntime
#endif
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cuda/cudnn_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <cfloat>

#include "core/providers/cuda/cuda_common.h"

#ifndef USE_CUDA_MINIMAL
namespace onnxruntime {
namespace cuda {

Expand Down Expand Up @@ -260,3 +260,4 @@ SetPoolingNdDescriptorHelper(cudnnPoolingDescriptor_t poolingDesc,

} // namespace cuda
} // namespace onnxruntime
#endif

0 comments on commit bc219ed

Please sign in to comment.