Skip to content

Commit

Permalink
[TensorRT EP] Fix mem leak for TRT plugins custom ops (#19248)
Browse files Browse the repository at this point in the history
TRT EP's GetTensorRTCustomOpDomainList() will create vector of
OrtCustomOpDomain objects and release the ownership of those objects.
But, thoses objects are not released forever.
In session level, we need to make TRT EP remember what OrtCustomOpDomain
objects it created and release them at EP destruction time.
  • Loading branch information
chilo-ms authored Jan 25, 2024
1 parent 2b285cd commit a2867b9
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 75 deletions.
18 changes: 13 additions & 5 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1834,13 +1834,21 @@ nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const {
}

void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) const {
if (info_.custom_op_domain_list.empty()) {
common::Status status = CreateTensorRTCustomOpDomainList(info_);
if (!status.IsOK()) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
std::string extra_plugin_lib_paths{""};
if (info_.has_trt_options) {
if (!info_.extra_plugin_lib_paths.empty()) {
extra_plugin_lib_paths = info_.extra_plugin_lib_paths;
}
} else {
const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths);
if (!extra_plugin_lib_paths_env.empty()) {
extra_plugin_lib_paths = extra_plugin_lib_paths_env;
}
}
auto status = CreateTensorRTCustomOpDomainList(custom_op_domain_list, extra_plugin_lib_paths);
if (status != Status::OK()) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
}
custom_op_domain_list = info_.custom_op_domain_list;
}

// Check the graph is the subgraph of control flow op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ extern TensorrtLogger& GetTensorrtLogger();
* So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation.
*/
common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths) {
std::unique_ptr<OrtCustomOpDomain> custom_op_domain = std::make_unique<OrtCustomOpDomain>();
custom_op_domain->domain_ = "trt.plugins";
static std::unique_ptr<OrtCustomOpDomain> custom_op_domain = std::make_unique<OrtCustomOpDomain>();
static std::vector<std::unique_ptr<TensorRTCustomOp>> created_custom_op_list;
if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) {
domain_list.push_back(custom_op_domain.get());
return Status::OK();
}

// Load any extra TRT plugin library if any.
// When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry.
Expand Down Expand Up @@ -69,38 +73,19 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
continue;
}

std::unique_ptr<TensorRTCustomOp> trt_custom_op = std::make_unique<TensorRTCustomOp>(onnxruntime::kTensorrtExecutionProvider, nullptr);
trt_custom_op->SetName(plugin_creator->getPluginName());
custom_op_domain->custom_ops_.push_back(trt_custom_op.release());
created_custom_op_list.push_back(std::make_unique<TensorRTCustomOp>(onnxruntime::kTensorrtExecutionProvider, nullptr)); // Make sure TensorRTCustomOp object won't be cleaned up
created_custom_op_list.back().get()->SetName(plugin_creator->getPluginName());
custom_op_domain->custom_ops_.push_back(created_custom_op_list.back().get());
registered_plugin_names.insert(plugin_name);
}
domain_list.push_back(custom_op_domain.release());
custom_op_domain->domain_ = "trt.plugins";
domain_list.push_back(custom_op_domain.get());
} catch (const std::exception&) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins";
}
return Status::OK();
}

common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) {
std::vector<OrtCustomOpDomain*> domain_list;
std::string extra_plugin_lib_paths{""};
if (info.has_trt_options) {
if (!info.extra_plugin_lib_paths.empty()) {
extra_plugin_lib_paths = info.extra_plugin_lib_paths;
}
} else {
const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths);
if (!extra_plugin_lib_paths_env.empty()) {
extra_plugin_lib_paths = extra_plugin_lib_paths_env;
}
}
auto status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths);
if (!domain_list.empty()) {
info.custom_op_domain_list = domain_list;
}
return Status::OK();
}

void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) {
if (domain != nullptr) {
for (auto ptr : domain->custom_ops_) {
Expand Down
49 changes: 8 additions & 41 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1713,17 +1713,9 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessi

ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id) {
API_IMPL_BEGIN
auto factory = onnxruntime::TensorrtProviderFactoryCreator::Create(device_id);
if (!factory) {
return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library");
}

options->provider_factories.push_back(factory);

std::string extra_plugin_lib_paths = onnxruntime::Env::Default().GetEnvironmentVar("trt_extra_plugin_lib_paths");
AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths);

return nullptr;
OrtTensorRTProviderOptionsV2 tensorrt_options;
tensorrt_options.device_id = device_id;
return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &tensorrt_options);
API_IMPL_END
}

Expand All @@ -1741,33 +1733,8 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtS

ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options) {
API_IMPL_BEGIN

std::shared_ptr<onnxruntime::IExecutionProviderFactory> factory;

#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT)
auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0";
// If EP context configs are provided in session options, we need to propagate them to provider options
if (ep_context_cache_enabled_from_sess_options) {
OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options);

onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &trt_options_converted);
factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted);
} else {
factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options);
}
#else
factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options);
#endif

if (!factory) {
return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library");
}

options->provider_factories.push_back(factory);

AddTensorRTCustomOpDomainToSessionOption(options, "");

return nullptr;
OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options);
return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &trt_options_converted);
API_IMPL_END
}

Expand Down Expand Up @@ -1906,11 +1873,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2,
// if provider options already have the EP context configs provided, the configs in session options will be ignored
// since provider options has higher priority than session options.
if (!ep_context_cache_enabled_from_provider_options && ep_context_cache_enabled_from_sess_options) {
// We need to create a new provider options V2 object and copy from provider_options, due to the "const" object pointed by provider_options can't be modified.
// Note: No need to worry about tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will
// This function might need to update the "const" OrtTensorRTProviderOptionsV2 object which can't be modified.
// Therefore, we need to create a new OrtTensorRTProviderOptionsV2 object and copy from tensorrt_options and use this new object to create the factory instead.
// Note: No need to worry about new_tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will
// create a factory object that copies any provider options from tensorrt_options including "const char*" provider options.
OrtTensorRTProviderOptionsV2 new_tensorrt_options = *tensorrt_options; // copy and assign from tensorrt_options

onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &new_tensorrt_options);
factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&new_tensorrt_options);
} else {
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,9 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti
if (it != options.end()) {
trt_extra_plugin_lib_paths = it->second;
}
std::vector<OrtCustomOpDomain*> domain_list;
tensorrt_provider_info->GetTensorRTCustomOpDomainList(domain_list, trt_extra_plugin_lib_paths);
for (auto ptr : domain_list) {
std::vector<OrtCustomOpDomain*> custom_op_domains;
tensorrt_provider_info->GetTensorRTCustomOpDomainList(custom_op_domains, trt_extra_plugin_lib_paths);
for (auto ptr : custom_op_domains) {
if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) {
so.custom_op_domains_.push_back(ptr);
} else {
Expand Down

0 comments on commit a2867b9

Please sign in to comment.