Skip to content

Commit

Permalink
[TensorRT EP] Only instantiate TRT builder once (#18100)
Browse files Browse the repository at this point in the history
The TRT builder instantization is slow (see
[here](#18071)).
In current TRT EP, we instantiate builder object every time we need it.
There are multiple places need the TRT builder so this causes huge
performance overhead.
  • Loading branch information
chilo-ms authored Nov 16, 2023
1 parent 6f9f653 commit 18a3675
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
25 changes: 19 additions & 6 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1272,6 +1272,20 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) {
return Status::OK();
}

// Get the pointer to the IBuilder instance.
// Note: This function is not thread safe. Calls to this function from different threads must be serialized
// even though it doesn't make sense to have multiple threads initializing the same inference session.
nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const {
if (!builder_) {
TensorrtLogger& trt_logger = GetTensorrtLogger();
{
auto lock = GetApiLock();
builder_ = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
}
}
return builder_.get();
}

void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) const {
if (info_.custom_op_domain_list.empty()) {
common::Status status = CreateTensorRTCustomOpDomainList(info_);
Expand Down Expand Up @@ -1633,7 +1647,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect
// Get supported node list recursively
SubGraphCollection_t parser_nodes_list;
TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
auto trt_builder = GetBuilder();
const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(explicitBatch));

Expand Down Expand Up @@ -1985,7 +1999,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
}

TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
auto trt_builder = GetBuilder();
const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(explicitBatch));
auto trt_config = std::unique_ptr<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
Expand Down Expand Up @@ -2438,7 +2452,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
parsers_.emplace(fused_node.Name(), std::move(trt_parser));
engines_.emplace(fused_node.Name(), std::move(trt_engine));
contexts_.emplace(fused_node.Name(), std::move(trt_context));
builders_.emplace(fused_node.Name(), std::move(trt_builder));
networks_.emplace(fused_node.Name(), std::move(trt_network));
input_info_[fused_node.Name()].push_back(input_indexes);
output_info_[fused_node.Name()].push_back(output_indexes);
Expand All @@ -2456,8 +2469,8 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
if (!tactic_sources_.empty()) {
tactics = GetTacticSourceFromString(tactic_sources_);
}
*p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name,
&parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name],
*p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(),
&parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name],
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
Expand Down Expand Up @@ -2490,7 +2503,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue;
auto fused_node_name = trt_state->fused_node_name;
auto& shape_ranges = trt_state->input_shape_ranges;
auto trt_builder = trt_state->builder->get();
auto trt_builder = trt_state->builder;
auto trt_engine = trt_state->engine->get();
auto trt_context = trt_state->context->get();
auto trt_profiles = trt_state->profiles;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@ struct TensorrtFuncState {
DestroyFunc test_release_func = nullptr;
AllocatorHandle allocator = nullptr;
std::string fused_node_name;
nvinfer1::IBuilder* builder;
tensorrt_ptr::unique_pointer<nvonnxparser::IParser>* parser = nullptr;
std::unique_ptr<nvinfer1::ICudaEngine>* engine = nullptr;
std::unique_ptr<nvinfer1::IExecutionContext>* context = nullptr;
std::unique_ptr<nvinfer1::IBuilder>* builder = nullptr;
std::unique_ptr<nvinfer1::INetworkDefinition>* network = nullptr;
std::vector<std::unordered_map<std::string, size_t>> input_info;
std::vector<std::unordered_map<std::string, size_t>> output_info;
Expand Down Expand Up @@ -245,6 +245,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {
std::unordered_set<std::string> control_flow_op_set_ = {"If", "Loop", "Scan"};
mutable std::unordered_map<std::string, std::unique_ptr<SubGraphContext>> subgraph_context_map_;

mutable std::unique_ptr<nvinfer1::IBuilder> builder_;

// Following maps that hold TRT objects will be accessible by different threads if ORT is using multithreading.
// In general, TensorRT objects are not thread safe; accesses to an object from different threads must be serialized by the client.
// But there are still some thread safe operations, please see here https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
Expand Down Expand Up @@ -456,5 +458,11 @@ class TensorrtExecutionProvider : public IExecutionProvider {
void CaptureBegin();
void CaptureEnd();
void IncrementRegularRunCountBeforeGraphCapture();

/**
* Get the pointer to the IBuilder instance.
* This function only creates the instance at the first time it's being called."
*/
nvinfer1::IBuilder* GetBuilder() const;
};
} // namespace onnxruntime

0 comments on commit 18a3675

Please sign in to comment.