diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 901f2e3a9f8d8..1de0217c7e1fa 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -155,7 +155,7 @@ class IExecutionProvider { kernel implementation is needed for custom op since the real implementation is inside TRT. This custom op acts as a role to help pass ONNX model validation. */ - virtual void GetCustomOpDomainList(std::vector>& /*provider custom op domain list*/) const {}; + virtual void GetCustomOpDomainList(std::vector& /*provider custom op domain list*/) const {}; /** Returns an opaque handle whose exact type varies based on the provider diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index c58297457076c..39e5f5be000e5 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1833,7 +1833,7 @@ nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const { return builder_.get(); } -void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector>& custom_op_domain_list) const { +void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { std::string extra_plugin_lib_paths{""}; if (info_.has_trt_options) { if (!info_.extra_plugin_lib_paths.empty()) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index bea10b7ad8cfb..ad2d2c55c67e1 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -244,7 +244,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; - void GetCustomOpDomainList(std::vector>& custom_op_domain_list) const override; + void GetCustomOpDomainList(std::vector& custom_op_domain_list) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 77d1ffbc67915..eb340ba1e64b6 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -26,11 +26,11 @@ extern TensorrtLogger& GetTensorrtLogger(); * Note: Current TRT plugin doesn't have APIs to get number of inputs/outputs of the plugin. * So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation. */ -common::Status CreateTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths) { - static std::shared_ptr custom_op_domain = std::make_shared(); - static std::unordered_set> custom_op_set; +common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) { + static std::unique_ptr custom_op_domain = std::make_unique(); + static std::vector> created_custom_op_list; if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) { - domain_list.push_back(custom_op_domain); + domain_list.push_back(custom_op_domain.get()); return Status::OK(); } @@ -73,14 +73,13 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector trt_custom_op = std::make_shared(onnxruntime::kTensorrtExecutionProvider, nullptr); - trt_custom_op->SetName(plugin_creator->getPluginName()); - custom_op_domain->custom_ops_.push_back(trt_custom_op.get()); - custom_op_set.insert(trt_custom_op); // Make sure trt_custom_op object won't be cleaned up + created_custom_op_list.push_back(std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr)); // Make sure TensorRTCustomOp object won't be cleaned up + created_custom_op_list.back().get()->SetName(plugin_creator->getPluginName()); + custom_op_domain->custom_ops_.push_back(created_custom_op_list.back().get()); registered_plugin_names.insert(plugin_name); } custom_op_domain->domain_ = "trt.plugins"; - domain_list.push_back(custom_op_domain); + domain_list.push_back(custom_op_domain.get()); } catch (const std::exception&) { LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins"; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h index 0c9251a12fa75..b19d9ab0f66d0 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h @@ -13,7 +13,7 @@ using namespace onnxruntime; namespace onnxruntime { common::Status LoadDynamicLibrary(onnxruntime::PathString library_name); -common::Status CreateTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths); +common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths); common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info); void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain); void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 6b4eaf65ec138..568da57a50956 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -32,7 +32,7 @@ struct ProviderInfo_TensorRT_Impl final : ProviderInfo_TensorRT { return nullptr; } - OrtStatus* GetTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths) override { + OrtStatus* GetTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) override { common::Status status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths); if (!status.IsOK()) { return CreateStatus(ORT_FAIL, "[TensorRT EP] Can't create custom ops for TRT plugins."); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h index bd708e0bec7f5..231e14e5c95f2 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h @@ -8,7 +8,7 @@ namespace onnxruntime { struct ProviderInfo_TensorRT { virtual OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) = 0; virtual OrtStatus* UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy) = 0; - virtual OrtStatus* GetTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths) = 0; + virtual OrtStatus* GetTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) = 0; virtual OrtStatus* ReleaseCustomOpDomainList(std::vector& domain_list) = 0; protected: diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c69d37a6ab35e..ad46a03e8fd75 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -664,7 +664,7 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) // Register Custom Op if EP requests it std::vector custom_op_domains; - std::vector> candidate_custom_op_domains; + std::vector candidate_custom_op_domains; p_exec_provider->GetCustomOpDomainList(candidate_custom_op_domains); auto registry_kernels = kernel_registry_manager_.GetKernelRegistriesByProviderType(p_exec_provider->Type()); @@ -672,7 +672,7 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr // Register the custom op domain only if it has not been registered before if (registry_kernels.empty()) { for (auto candidate_custom_op_domain : candidate_custom_op_domains) { - custom_op_domains.push_back(candidate_custom_op_domain.get()); + custom_op_domains.push_back(candidate_custom_op_domain); } } else { for (auto candidate_custom_op_domain : candidate_custom_op_domains) { @@ -688,7 +688,7 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr } } if (need_register) { - custom_op_domains.push_back(candidate_custom_op_domain.get()); + custom_op_domains.push_back(candidate_custom_op_domain); } } } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 552a38047a50b..f48110aa7ee5b 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1687,12 +1687,12 @@ void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, std::s return false; }; - std::vector> custom_op_domains; + std::vector custom_op_domains; onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths); for (auto ptr : custom_op_domains) { if (!is_already_in_domains(ptr->domain_, options->custom_op_domains_)) { - options->custom_op_domains_.push_back(ptr.get()); + options->custom_op_domains_.push_back(ptr); } else { LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option."; } diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index a76c2723c1f4b..8e13982ca6861 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -443,11 +443,11 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti if (it != options.end()) { trt_extra_plugin_lib_paths = it->second; } - std::vector> custom_op_domains; + std::vector custom_op_domains; tensorrt_provider_info->GetTensorRTCustomOpDomainList(custom_op_domains, trt_extra_plugin_lib_paths); for (auto ptr : custom_op_domains) { if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) { - so.custom_op_domains_.push_back(ptr.get()); + so.custom_op_domains_.push_back(ptr); } else { LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option."; }