diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index fe6b959b962de..39e5f5be000e5 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1834,13 +1834,21 @@ nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const { } void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { - if (info_.custom_op_domain_list.empty()) { - common::Status status = CreateTensorRTCustomOpDomainList(info_); - if (!status.IsOK()) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; + std::string extra_plugin_lib_paths{""}; + if (info_.has_trt_options) { + if (!info_.extra_plugin_lib_paths.empty()) { + extra_plugin_lib_paths = info_.extra_plugin_lib_paths; } + } else { + const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths); + if (!extra_plugin_lib_paths_env.empty()) { + extra_plugin_lib_paths = extra_plugin_lib_paths_env; + } + } + auto status = CreateTensorRTCustomOpDomainList(custom_op_domain_list, extra_plugin_lib_paths); + if (status != Status::OK()) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; } - custom_op_domain_list = info_.custom_op_domain_list; } // Check the graph is the subgraph of control flow op diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 4e466a5d568a6..eb340ba1e64b6 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -27,8 +27,12 @@ extern TensorrtLogger& GetTensorrtLogger(); * So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation. */ common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) { - std::unique_ptr custom_op_domain = std::make_unique(); - custom_op_domain->domain_ = "trt.plugins"; + static std::unique_ptr custom_op_domain = std::make_unique(); + static std::vector> created_custom_op_list; + if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) { + domain_list.push_back(custom_op_domain.get()); + return Status::OK(); + } // Load any extra TRT plugin library if any. // When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry. @@ -69,38 +73,19 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& continue; } - std::unique_ptr trt_custom_op = std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr); - trt_custom_op->SetName(plugin_creator->getPluginName()); - custom_op_domain->custom_ops_.push_back(trt_custom_op.release()); + created_custom_op_list.push_back(std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr)); // Make sure TensorRTCustomOp object won't be cleaned up + created_custom_op_list.back().get()->SetName(plugin_creator->getPluginName()); + custom_op_domain->custom_ops_.push_back(created_custom_op_list.back().get()); registered_plugin_names.insert(plugin_name); } - domain_list.push_back(custom_op_domain.release()); + custom_op_domain->domain_ = "trt.plugins"; + domain_list.push_back(custom_op_domain.get()); } catch (const std::exception&) { LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins"; } return Status::OK(); } -common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) { - std::vector domain_list; - std::string extra_plugin_lib_paths{""}; - if (info.has_trt_options) { - if (!info.extra_plugin_lib_paths.empty()) { - extra_plugin_lib_paths = info.extra_plugin_lib_paths; - } - } else { - const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths); - if (!extra_plugin_lib_paths_env.empty()) { - extra_plugin_lib_paths = extra_plugin_lib_paths_env; - } - } - auto status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths); - if (!domain_list.empty()) { - info.custom_op_domain_list = domain_list; - } - return Status::OK(); -} - void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) { if (domain != nullptr) { for (auto ptr : domain->custom_ops_) { diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 3269c9f0f4e4b..29c2c6b0cce16 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1713,17 +1713,9 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessi ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id) { API_IMPL_BEGIN - auto factory = onnxruntime::TensorrtProviderFactoryCreator::Create(device_id); - if (!factory) { - return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library"); - } - - options->provider_factories.push_back(factory); - - std::string extra_plugin_lib_paths = onnxruntime::Env::Default().GetEnvironmentVar("trt_extra_plugin_lib_paths"); - AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths); - - return nullptr; + OrtTensorRTProviderOptionsV2 tensorrt_options; + tensorrt_options.device_id = device_id; + return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &tensorrt_options); API_IMPL_END } @@ -1741,33 +1733,8 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtS ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options) { API_IMPL_BEGIN - - std::shared_ptr factory; - -#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) - auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0"; - // If EP context configs are provided in session options, we need to propagate them to provider options - if (ep_context_cache_enabled_from_sess_options) { - OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options); - - onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &trt_options_converted); - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted); - } else { - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); - } -#else - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); -#endif - - if (!factory) { - return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library"); - } - - options->provider_factories.push_back(factory); - - AddTensorRTCustomOpDomainToSessionOption(options, ""); - - return nullptr; + OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options); + return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &trt_options_converted); API_IMPL_END } @@ -1906,11 +1873,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, // if provider options already have the EP context configs provided, the configs in session options will be ignored // since provider options has higher priority than session options. if (!ep_context_cache_enabled_from_provider_options && ep_context_cache_enabled_from_sess_options) { - // We need to create a new provider options V2 object and copy from provider_options, due to the "const" object pointed by provider_options can't be modified. - // Note: No need to worry about tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will + // This function might need to update the "const" OrtTensorRTProviderOptionsV2 object which can't be modified. + // Therefore, we need to create a new OrtTensorRTProviderOptionsV2 object and copy from tensorrt_options and use this new object to create the factory instead. + // Note: No need to worry about new_tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will // create a factory object that copies any provider options from tensorrt_options including "const char*" provider options. OrtTensorRTProviderOptionsV2 new_tensorrt_options = *tensorrt_options; // copy and assign from tensorrt_options - onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &new_tensorrt_options); factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&new_tensorrt_options); } else { diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f7ed5520727db..8e13982ca6861 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -443,9 +443,9 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti if (it != options.end()) { trt_extra_plugin_lib_paths = it->second; } - std::vector domain_list; - tensorrt_provider_info->GetTensorRTCustomOpDomainList(domain_list, trt_extra_plugin_lib_paths); - for (auto ptr : domain_list) { + std::vector custom_op_domains; + tensorrt_provider_info->GetTensorRTCustomOpDomainList(custom_op_domains, trt_extra_plugin_lib_paths); + for (auto ptr : custom_op_domains) { if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) { so.custom_op_domains_.push_back(ptr); } else {