Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Jan 23, 2024
1 parent 8f061a5 commit 23ffbc2
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ struct Tensorrt_Provider : Provider {
info.ep_context_file_path = options.trt_ep_context_file_path == nullptr ? "" : options.trt_ep_context_file_path;
info.ep_context_embed_mode = options.trt_ep_context_embed_mode;
info.engine_cache_prefix = options.trt_engine_cache_prefix == nullptr ? "" : options.trt_engine_cache_prefix;

info.custom_op_domain_list = options.custom_op_domain_list;

return std::make_shared<TensorrtProviderFactory>(info);
Expand Down
44 changes: 7 additions & 37 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1693,8 +1693,10 @@ void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, OrtTen
for (auto ptr : custom_op_domains) {
if (!is_already_in_domains(ptr->domain_, options->custom_op_domains_)) {
options->custom_op_domains_.push_back(ptr);
// The OrtCustomOpDomains objects created by TRT EP's GetTensorRTCustomOpDomainList() should be released once session is finished.

Check warning on line 1696 in onnxruntime/core/session/provider_bridge_ort.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/session/provider_bridge_ort.cc#L1696

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/session/provider_bridge_ort.cc:1696:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
// TRT EP needs to keep all the pointers OrtCustomOpDomain obejcts and releases them upon TRT EP destruction.
if (tensorrt_options) {
tensorrt_options->custom_op_domain_list.push_back(ptr); // TensorRT EP should keep all the pointers to OrtCustomOpDomain obejcts for the purpose of releasing them at EP destruction
tensorrt_options->custom_op_domain_list.push_back(ptr);
}
} else {
LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option.";
Expand All @@ -1716,17 +1718,9 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessi

ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id) {
API_IMPL_BEGIN
auto factory = onnxruntime::TensorrtProviderFactoryCreator::Create(device_id);
if (!factory) {
return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library");
}

options->provider_factories.push_back(factory);

std::string extra_plugin_lib_paths = onnxruntime::Env::Default().GetEnvironmentVar("trt_extra_plugin_lib_paths");
AddTensorRTCustomOpDomainToSessionOption(options, nullptr, extra_plugin_lib_paths);

return nullptr;
OrtTensorRTProviderOptionsV2 tensorrt_options;
tensorrt_options.device_id = device_id;
return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &tensorrt_options);
API_IMPL_END
}

Expand All @@ -1744,32 +1738,8 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtS

ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options) {
API_IMPL_BEGIN

std::shared_ptr<onnxruntime::IExecutionProviderFactory> factory;

OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options);

Check warning on line 1741 in onnxruntime/core/session/provider_bridge_ort.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/session/provider_bridge_ort.cc#L1741

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/session/provider_bridge_ort.cc:1741:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
trt_options_converted.custom_op_domain_list.clear();
AddTensorRTCustomOpDomainToSessionOption(options, &trt_options_converted, "");

#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT)
auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0";
// If EP context configs are provided in session options, we need to propagate them to provider options
if (ep_context_cache_enabled_from_sess_options) {
onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &trt_options_converted);
factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted);
}
factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted);
#else
factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted);
#endif

if (!factory) {
return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library");
}

options->provider_factories.push_back(factory);

return nullptr;
return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &trt_options_converted);
API_IMPL_END
}

Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -745,6 +745,13 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
ORT_THROW("Invalid TensorRT EP option: ", option.first);
}
}
// The OrtCustomOpDomains objects created by TRT EP's GetTensorRTCustomOpDomainList() should be released once session is finished.

Check warning on line 748 in onnxruntime/python/onnxruntime_pybind_state.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/onnxruntime_pybind_state.cc#L748

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/onnxruntime_pybind_state.cc:748:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
// TRT EP needs to keep all the pointers OrtCustomOpDomain obejcts and releases them upon TRT EP destruction.
for (auto ptr : session_options.custom_op_domains_) {
if (ptr->domain_ == "trt.plugins") {
params->custom_op_domain_list.push_back(ptr);
}
}
if (std::shared_ptr<IExecutionProviderFactory> tensorrt_provider_factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&params)) {
return tensorrt_provider_factory->CreateProvider();
}
Expand Down

0 comments on commit 23ffbc2

Please sign in to comment.