From b1689a8e4ac7792814793e618c18011b2661131c Mon Sep 17 00:00:00 2001 From: Chi Lo Date: Wed, 24 Jan 2024 23:31:17 +0000 Subject: [PATCH] update --- cmake/external/emsdk | 2 +- .../core/framework/execution_provider.h | 2 +- .../tensorrt/tensorrt_provider_options.h | 2 -- .../tensorrt/tensorrt_execution_provider.cc | 17 +++++++++++------ .../tensorrt/tensorrt_execution_provider.h | 2 +- .../tensorrt_execution_provider_custom_ops.cc | 16 ++++++++++------ .../tensorrt_execution_provider_custom_ops.h | 2 +- .../tensorrt/tensorrt_provider_factory.cc | 4 +--- .../tensorrt/tensorrt_provider_factory.h | 2 +- onnxruntime/core/session/inference_session.cc | 8 +++++--- onnxruntime/core/session/provider_bridge_ort.cc | 10 ++-------- onnxruntime/python/onnxruntime_pybind_state.cc | 7 ------- 12 files changed, 34 insertions(+), 40 deletions(-) diff --git a/cmake/external/emsdk b/cmake/external/emsdk index 4e2496141eda1..a896e3d066448 160000 --- a/cmake/external/emsdk +++ b/cmake/external/emsdk @@ -1 +1 @@ -Subproject commit 4e2496141eda15040c44e9bbf237a1326368e34c +Subproject commit a896e3d066448b3530dbcaa48869fafefd738f57 diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index 1de0217c7e1fa..901f2e3a9f8d8 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -155,7 +155,7 @@ class IExecutionProvider { kernel implementation is needed for custom op since the real implementation is inside TRT. This custom op acts as a role to help pass ONNX model validation. */ - virtual void GetCustomOpDomainList(std::vector& /*provider custom op domain list*/) const {}; + virtual void GetCustomOpDomainList(std::vector>& /*provider custom op domain list*/) const {}; /** Returns an opaque handle whose exact type varies based on the provider diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h index 910e9d5fb4417..32a9f06464ace 100644 --- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h +++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h @@ -70,6 +70,4 @@ struct OrtTensorRTProviderOptionsV2 { int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix - - std::vector custom_op_domain_list; }; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index fe6b959b962de..f72cdd5c4f899 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1833,14 +1833,19 @@ nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const { return builder_.get(); } -void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { - if (info_.custom_op_domain_list.empty()) { - common::Status status = CreateTensorRTCustomOpDomainList(info_); - if (!status.IsOK()) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; +void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector>& custom_op_domain_list) const { + std::string extra_plugin_lib_paths{""}; + if (info_.has_trt_options) { + if (!info_.extra_plugin_lib_paths.empty()) { + extra_plugin_lib_paths = info_.extra_plugin_lib_paths; + } + } else { + const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths); + if (!extra_plugin_lib_paths_env.empty()) { + extra_plugin_lib_paths = extra_plugin_lib_paths_env; } } - custom_op_domain_list = info_.custom_op_domain_list; + CreateTensorRTCustomOpDomainList(custom_op_domain_list, extra_plugin_lib_paths); } // Check the graph is the subgraph of control flow op diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index ad2d2c55c67e1..bea10b7ad8cfb 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -244,7 +244,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; - void GetCustomOpDomainList(std::vector& custom_op_domain_list) const override; + void GetCustomOpDomainList(std::vector>& custom_op_domain_list) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index 4e466a5d568a6..04005ab8a050e 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -26,9 +26,12 @@ extern TensorrtLogger& GetTensorrtLogger(); * Note: Current TRT plugin doesn't have APIs to get number of inputs/outputs of the plugin. * So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation. */ -common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) { - std::unique_ptr custom_op_domain = std::make_unique(); - custom_op_domain->domain_ = "trt.plugins"; +common::Status CreateTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths) { + static std::shared_ptr custom_op_domain = std::make_shared(); + if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) { + domain_list.push_back(custom_op_domain); + return Status::OK(); + } // Load any extra TRT plugin library if any. // When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry. @@ -69,12 +72,13 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& continue; } - std::unique_ptr trt_custom_op = std::make_unique(onnxruntime::kTensorrtExecutionProvider, nullptr); + std::shared_ptr trt_custom_op = std::make_shared(onnxruntime::kTensorrtExecutionProvider, nullptr); trt_custom_op->SetName(plugin_creator->getPluginName()); - custom_op_domain->custom_ops_.push_back(trt_custom_op.release()); + custom_op_domain->custom_ops_.push_back(trt_custom_op.get()); registered_plugin_names.insert(plugin_name); } - domain_list.push_back(custom_op_domain.release()); + custom_op_domain->domain_ = "trt.plugins"; + domain_list.push_back(custom_op_domain); } catch (const std::exception&) { LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins"; } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h index b19d9ab0f66d0..0c9251a12fa75 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.h @@ -13,7 +13,7 @@ using namespace onnxruntime; namespace onnxruntime { common::Status LoadDynamicLibrary(onnxruntime::PathString library_name); -common::Status CreateTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths); +common::Status CreateTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths); common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info); void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain); void ReleaseTensorRTCustomOpDomainList(std::vector& custom_op_domain_list); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index 9f80fae676029..6b4eaf65ec138 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -32,7 +32,7 @@ struct ProviderInfo_TensorRT_Impl final : ProviderInfo_TensorRT { return nullptr; } - OrtStatus* GetTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) override { + OrtStatus* GetTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths) override { common::Status status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths); if (!status.IsOK()) { return CreateStatus(ORT_FAIL, "[TensorRT EP] Can't create custom ops for TRT plugins."); @@ -114,8 +114,6 @@ struct Tensorrt_Provider : Provider { info.ep_context_embed_mode = options.trt_ep_context_embed_mode; info.engine_cache_prefix = options.trt_engine_cache_prefix == nullptr ? "" : options.trt_engine_cache_prefix; - info.custom_op_domain_list = options.custom_op_domain_list; - return std::make_shared(info); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h index 231e14e5c95f2..bd708e0bec7f5 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h @@ -8,7 +8,7 @@ namespace onnxruntime { struct ProviderInfo_TensorRT { virtual OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) = 0; virtual OrtStatus* UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy) = 0; - virtual OrtStatus* GetTensorRTCustomOpDomainList(std::vector& domain_list, const std::string extra_plugin_lib_paths) = 0; + virtual OrtStatus* GetTensorRTCustomOpDomainList(std::vector>& domain_list, const std::string extra_plugin_lib_paths) = 0; virtual OrtStatus* ReleaseCustomOpDomainList(std::vector& domain_list) = 0; protected: diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e8853c8824738..6de53243f880f 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -664,14 +664,16 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) // Register Custom Op if EP requests it std::vector custom_op_domains; - std::vector candidate_custom_op_domains; + std::vector> candidate_custom_op_domains; p_exec_provider->GetCustomOpDomainList(candidate_custom_op_domains); auto registry_kernels = kernel_registry_manager_.GetKernelRegistriesByProviderType(p_exec_provider->Type()); // Register the custom op domain only if it has not been registered before if (registry_kernels.empty()) { - custom_op_domains = candidate_custom_op_domains; + for (auto candidate_custom_op_domain : candidate_custom_op_domains) { + custom_op_domains.push_back(candidate_custom_op_domain.get()); + } } else { for (auto candidate_custom_op_domain : candidate_custom_op_domains) { for (auto registry_kernel : registry_kernels) { @@ -686,7 +688,7 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr } } if (need_register) { - custom_op_domains.push_back(candidate_custom_op_domain); + custom_op_domains.push_back(candidate_custom_op_domain.get()); } } } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 1d7807d04fe52..9e8c215af39f0 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1687,17 +1687,12 @@ void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, OrtTen return false; }; - std::vector custom_op_domains; + std::vector> custom_op_domains; onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths); for (auto ptr : custom_op_domains) { if (!is_already_in_domains(ptr->domain_, options->custom_op_domains_)) { - options->custom_op_domains_.push_back(ptr); - // The OrtCustomOpDomains objects created by TRT EP's GetTensorRTCustomOpDomainList() should be released once session is finished. - // TRT EP needs to keep all the pointers OrtCustomOpDomain obejcts and releases them upon TRT EP destruction. - if (tensorrt_options) { - tensorrt_options->custom_op_domain_list.push_back(ptr); - } + options->custom_op_domains_.push_back(ptr.get()); } else { LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option."; } @@ -1877,7 +1872,6 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, OrtTensorRTProviderOptionsV2 new_tensorrt_options = *tensorrt_options; // copy and assign from tensorrt_options std::string extra_plugin_lib_paths = (tensorrt_options == nullptr || tensorrt_options->trt_extra_plugin_lib_paths == nullptr) ? "" : tensorrt_options->trt_extra_plugin_lib_paths; - new_tensorrt_options.custom_op_domain_list.clear(); AddTensorRTCustomOpDomainToSessionOption(options, &new_tensorrt_options, extra_plugin_lib_paths); #if !defined(ORT_MINIMAL_BUILD) && defined(USE_TENSORRT) diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 72df0b9006bf9..f7ed5520727db 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -745,13 +745,6 @@ std::unique_ptr CreateExecutionProviderInstance( ORT_THROW("Invalid TensorRT EP option: ", option.first); } } - // The OrtCustomOpDomains objects created by TRT EP's GetTensorRTCustomOpDomainList() should be released once session is finished. - // TRT EP needs to keep all the pointers OrtCustomOpDomain obejcts and releases them upon TRT EP destruction. - for (auto ptr : session_options.custom_op_domains_) { - if (ptr->domain_ == "trt.plugins") { - params->custom_op_domain_list.push_back(ptr); - } - } if (std::shared_ptr tensorrt_provider_factory = onnxruntime::TensorrtProviderFactoryCreator::Create(¶ms)) { return tensorrt_provider_factory->CreateProvider(); }