From 18a3675bf73f86e05a200428c06c053357bbc51b Mon Sep 17 00:00:00 2001 From: Chi Lo <54722500+chilo-ms@users.noreply.github.com> Date: Thu, 16 Nov 2023 07:39:41 +0000 Subject: [PATCH] [TensorRT EP] Only instantiate TRT builder once (#18100) The TRT builder instantization is slow (see [here](https://github.com/microsoft/onnxruntime/issues/18071)). In current TRT EP, we instantiate builder object every time we need it. There are multiple places need the TRT builder so this causes huge performance overhead. --- .../tensorrt/tensorrt_execution_provider.cc | 25 ++++++++++++++----- .../tensorrt/tensorrt_execution_provider.h | 10 +++++++- 2 files changed, 28 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 020af451cdcd5..3b3732bb716f9 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1272,6 +1272,20 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { return Status::OK(); } +// Get the pointer to the IBuilder instance. +// Note: This function is not thread safe. Calls to this function from different threads must be serialized +// even though it doesn't make sense to have multiple threads initializing the same inference session. +nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const { + if (!builder_) { + TensorrtLogger& trt_logger = GetTensorrtLogger(); + { + auto lock = GetApiLock(); + builder_ = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + } + } + return builder_.get(); +} + void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { if (info_.custom_op_domain_list.empty()) { common::Status status = CreateTensorRTCustomOpDomainList(info_); @@ -1633,7 +1647,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect // Get supported node list recursively SubGraphCollection_t parser_nodes_list; TensorrtLogger& trt_logger = GetTensorrtLogger(); - auto trt_builder = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); + auto trt_builder = GetBuilder(); const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(explicitBatch)); @@ -1985,7 +1999,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(nvinfer1::createInferBuilder(trt_logger)); + auto trt_builder = GetBuilder(); const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(explicitBatch)); auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); @@ -2438,7 +2452,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name, - &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], + *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(), + &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, @@ -2490,7 +2503,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorsync_stream_after_enqueue; auto fused_node_name = trt_state->fused_node_name; auto& shape_ranges = trt_state->input_shape_ranges; - auto trt_builder = trt_state->builder->get(); + auto trt_builder = trt_state->builder; auto trt_engine = trt_state->engine->get(); auto trt_context = trt_state->context->get(); auto trt_profiles = trt_state->profiles; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index cda08715ea009..a945d219088aa 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -105,10 +105,10 @@ struct TensorrtFuncState { DestroyFunc test_release_func = nullptr; AllocatorHandle allocator = nullptr; std::string fused_node_name; + nvinfer1::IBuilder* builder; tensorrt_ptr::unique_pointer* parser = nullptr; std::unique_ptr* engine = nullptr; std::unique_ptr* context = nullptr; - std::unique_ptr* builder = nullptr; std::unique_ptr* network = nullptr; std::vector> input_info; std::vector> output_info; @@ -245,6 +245,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"}; mutable std::unordered_map> subgraph_context_map_; + mutable std::unique_ptr builder_; + // Following maps that hold TRT objects will be accessible by different threads if ORT is using multithreading. // In general, TensorRT objects are not thread safe; accesses to an object from different threads must be serialized by the client. // But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading @@ -456,5 +458,11 @@ class TensorrtExecutionProvider : public IExecutionProvider { void CaptureBegin(); void CaptureEnd(); void IncrementRegularRunCountBeforeGraphCapture(); + + /** + * Get the pointer to the IBuilder instance. + * This function only creates the instance at the first time it's being called." + */ + nvinfer1::IBuilder* GetBuilder() const; }; } // namespace onnxruntime