Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Fix mem leak for TRT plugins custom ops #19248

Merged
merged 11 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1834,13 +1834,21 @@
}

void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) const {
if (info_.custom_op_domain_list.empty()) {
common::Status status = CreateTensorRTCustomOpDomainList(info_);
if (!status.IsOK()) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
std::string extra_plugin_lib_paths{""};
if (info_.has_trt_options) {
if (!info_.extra_plugin_lib_paths.empty()) {
extra_plugin_lib_paths = info_.extra_plugin_lib_paths;
}
} else {
const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths);

Check warning on line 1843 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1843

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1843:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (!extra_plugin_lib_paths_env.empty()) {
extra_plugin_lib_paths = extra_plugin_lib_paths_env;
}
}
auto status = CreateTensorRTCustomOpDomainList(custom_op_domain_list, extra_plugin_lib_paths);
if (status != Status::OK()) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
}
custom_op_domain_list = info_.custom_op_domain_list;
}

// Check the graph is the subgraph of control flow op
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@
* So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation.
*/
common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths) {
std::unique_ptr<OrtCustomOpDomain> custom_op_domain = std::make_unique<OrtCustomOpDomain>();
custom_op_domain->domain_ = "trt.plugins";
static std::unique_ptr<OrtCustomOpDomain> custom_op_domain = std::make_unique<OrtCustomOpDomain>();
static std::vector<std::unique_ptr<TensorRTCustomOp>> created_custom_op_list;
if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) {
domain_list.push_back(custom_op_domain.get());
return Status::OK();
}

// Load any extra TRT plugin library if any.
// When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry.
Expand Down Expand Up @@ -69,38 +73,19 @@
continue;
}

std::unique_ptr<TensorRTCustomOp> trt_custom_op = std::make_unique<TensorRTCustomOp>(onnxruntime::kTensorrtExecutionProvider, nullptr);
trt_custom_op->SetName(plugin_creator->getPluginName());
custom_op_domain->custom_ops_.push_back(trt_custom_op.release());
created_custom_op_list.push_back(std::make_unique<TensorRTCustomOp>(onnxruntime::kTensorrtExecutionProvider, nullptr)); // Make sure TensorRTCustomOp object won't be cleaned up

Check warning on line 76 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc#L76

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc:76:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

Check warning on line 76 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc#L76

Add #include <memory> for make_unique<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc:76:  Add #include <memory> for make_unique<>  [build/include_what_you_use] [4]
created_custom_op_list.back().get()->SetName(plugin_creator->getPluginName());
custom_op_domain->custom_ops_.push_back(created_custom_op_list.back().get());
registered_plugin_names.insert(plugin_name);
}
domain_list.push_back(custom_op_domain.release());
custom_op_domain->domain_ = "trt.plugins";
domain_list.push_back(custom_op_domain.get());
} catch (const std::exception&) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins";
}
return Status::OK();
}

common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) {
std::vector<OrtCustomOpDomain*> domain_list;
std::string extra_plugin_lib_paths{""};
if (info.has_trt_options) {
if (!info.extra_plugin_lib_paths.empty()) {
extra_plugin_lib_paths = info.extra_plugin_lib_paths;
}
} else {
const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths);
if (!extra_plugin_lib_paths_env.empty()) {
extra_plugin_lib_paths = extra_plugin_lib_paths_env;
}
}
auto status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths);
if (!domain_list.empty()) {
info.custom_op_domain_list = domain_list;
}
return Status::OK();
}

void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) {
if (domain != nullptr) {
for (auto ptr : domain->custom_ops_) {
Expand Down
49 changes: 8 additions & 41 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1713,17 +1713,9 @@

ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id) {
API_IMPL_BEGIN
auto factory = onnxruntime::TensorrtProviderFactoryCreator::Create(device_id);
if (!factory) {
return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library");
}

options->provider_factories.push_back(factory);

std::string extra_plugin_lib_paths = onnxruntime::Env::Default().GetEnvironmentVar("trt_extra_plugin_lib_paths");
AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths);

return nullptr;
OrtTensorRTProviderOptionsV2 tensorrt_options;
tensorrt_options.device_id = device_id;
return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &tensorrt_options);
API_IMPL_END
}

Expand All @@ -1741,33 +1733,8 @@

ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options) {
API_IMPL_BEGIN

std::shared_ptr<onnxruntime::IExecutionProviderFactory> factory;

#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT)
auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0";
// If EP context configs are provided in session options, we need to propagate them to provider options
if (ep_context_cache_enabled_from_sess_options) {
OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options);

onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &trt_options_converted);
factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted);
} else {
factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options);
}
#else
factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options);
#endif

if (!factory) {
return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library");
}

options->provider_factories.push_back(factory);

AddTensorRTCustomOpDomainToSessionOption(options, "");

return nullptr;
OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options);

Check warning on line 1736 in onnxruntime/core/session/provider_bridge_ort.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/session/provider_bridge_ort.cc#L1736

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/session/provider_bridge_ort.cc:1736:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &trt_options_converted);
API_IMPL_END
}

Expand Down Expand Up @@ -1906,11 +1873,11 @@
// if provider options already have the EP context configs provided, the configs in session options will be ignored
// since provider options has higher priority than session options.
if (!ep_context_cache_enabled_from_provider_options && ep_context_cache_enabled_from_sess_options) {
// We need to create a new provider options V2 object and copy from provider_options, due to the "const" object pointed by provider_options can't be modified.
// Note: No need to worry about tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will
// This function might need to update the "const" OrtTensorRTProviderOptionsV2 object which can't be modified.
// Therefore, we need to create a new OrtTensorRTProviderOptionsV2 object and copy from tensorrt_options and use this new object to create the factory instead.

Check warning on line 1877 in onnxruntime/core/session/provider_bridge_ort.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/session/provider_bridge_ort.cc#L1877

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/session/provider_bridge_ort.cc:1877:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
// Note: No need to worry about new_tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will

Check warning on line 1878 in onnxruntime/core/session/provider_bridge_ort.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/session/provider_bridge_ort.cc#L1878

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/session/provider_bridge_ort.cc:1878:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
// create a factory object that copies any provider options from tensorrt_options including "const char*" provider options.
OrtTensorRTProviderOptionsV2 new_tensorrt_options = *tensorrt_options; // copy and assign from tensorrt_options

onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &new_tensorrt_options);
factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&new_tensorrt_options);
} else {
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,9 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti
if (it != options.end()) {
trt_extra_plugin_lib_paths = it->second;
}
std::vector<OrtCustomOpDomain*> domain_list;
tensorrt_provider_info->GetTensorRTCustomOpDomainList(domain_list, trt_extra_plugin_lib_paths);
for (auto ptr : domain_list) {
std::vector<OrtCustomOpDomain*> custom_op_domains;
tensorrt_provider_info->GetTensorRTCustomOpDomainList(custom_op_domains, trt_extra_plugin_lib_paths);
for (auto ptr : custom_op_domains) {
if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) {
so.custom_op_domains_.push_back(ptr);
} else {
Expand Down
Loading