diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 59edc415da7b9..fe85541a065a4 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -34,7 +34,8 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& // When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry. // This is done through macro, for example, REGISTER_TENSORRT_PLUGIN(VisionTransformerPluginCreator). // extra_plugin_lib_paths has the format of "path_1;path_2....;path_n" - if (!extra_plugin_lib_paths.empty()) { + static bool is_loaded = false; + if (!extra_plugin_lib_paths.empty() && !is_loaded) { std::stringstream extra_plugin_libs(extra_plugin_lib_paths); std::string lib; while (std::getline(extra_plugin_libs, lib, ';')) { @@ -45,6 +46,7 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& LOGS_DEFAULT(WARNING) << "[TensorRT EP]" << status.ToString(); } } + is_loaded = true; } try { @@ -79,6 +81,26 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& return Status::OK(); } +common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) { + std::vector domain_list; + std::string extra_plugin_lib_paths{""}; + if (info.has_trt_options) { + if (!info.extra_plugin_lib_paths.empty()) { + extra_plugin_lib_paths = info.extra_plugin_lib_paths; + } + } else { + const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths); + if (!extra_plugin_lib_paths_env.empty()) { + extra_plugin_lib_paths = extra_plugin_lib_paths_env; + } + } + auto status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths); + if (!domain_list.empty()) { + info.custom_op_domain_list = domain_list; + } + return Status::OK(); +} + void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) { if (domain != nullptr) { for (auto ptr : domain->custom_ops_) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h index 543567f284518..232ec6d3b5799 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h @@ -14,6 +14,7 @@ namespace onnxruntime { common::Status LoadDynamicLibrary(onnxruntime::PathString library_name); common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths); +common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info); void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain); void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 2dfddc35ed6e9..b5dbe1ac459b1 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -53,16 +53,10 @@ struct TensorrtProviderFactory : IExecutionProviderFactory { std::unique_ptr CreateProvider() override; - void GetCustomOpDomainList(std::vector& custom_op_domain_list); - private: TensorrtExecutionProviderInfo info_; }; -void TensorrtProviderFactory::GetCustomOpDomainList(std::vector& custom_op_domain_list) { - custom_op_domain_list = info_.custom_op_domain_list; -} - std::unique_ptr TensorrtProviderFactory::CreateProvider() { return std::make_unique(info_); } @@ -81,6 +75,11 @@ struct Tensorrt_Provider : Provider { info.device_id = device_id; info.has_trt_options = false; + common::Status status = CreateTensorRTCustomOpDomainList(info); + if (!status.IsOK()) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; + } + return std::make_shared(info); } @@ -122,6 +121,11 @@ struct Tensorrt_Provider : Provider { info.profile_opt_shapes = options.trt_profile_opt_shapes == nullptr ? "" : options.trt_profile_opt_shapes; info.cuda_graph_enable = options.trt_cuda_graph_enable != 0; + common::Status status = CreateTensorRTCustomOpDomainList(info); + if (!status.IsOK()) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; + } + return std::make_shared(info); } @@ -134,11 +138,6 @@ struct Tensorrt_Provider : Provider { return onnxruntime::TensorrtExecutionProviderInfo::ToProviderOptions(options); } - void GetCustomOpDomainList(IExecutionProviderFactory* factory, std::vector& custom_op_domains_ptr) override { - TensorrtProviderFactory* trt_factory = reinterpret_cast(factory); - trt_factory->GetCustomOpDomainList(custom_op_domains_ptr); - } - void Initialize() override { InitializeRegistry(); } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index ed7f8f484ba32..5fa15a03cd5b2 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1420,10 +1420,6 @@ std::shared_ptr TensorrtProviderFactoryCreator::Creat return s_library_tensorrt.Get().CreateExecutionProviderFactory(provider_options); } -void TensorrtProviderGetCustomOpDomainList(IExecutionProviderFactory* factory, std::vector& custom_op_domains_ptr) { - s_library_tensorrt.Get().GetCustomOpDomainList(factory, custom_op_domains_ptr); -} - std::shared_ptr MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* provider_options) { return s_library_migraphx.Get().CreateExecutionProviderFactory(provider_options); }