From 8080e593873a7e652afc5b9c84fba6eade07cab7 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Mon, 22 Jan 2024 22:26:50 +0000 Subject: [PATCH 01/10] update --- .../tensorrt/tensorrt_execution_provider_custom_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 4e466a5d568a6..24d135f224390 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -71,7 +71,7 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& std::unique_ptr trt_custom_op = std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr); trt_custom_op->SetName(plugin_creator->getPluginName()); - custom_op_domain->custom_ops_.push_back(trt_custom_op.release()); + custom_op_domain->custom_ops_.push_back(trt_custom_op.get()); registered_plugin_names.insert(plugin_name); } domain_list.push_back(custom_op_domain.release()); From 8f061a562fb5f98881325c7c5850387fce2e4815 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 23 Jan 2024 01:14:36 +0000 Subject: [PATCH 02/10] update --- .../tensorrt/tensorrt_provider_options.h | 2 + .../tensorrt_execution_provider_custom_ops.cc | 2 +- .../tensorrt/tensorrt_provider_factory.cc | 2 + .../core/session/provider_bridge_ort.cc | 44 ++++++++++--------- 4 files changed, 28 insertions(+), 22 deletions(-) diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 32a9f06464ace..910e9d5fb4417 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -70,4 +70,6 @@ struct OrtTensorRTProviderOptionsV2 { int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix + + std::vector custom_op_domain_list; }; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 24d135f224390..4e466a5d568a6 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -71,7 +71,7 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& std::unique_ptr trt_custom_op = std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr); trt_custom_op->SetName(plugin_creator->getPluginName()); - custom_op_domain->custom_ops_.push_back(trt_custom_op.get()); + custom_op_domain->custom_ops_.push_back(trt_custom_op.release()); registered_plugin_names.insert(plugin_name); } domain_list.push_back(custom_op_domain.release()); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 568da57a50956..7f6a353706fb6 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -113,6 +113,8 @@ struct Tensorrt_Provider : Provider { info.ep_context_file_path = options.trt_ep_context_file_path == nullptr ? "" : options.trt_ep_context_file_path; info.ep_context_embed_mode = options.trt_ep_context_embed_mode; info.engine_cache_prefix = options.trt_engine_cache_prefix == nullptr ? "" : options.trt_engine_cache_prefix; + + info.custom_op_domain_list = options.custom_op_domain_list; return std::make_shared(info); } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 3269c9f0f4e4b..93ee472e2d8b1 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1677,7 +1677,7 @@ ProviderOptions GetProviderInfo_Cuda(const OrtCUDAProviderOptionsV2* provider_op } // namespace onnxruntime -void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, std::string extra_plugin_lib_paths) { +void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, OrtTensorRTProviderOptionsV2* tensorrt_options, std::string extra_plugin_lib_paths) { auto is_already_in_domains = [&](std::string& domain_name, std::vector& domains) { for (auto ptr : domains) { if (domain_name == ptr->domain_) { @@ -1693,6 +1693,9 @@ void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, std::s for (auto ptr : custom_op_domains) { if (!is_already_in_domains(ptr->domain_, options->custom_op_domains_)) { options->custom_op_domains_.push_back(ptr); + if (tensorrt_options) { + tensorrt_options->custom_op_domain_list.push_back(ptr); // TensorRT EP should keep all the pointers to OrtCustomOpDomain obejcts for the purpose of releasing them at EP destruction + } } else { LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option."; } @@ -1721,7 +1724,7 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtS options->provider_factories.push_back(factory); std::string extra_plugin_lib_paths = onnxruntime::Env::Default().GetEnvironmentVar("trt_extra_plugin_lib_paths"); - AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths); + AddTensorRTCustomOpDomainToSessionOption(options, nullptr, extra_plugin_lib_paths); return nullptr; API_IMPL_END @@ -1744,19 +1747,20 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In std::shared_ptr factory; + OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options); + trt_options_converted.custom_op_domain_list.clear(); + AddTensorRTCustomOpDomainToSessionOption(options, &trt_options_converted, ""); + #if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0"; // If EP context configs are provided in session options, we need to propagate them to provider options if (ep_context_cache_enabled_from_sess_options) { - OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options); - onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &trt_options_converted); factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted); - } else { - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); } + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted); #else - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted); #endif if (!factory) { @@ -1765,8 +1769,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In options->provider_factories.push_back(factory); - AddTensorRTCustomOpDomainToSessionOption(options, ""); - return nullptr; API_IMPL_END } @@ -1898,6 +1900,16 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, std::shared_ptr factory; + // This function might need to update the "const" OrtTensorRTProviderOptionsV2 object which can't be modified. + // Therefore, we need to create a new OrtTensorRTProviderOptionsV2 object and copy from tensorrt_options and use this new object to create the factory instead. + // Note: No need to worry about new_tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will + // create a factory object that copies any provider options from tensorrt_options including "const char*" provider options. + OrtTensorRTProviderOptionsV2 new_tensorrt_options = *tensorrt_options; // copy and assign from tensorrt_options + + std::string extra_plugin_lib_paths = (tensorrt_options == nullptr || tensorrt_options->trt_extra_plugin_lib_paths == nullptr) ? "" : tensorrt_options->trt_extra_plugin_lib_paths; + new_tensorrt_options.custom_op_domain_list.clear(); + AddTensorRTCustomOpDomainToSessionOption(options, &new_tensorrt_options, extra_plugin_lib_paths); + #if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) auto ep_context_cache_enabled_from_provider_options = tensorrt_options->trt_dump_ep_context_model != 0; auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0"; @@ -1906,18 +1918,11 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, // if provider options already have the EP context configs provided, the configs in session options will be ignored // since provider options has higher priority than session options. if (!ep_context_cache_enabled_from_provider_options && ep_context_cache_enabled_from_sess_options) { - // We need to create a new provider options V2 object and copy from provider_options, due to the "const" object pointed by provider_options can't be modified. - // Note: No need to worry about tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will - // create a factory object that copies any provider options from tensorrt_options including "const char*" provider options. - OrtTensorRTProviderOptionsV2 new_tensorrt_options = *tensorrt_options; // copy and assign from tensorrt_options - onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &new_tensorrt_options); - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&new_tensorrt_options); - } else { - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); } + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&new_tensorrt_options); #else - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&new_tensorrt_options); #endif if (!factory) { @@ -1926,9 +1931,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, options->provider_factories.push_back(factory); - std::string extra_plugin_lib_paths = (tensorrt_options == nullptr || tensorrt_options->trt_extra_plugin_lib_paths == nullptr) ? "" : tensorrt_options->trt_extra_plugin_lib_paths; - AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths); - return nullptr; API_IMPL_END } From 23ffbc2c795bdaf83fefe83607a3ece96e234ec4 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Tue, 23 Jan 2024 22:31:27 +0000 Subject: [PATCH 03/10] update --- .../tensorrt/tensorrt_provider_factory.cc | 2 +- .../core/session/provider_bridge_ort.cc | 44 +++---------------- .../python/onnxruntime_pybind_state.cc | 7 +++ 3 files changed, 15 insertions(+), 38 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 7f6a353706fb6..9f80fae676029 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -113,7 +113,7 @@ struct Tensorrt_Provider : Provider { info.ep_context_file_path = options.trt_ep_context_file_path == nullptr ? "" : options.trt_ep_context_file_path; info.ep_context_embed_mode = options.trt_ep_context_embed_mode; info.engine_cache_prefix = options.trt_engine_cache_prefix == nullptr ? "" : options.trt_engine_cache_prefix; - + info.custom_op_domain_list = options.custom_op_domain_list; return std::make_shared(info); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 93ee472e2d8b1..1d7807d04fe52 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1693,8 +1693,10 @@ void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, OrtTen for (auto ptr : custom_op_domains) { if (!is_already_in_domains(ptr->domain_, options->custom_op_domains_)) { options->custom_op_domains_.push_back(ptr); + // The OrtCustomOpDomains objects created by TRT EP's GetTensorRTCustomOpDomainList() should be released once session is finished. + // TRT EP needs to keep all the pointers OrtCustomOpDomain obejcts and releases them upon TRT EP destruction. if (tensorrt_options) { - tensorrt_options->custom_op_domain_list.push_back(ptr); // TensorRT EP should keep all the pointers to OrtCustomOpDomain obejcts for the purpose of releasing them at EP destruction + tensorrt_options->custom_op_domain_list.push_back(ptr); } } else { LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option."; @@ -1716,17 +1718,9 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessi ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id) { API_IMPL_BEGIN - auto factory = onnxruntime::TensorrtProviderFactoryCreator::Create(device_id); - if (!factory) { - return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library"); - } - - options->provider_factories.push_back(factory); - - std::string extra_plugin_lib_paths = onnxruntime::Env::Default().GetEnvironmentVar("trt_extra_plugin_lib_paths"); - AddTensorRTCustomOpDomainToSessionOption(options, nullptr, extra_plugin_lib_paths); - - return nullptr; + OrtTensorRTProviderOptionsV2 tensorrt_options; + tensorrt_options.device_id = device_id; + return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &tensorrt_options); API_IMPL_END } @@ -1744,32 +1738,8 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtS ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options) { API_IMPL_BEGIN - - std::shared_ptr factory; - OrtTensorRTProviderOptionsV2 trt_options_converted = onnxruntime::OrtTensorRTProviderOptionsToOrtTensorRTProviderOptionsV2(tensorrt_options); - trt_options_converted.custom_op_domain_list.clear(); - AddTensorRTCustomOpDomainToSessionOption(options, &trt_options_converted, ""); - -#if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) - auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0"; - // If EP context configs are provided in session options, we need to propagate them to provider options - if (ep_context_cache_enabled_from_sess_options) { - onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &trt_options_converted); - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted); - } - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted); -#else - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&trt_options_converted); -#endif - - if (!factory) { - return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_Tensorrt: Failed to load shared library"); - } - - options->provider_factories.push_back(factory); - - return nullptr; + return OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2(options, &trt_options_converted); API_IMPL_END } diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f7ed5520727db..72df0b9006bf9 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -745,6 +745,13 @@ std::unique_ptr CreateExecutionProviderInstance( ORT_THROW("Invalid TensorRT EP option: ", option.first); } } + // The OrtCustomOpDomains objects created by TRT EP's GetTensorRTCustomOpDomainList() should be released once session is finished. + // TRT EP needs to keep all the pointers OrtCustomOpDomain obejcts and releases them upon TRT EP destruction. + for (auto ptr : session_options.custom_op_domains_) { + if (ptr->domain_ == "trt.plugins") { + params->custom_op_domain_list.push_back(ptr); + } + } if (std::shared_ptr tensorrt_provider_factory = onnxruntime::TensorrtProviderFactoryCreator::Create(¶ms)) { return tensorrt_provider_factory->CreateProvider(); } From b1689a8e4ac7792814793e618c18011b2661131c Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 24 Jan 2024 23:31:17 +0000 Subject: [PATCH 04/10] update --- cmake/external/emsdk | 2 +- .../core/framework/execution_provider.h | 2 +- .../tensorrt/tensorrt_provider_options.h | 2 -- .../tensorrt/tensorrt_execution_provider.cc | 17 +++++++++++------ .../tensorrt/tensorrt_execution_provider.h | 2 +- .../tensorrt_execution_provider_custom_ops.cc | 16 ++++++++++------ .../tensorrt_execution_provider_custom_ops.h | 2 +- .../tensorrt/tensorrt_provider_factory.cc | 4 +--- .../tensorrt/tensorrt_provider_factory.h | 2 +- onnxruntime/core/session/inference_session.cc | 8 +++++--- onnxruntime/core/session/provider_bridge_ort.cc | 10 ++-------- onnxruntime/python/onnxruntime_pybind_state.cc | 7 ------- 12 files changed, 34 insertions(+), 40 deletions(-) diff --git a/cmake/external/emsdk b/cmake/external/emsdk index 4e2496141eda1..a896e3d066448 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit 4e2496141eda15040c44e9bbf237a1326368e34c +Subproject commit a896e3d066448b3530dbcaa48869fafefd738f57 diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 1de0217c7e1fa..901f2e3a9f8d8 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -155,7 +155,7 @@ class IExecutionProvider { kernel implementation is needed for custom op since the real implementation is inside TRT. This custom op acts as a role to help pass ONNX model validation. */ - virtual void GetCustomOpDomainList(std::vector& /*provider custom op domain list*/) const {}; + virtual void GetCustomOpDomainList(std::vector>& /*provider custom op domain list*/) const {}; /** Returns an opaque handle whose exact type varies based on the provider diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 910e9d5fb4417..32a9f06464ace 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -70,6 +70,4 @@ struct OrtTensorRTProviderOptionsV2 { int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix - - std::vector custom_op_domain_list; }; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index fe6b959b962de..f72cdd5c4f899 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1833,14 +1833,19 @@ nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const { return builder_.get(); } -void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { - if (info_.custom_op_domain_list.empty()) { - common::Status status = CreateTensorRTCustomOpDomainList(info_); - if (!status.IsOK()) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; +void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector>& custom_op_domain_list) const { + std::string extra_plugin_lib_paths{""}; + if (info_.has_trt_options) { + if (!info_.extra_plugin_lib_paths.empty()) { + extra_plugin_lib_paths = info_.extra_plugin_lib_paths; + } + } else { + const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths); + if (!extra_plugin_lib_paths_env.empty()) { + extra_plugin_lib_paths = extra_plugin_lib_paths_env; } } - custom_op_domain_list = info_.custom_op_domain_list; + CreateTensorRTCustomOpDomainList(custom_op_domain_list, extra_plugin_lib_paths); } // Check the graph is the subgraph of control flow op diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index ad2d2c55c67e1..bea10b7ad8cfb 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -244,7 +244,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; - void GetCustomOpDomainList(std::vector& custom_op_domain_list) const override; + void GetCustomOpDomainList(std::vector>& custom_op_domain_list) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 4e466a5d568a6..04005ab8a050e 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -26,9 +26,12 @@ extern TensorrtLogger& GetTensorrtLogger(); * Note: Current TRT plugin doesn't have APIs to get number of inputs/outputs of the plugin. * So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation. */ -common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) { - std::unique_ptr custom_op_domain = std::make_unique(); - custom_op_domain->domain_ = "trt.plugins"; +common::Status CreateTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths) { + static std::shared_ptr custom_op_domain = std::make_shared(); + if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) { + domain_list.push_back(custom_op_domain); + return Status::OK(); + } // Load any extra TRT plugin library if any. // When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry. @@ -69,12 +72,13 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& continue; } - std::unique_ptr trt_custom_op = std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr); + std::shared_ptr trt_custom_op = std::make_shared(onnxruntime::kTensorrtExecutionProvider, nullptr); trt_custom_op->SetName(plugin_creator->getPluginName()); - custom_op_domain->custom_ops_.push_back(trt_custom_op.release()); + custom_op_domain->custom_ops_.push_back(trt_custom_op.get()); registered_plugin_names.insert(plugin_name); } - domain_list.push_back(custom_op_domain.release()); + custom_op_domain->domain_ = "trt.plugins"; + domain_list.push_back(custom_op_domain); } catch (const std::exception&) { LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins"; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h index b19d9ab0f66d0..0c9251a12fa75 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h @@ -13,7 +13,7 @@ using namespace onnxruntime; namespace onnxruntime { common::Status LoadDynamicLibrary(onnxruntime::PathString library_name); -common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths); +common::Status CreateTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths); common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info); void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain); void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 9f80fae676029..6b4eaf65ec138 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -32,7 +32,7 @@ struct ProviderInfo_TensorRT_Impl final : ProviderInfo_TensorRT { return nullptr; } - OrtStatus* GetTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) override { + OrtStatus* GetTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths) override { common::Status status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths); if (!status.IsOK()) { return CreateStatus(ORT_FAIL, "[TensorRT EP] Can't create custom ops for TRT plugins."); @@ -114,8 +114,6 @@ struct Tensorrt_Provider : Provider { info.ep_context_embed_mode = options.trt_ep_context_embed_mode; info.engine_cache_prefix = options.trt_engine_cache_prefix == nullptr ? "" : options.trt_engine_cache_prefix; - info.custom_op_domain_list = options.custom_op_domain_list; - return std::make_shared(info); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h index 231e14e5c95f2..bd708e0bec7f5 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h @@ -8,7 +8,7 @@ namespace onnxruntime { struct ProviderInfo_TensorRT { virtual OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) = 0; virtual OrtStatus* UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy) = 0; - virtual OrtStatus* GetTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) = 0; + virtual OrtStatus* GetTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths) = 0; virtual OrtStatus* ReleaseCustomOpDomainList(std::vector& domain_list) = 0; protected: diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e8853c8824738..6de53243f880f 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -664,14 +664,16 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) // Register Custom Op if EP requests it std::vector custom_op_domains; - std::vector candidate_custom_op_domains; + std::vector> candidate_custom_op_domains; p_exec_provider->GetCustomOpDomainList(candidate_custom_op_domains); auto registry_kernels = kernel_registry_manager_.GetKernelRegistriesByProviderType(p_exec_provider->Type()); // Register the custom op domain only if it has not been registered before if (registry_kernels.empty()) { - custom_op_domains = candidate_custom_op_domains; + for (auto candidate_custom_op_domain : candidate_custom_op_domains) { + custom_op_domains.push_back(candidate_custom_op_domain.get()); + } } else { for (auto candidate_custom_op_domain : candidate_custom_op_domains) { for (auto registry_kernel : registry_kernels) { @@ -686,7 +688,7 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr } } if (need_register) { - custom_op_domains.push_back(candidate_custom_op_domain); + custom_op_domains.push_back(candidate_custom_op_domain.get()); } } } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 1d7807d04fe52..9e8c215af39f0 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1687,17 +1687,12 @@ void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, OrtTen return false; }; - std::vector custom_op_domains; + std::vector> custom_op_domains; onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths); for (auto ptr : custom_op_domains) { if (!is_already_in_domains(ptr->domain_, options->custom_op_domains_)) { - options->custom_op_domains_.push_back(ptr); - // The OrtCustomOpDomains objects created by TRT EP's GetTensorRTCustomOpDomainList() should be released once session is finished. - // TRT EP needs to keep all the pointers OrtCustomOpDomain obejcts and releases them upon TRT EP destruction. - if (tensorrt_options) { - tensorrt_options->custom_op_domain_list.push_back(ptr); - } + options->custom_op_domains_.push_back(ptr.get()); } else { LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option."; } @@ -1877,7 +1872,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, OrtTensorRTProviderOptionsV2 new_tensorrt_options = *tensorrt_options; // copy and assign from tensorrt_options std::string extra_plugin_lib_paths = (tensorrt_options == nullptr || tensorrt_options->trt_extra_plugin_lib_paths == nullptr) ? "" : tensorrt_options->trt_extra_plugin_lib_paths; - new_tensorrt_options.custom_op_domain_list.clear(); AddTensorRTCustomOpDomainToSessionOption(options, &new_tensorrt_options, extra_plugin_lib_paths); #if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 72df0b9006bf9..f7ed5520727db 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -745,13 +745,6 @@ std::unique_ptr CreateExecutionProviderInstance( ORT_THROW("Invalid TensorRT EP option: ", option.first); } } - // The OrtCustomOpDomains objects created by TRT EP's GetTensorRTCustomOpDomainList() should be released once session is finished. - // TRT EP needs to keep all the pointers OrtCustomOpDomain obejcts and releases them upon TRT EP destruction. - for (auto ptr : session_options.custom_op_domains_) { - if (ptr->domain_ == "trt.plugins") { - params->custom_op_domain_list.push_back(ptr); - } - } if (std::shared_ptr tensorrt_provider_factory = onnxruntime::TensorrtProviderFactoryCreator::Create(¶ms)) { return tensorrt_provider_factory->CreateProvider(); } From 23bba529f88cb6164f79ad78aa0398c433134206 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 24 Jan 2024 23:36:05 +0000 Subject: [PATCH 05/10] revert emsdk --- cmake/external/emsdk | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/external/emsdk b/cmake/external/emsdk index a896e3d066448..4e2496141eda1 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit a896e3d066448b3530dbcaa48869fafefd738f57 +Subproject commit 4e2496141eda15040c44e9bbf237a1326368e34c From ccb2f1a3e5dc9c566791c371a0460abe0f5dc0f6 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 25 Jan 2024 00:00:12 +0000 Subject: [PATCH 06/10] update --- .../core/session/provider_bridge_ort.cc | 25 ++++++++++--------- 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 2330f80eb84df..552a38047a50b 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1677,7 +1677,7 @@ ProviderOptions GetProviderInfo_Cuda(const OrtCUDAProviderOptionsV2* provider_op } // namespace onnxruntime -void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, OrtTensorRTProviderOptionsV2* tensorrt_options, std::string extra_plugin_lib_paths) { +void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, std::string extra_plugin_lib_paths) { auto is_already_in_domains = [&](std::string& domain_name, std::vector& domains) { for (auto ptr : domains) { if (domain_name == ptr->domain_) { @@ -1865,15 +1865,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, std::shared_ptr factory; - // This function might need to update the "const" OrtTensorRTProviderOptionsV2 object which can't be modified. - // Therefore, we need to create a new OrtTensorRTProviderOptionsV2 object and copy from tensorrt_options and use this new object to create the factory instead. - // Note: No need to worry about new_tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will - // create a factory object that copies any provider options from tensorrt_options including "const char*" provider options. - OrtTensorRTProviderOptionsV2 new_tensorrt_options = *tensorrt_options; // copy and assign from tensorrt_options - - std::string extra_plugin_lib_paths = (tensorrt_options == nullptr || tensorrt_options->trt_extra_plugin_lib_paths == nullptr) ? "" : tensorrt_options->trt_extra_plugin_lib_paths; - AddTensorRTCustomOpDomainToSessionOption(options, &new_tensorrt_options, extra_plugin_lib_paths); - #if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) auto ep_context_cache_enabled_from_provider_options = tensorrt_options->trt_dump_ep_context_model != 0; auto ep_context_cache_enabled_from_sess_options = (options->value).config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") != "0"; @@ -1882,11 +1873,18 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, // if provider options already have the EP context configs provided, the configs in session options will be ignored // since provider options has higher priority than session options. if (!ep_context_cache_enabled_from_provider_options && ep_context_cache_enabled_from_sess_options) { + // This function might need to update the "const" OrtTensorRTProviderOptionsV2 object which can't be modified. + // Therefore, we need to create a new OrtTensorRTProviderOptionsV2 object and copy from tensorrt_options and use this new object to create the factory instead. + // Note: No need to worry about new_tensorrt_options being a local variable, CreateExecutionProviderFactory() in TRT EP will + // create a factory object that copies any provider options from tensorrt_options including "const char*" provider options. + OrtTensorRTProviderOptionsV2 new_tensorrt_options = *tensorrt_options; // copy and assign from tensorrt_options onnxruntime::UpdateOrtTensorRTProviderOptionsV2FromSessionOptionsConfigs(options, &new_tensorrt_options); + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&new_tensorrt_options); + } else { + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); } - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&new_tensorrt_options); #else - factory = onnxruntime::TensorrtProviderFactoryCreator::Create(&new_tensorrt_options); + factory = onnxruntime::TensorrtProviderFactoryCreator::Create(tensorrt_options); #endif if (!factory) { @@ -1895,6 +1893,9 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, options->provider_factories.push_back(factory); + std::string extra_plugin_lib_paths = (tensorrt_options == nullptr || tensorrt_options->trt_extra_plugin_lib_paths == nullptr) ? "" : tensorrt_options->trt_extra_plugin_lib_paths; + AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths); + return nullptr; API_IMPL_END } From 178e3065782eeb7fce1d6fd459244270d039bb82 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 25 Jan 2024 00:21:05 +0000 Subject: [PATCH 07/10] update --- .../tensorrt/tensorrt_execution_provider.cc | 5 ++++- .../tensorrt_execution_provider_custom_ops.cc | 20 ------------------- 2 files changed, 4 insertions(+), 21 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index f72cdd5c4f899..c58297457076c 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1845,7 +1845,10 @@ void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector domain_list; - std::string extra_plugin_lib_paths{""}; - if (info.has_trt_options) { - if (!info.extra_plugin_lib_paths.empty()) { - extra_plugin_lib_paths = info.extra_plugin_lib_paths; - } - } else { - const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths); - if (!extra_plugin_lib_paths_env.empty()) { - extra_plugin_lib_paths = extra_plugin_lib_paths_env; - } - } - auto status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths); - if (!domain_list.empty()) { - info.custom_op_domain_list = domain_list; - } - return Status::OK(); -} - void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) { if (domain != nullptr) { for (auto ptr : domain->custom_ops_) { From 4b3c47322475dcfbde549085782ab9f5d458cb8e Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 25 Jan 2024 02:08:52 +0000 Subject: [PATCH 08/10] fix bug --- .../tensorrt/tensorrt_execution_provider_custom_ops.cc | 2 ++ onnxruntime/python/onnxruntime_pybind_state.cc | 8 ++++---- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 9452b3f1fd46b..77d1ffbc67915 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -28,6 +28,7 @@ extern TensorrtLogger& GetTensorrtLogger(); */ common::Status CreateTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths) { static std::shared_ptr custom_op_domain = std::make_shared(); + static std::unordered_set> custom_op_set; if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) { domain_list.push_back(custom_op_domain); return Status::OK(); @@ -75,6 +76,7 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector trt_custom_op = std::make_shared(onnxruntime::kTensorrtExecutionProvider, nullptr); trt_custom_op->SetName(plugin_creator->getPluginName()); custom_op_domain->custom_ops_.push_back(trt_custom_op.get()); + custom_op_set.insert(trt_custom_op); // Make sure trt_custom_op object won't be cleaned up registered_plugin_names.insert(plugin_name); } custom_op_domain->domain_ = "trt.plugins"; diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index f7ed5520727db..a76c2723c1f4b 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -443,11 +443,11 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti if (it != options.end()) { trt_extra_plugin_lib_paths = it->second; } - std::vector domain_list; - tensorrt_provider_info->GetTensorRTCustomOpDomainList(domain_list, trt_extra_plugin_lib_paths); - for (auto ptr : domain_list) { + std::vector> custom_op_domains; + tensorrt_provider_info->GetTensorRTCustomOpDomainList(custom_op_domains, trt_extra_plugin_lib_paths); + for (auto ptr : custom_op_domains) { if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) { - so.custom_op_domains_.push_back(ptr); + so.custom_op_domains_.push_back(ptr.get()); } else { LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option."; } From 18b8ba8db20a01675826484102557ea62777c4aa Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 25 Jan 2024 16:31:34 +0000 Subject: [PATCH 09/10] use unique_ptr --- .../core/framework/execution_provider.h | 2 +- .../tensorrt/tensorrt_execution_provider.cc | 2 +- .../tensorrt/tensorrt_execution_provider.h | 2 +- .../tensorrt_execution_provider_custom_ops.cc | 17 ++++++++--------- .../tensorrt_execution_provider_custom_ops.h | 2 +- .../tensorrt/tensorrt_provider_factory.cc | 2 +- .../tensorrt/tensorrt_provider_factory.h | 2 +- onnxruntime/core/session/inference_session.cc | 6 +++--- onnxruntime/core/session/provider_bridge_ort.cc | 4 ++-- onnxruntime/python/onnxruntime_pybind_state.cc | 4 ++-- 10 files changed, 21 insertions(+), 22 deletions(-) diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 901f2e3a9f8d8..1de0217c7e1fa 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -155,7 +155,7 @@ class IExecutionProvider { kernel implementation is needed for custom op since the real implementation is inside TRT. This custom op acts as a role to help pass ONNX model validation. */ - virtual void GetCustomOpDomainList(std::vector>& /*provider custom op domain list*/) const {}; + virtual void GetCustomOpDomainList(std::vector& /*provider custom op domain list*/) const {}; /** Returns an opaque handle whose exact type varies based on the provider diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index c58297457076c..39e5f5be000e5 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1833,7 +1833,7 @@ nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const { return builder_.get(); } -void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector>& custom_op_domain_list) const { +void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { std::string extra_plugin_lib_paths{""}; if (info_.has_trt_options) { if (!info_.extra_plugin_lib_paths.empty()) { diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index bea10b7ad8cfb..ad2d2c55c67e1 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -244,7 +244,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; - void GetCustomOpDomainList(std::vector>& custom_op_domain_list) const override; + void GetCustomOpDomainList(std::vector& custom_op_domain_list) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 77d1ffbc67915..eb340ba1e64b6 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -26,11 +26,11 @@ extern TensorrtLogger& GetTensorrtLogger(); * Note: Current TRT plugin doesn't have APIs to get number of inputs/outputs of the plugin. * So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation. */ -common::Status CreateTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths) { - static std::shared_ptr custom_op_domain = std::make_shared(); - static std::unordered_set> custom_op_set; +common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) { + static std::unique_ptr custom_op_domain = std::make_unique(); + static std::vector> created_custom_op_list; if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) { - domain_list.push_back(custom_op_domain); + domain_list.push_back(custom_op_domain.get()); return Status::OK(); } @@ -73,14 +73,13 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector trt_custom_op = std::make_shared(onnxruntime::kTensorrtExecutionProvider, nullptr); - trt_custom_op->SetName(plugin_creator->getPluginName()); - custom_op_domain->custom_ops_.push_back(trt_custom_op.get()); - custom_op_set.insert(trt_custom_op); // Make sure trt_custom_op object won't be cleaned up + created_custom_op_list.push_back(std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr)); // Make sure TensorRTCustomOp object won't be cleaned up + created_custom_op_list.back().get()->SetName(plugin_creator->getPluginName()); + custom_op_domain->custom_ops_.push_back(created_custom_op_list.back().get()); registered_plugin_names.insert(plugin_name); } custom_op_domain->domain_ = "trt.plugins"; - domain_list.push_back(custom_op_domain); + domain_list.push_back(custom_op_domain.get()); } catch (const std::exception&) { LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins"; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h index 0c9251a12fa75..b19d9ab0f66d0 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h @@ -13,7 +13,7 @@ using namespace onnxruntime; namespace onnxruntime { common::Status LoadDynamicLibrary(onnxruntime::PathString library_name); -common::Status CreateTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths); +common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths); common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info); void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain); void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 6b4eaf65ec138..568da57a50956 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -32,7 +32,7 @@ struct ProviderInfo_TensorRT_Impl final : ProviderInfo_TensorRT { return nullptr; } - OrtStatus* GetTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths) override { + OrtStatus* GetTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) override { common::Status status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths); if (!status.IsOK()) { return CreateStatus(ORT_FAIL, "[TensorRT EP] Can't create custom ops for TRT plugins."); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h index bd708e0bec7f5..231e14e5c95f2 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h @@ -8,7 +8,7 @@ namespace onnxruntime { struct ProviderInfo_TensorRT { virtual OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) = 0; virtual OrtStatus* UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy) = 0; - virtual OrtStatus* GetTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths) = 0; + virtual OrtStatus* GetTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) = 0; virtual OrtStatus* ReleaseCustomOpDomainList(std::vector& domain_list) = 0; protected: diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c69d37a6ab35e..ad46a03e8fd75 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -664,7 +664,7 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) // Register Custom Op if EP requests it std::vector custom_op_domains; - std::vector> candidate_custom_op_domains; + std::vector candidate_custom_op_domains; p_exec_provider->GetCustomOpDomainList(candidate_custom_op_domains); auto registry_kernels = kernel_registry_manager_.GetKernelRegistriesByProviderType(p_exec_provider->Type()); @@ -672,7 +672,7 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr // Register the custom op domain only if it has not been registered before if (registry_kernels.empty()) { for (auto candidate_custom_op_domain : candidate_custom_op_domains) { - custom_op_domains.push_back(candidate_custom_op_domain.get()); + custom_op_domains.push_back(candidate_custom_op_domain); } } else { for (auto candidate_custom_op_domain : candidate_custom_op_domains) { @@ -688,7 +688,7 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr } } if (need_register) { - custom_op_domains.push_back(candidate_custom_op_domain.get()); + custom_op_domains.push_back(candidate_custom_op_domain); } } } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 552a38047a50b..f48110aa7ee5b 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1687,12 +1687,12 @@ void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, std::s return false; }; - std::vector> custom_op_domains; + std::vector custom_op_domains; onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths); for (auto ptr : custom_op_domains) { if (!is_already_in_domains(ptr->domain_, options->custom_op_domains_)) { - options->custom_op_domains_.push_back(ptr.get()); + options->custom_op_domains_.push_back(ptr); } else { LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option."; } diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index a76c2723c1f4b..8e13982ca6861 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -443,11 +443,11 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti if (it != options.end()) { trt_extra_plugin_lib_paths = it->second; } - std::vector> custom_op_domains; + std::vector custom_op_domains; tensorrt_provider_info->GetTensorRTCustomOpDomainList(custom_op_domains, trt_extra_plugin_lib_paths); for (auto ptr : custom_op_domains) { if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) { - so.custom_op_domains_.push_back(ptr.get()); + so.custom_op_domains_.push_back(ptr); } else { LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option."; } From 16a04ef17da7517c84416c00fbd951774ce50558 Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Thu, 25 Jan 2024 16:39:58 +0000 Subject: [PATCH 10/10] update --- onnxruntime/core/session/inference_session.cc | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index ad46a03e8fd75..39f47c09f2402 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -671,9 +671,7 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr // Register the custom op domain only if it has not been registered before if (registry_kernels.empty()) { - for (auto candidate_custom_op_domain : candidate_custom_op_domains) { - custom_op_domains.push_back(candidate_custom_op_domain); - } + custom_op_domains = candidate_custom_op_domains; } else { for (auto candidate_custom_op_domain : candidate_custom_op_domains) { for (auto registry_kernel : registry_kernels) {