Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Update TRT10.0 deprecated api #20989

Merged
merged 14 commits into from
Jul 2, 2024
58 changes: 39 additions & 19 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,11 +160,20 @@
nvinfer1::TacticSource source{};
t = toUpper(t);
if (t == "CUBLAS") {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS is deprecated in TensorRT 10.0";
#if NV_TENSORRT_MAJOR < 10
source = nvinfer1::TacticSource::kCUBLAS;
#endif
} else if (t == "CUBLASLT" || t == "CUBLAS_LT") {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUBLAS_LT is deprecated in TensorRT 9.0";
#if NV_TENSORRT_MAJOR < 9
source = nvinfer1::TacticSource::kCUBLAS_LT;
#endif
} else if (t == "CUDNN") {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Tactic kCUDNN is deprecated in TensorRT 10.0";
#if NV_TENSORRT_MAJOR < 10
source = nvinfer1::TacticSource::kCUDNN;
#endif
} else if (t == "EDGE_MASK_CONVOLUTIONS") {
source = nvinfer1::TacticSource::kEDGE_MASK_CONVOLUTIONS;
} else if (t == "JIT_CONVOLUTIONS") {
Expand Down Expand Up @@ -289,6 +298,25 @@
return g_host->CudaCall_true(retCode, exprString, libName, successCode, msg, file, line);
}

#if NV_TENSORRT_MAJOR >= 10

Check warning on line 301 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:301: Lines should be <= 120 characters long [whitespace/line_length] [2]
void* OutputAllocator::reallocateOutputAsync(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size,
uint64_t /*alignment*/, cudaStream_t /*stream*/) noexcept {
// Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr
// even for empty tensors, so allocate a dummy byte.
size = std::max(size, static_cast<uint64_t>(1));

Check warning on line 306 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:306: Lines should be <= 120 characters long [whitespace/line_length] [2]
if (size > allocated_size) {
cudaFree(outputPtr);
outputPtr = nullptr;
allocated_size = 0;
if (cudaMalloc(&outputPtr, size) == cudaSuccess) {
allocated_size = size;
}
}
// if cudaMalloc fails, returns nullptr.
return outputPtr;
}
#else
// Only override this method when TensorRT <= 8.6
void* OutputAllocator::reallocateOutput(char const* /*tensorName*/, void* /*currentMemory*/, uint64_t size,
uint64_t /*alignment*/) noexcept {
// Some memory allocators return nullptr when allocating zero bytes, but TensorRT requires a non-null ptr
Expand All @@ -305,6 +333,7 @@
// if cudaMalloc fails, returns nullptr.
return outputPtr;
}
#endif

void OutputAllocator::notifyShape(char const* /*tensorName*/, nvinfer1::Dims const& dims) noexcept {
output_shapes.clear();
Expand Down Expand Up @@ -3142,14 +3171,10 @@
if (mem_size > max_ctx_mem_size_) {
max_ctx_mem_size_ = mem_size;
}

#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996) // nvinfer1::ICudaEngine::createExecutionContextWithoutDeviceMemory was deprecated
#endif
#if NV_TENSORRT_MAJOR < 10
trt_context = std::unique_ptr<nvinfer1::IExecutionContext>(trt_engine->createExecutionContextWithoutDeviceMemory());
#if defined(_MSC_VER)
#pragma warning(pop)
#else
trt_context = std::unique_ptr<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
#endif
} else {
trt_context = std::unique_ptr<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext());
Expand Down Expand Up @@ -3596,14 +3621,12 @@

if (context_update) {
if (trt_state->context_memory_sharing_enable) {
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996) // nvinfer1::ICudaEngine::createExecutionContextWithoutDeviceMemory was deprecated
#endif
#if NV_TENSORRT_MAJOR < 10
*(trt_state->context) = std::unique_ptr<nvinfer1::IExecutionContext>(
trt_state->engine->get()->createExecutionContextWithoutDeviceMemory());
#if defined(_MSC_VER)
#pragma warning(pop)
#else
*(trt_state->context) = std::unique_ptr<nvinfer1::IExecutionContext>(
trt_state->engine->get()->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
#endif
} else {
*(trt_state->context) = std::unique_ptr<nvinfer1::IExecutionContext>(
Expand Down Expand Up @@ -3820,13 +3843,10 @@
if (mem_size > max_ctx_mem_size_) {
max_ctx_mem_size_ = mem_size;
}
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4996) // nvinfer1::ICudaEngine::createExecutionContextWithoutDeviceMemory was deprecated
#endif
#if NV_TENSORRT_MAJOR < 10
trt_context = std::unique_ptr<nvinfer1::IExecutionContext>(trt_engine->createExecutionContextWithoutDeviceMemory());
#if defined(_MSC_VER)
#pragma warning(pop)
#else
trt_context = std::unique_ptr<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext(nvinfer1::ExecutionContextAllocationStrategy::kUSER_MANAGED));
#endif
} else {
trt_context = std::unique_ptr<nvinfer1::IExecutionContext>(trt_engine->createExecutionContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,11 @@
//
class OutputAllocator : public nvinfer1::IOutputAllocator {
public:
#if NV_TENSORRT_MAJOR >= 10
void* reallocateOutputAsync(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment, cudaStream_t stream) noexcept override;

Check warning on line 120 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h:120: Lines should be <= 120 characters long [whitespace/line_length] [2]
#else
void* reallocateOutput(char const* tensorName, void* currentMemory, uint64_t size, uint64_t alignment) noexcept override;

#endif
void notifyShape(char const* tensorName, nvinfer1::Dims const& dims) noexcept override;

void* getBuffer() {
Expand Down
Loading