Skip to content

Commit

Permalink
use unique_ptr
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Jan 25, 2024
1 parent 4b3c473 commit 18b8ba8
Show file tree
Hide file tree
Showing 10 changed files with 21 additions and 22 deletions.
2 changes: 1 addition & 1 deletion include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class IExecutionProvider {
kernel implementation is needed for custom op since the real implementation is inside TRT. This custom op acts as
a role to help pass ONNX model validation.
*/
virtual void GetCustomOpDomainList(std::vector<std::shared_ptr<OrtCustomOpDomain>>& /*provider custom op domain list*/) const {};
virtual void GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& /*provider custom op domain list*/) const {};

/**
Returns an opaque handle whose exact type varies based on the provider
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1833,7 +1833,7 @@ nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const {
return builder_.get();
}

void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector<std::shared_ptr<OrtCustomOpDomain>>& custom_op_domain_list) const {
void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) const {
std::string extra_plugin_lib_paths{""};
if (info_.has_trt_options) {
if (!info_.extra_plugin_lib_paths.empty()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {

void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;

void GetCustomOpDomainList(std::vector<std::shared_ptr<OrtCustomOpDomain>>& custom_op_domain_list) const override;
void GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) const override;

OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ extern TensorrtLogger& GetTensorrtLogger();
* Note: Current TRT plugin doesn't have APIs to get number of inputs/outputs of the plugin.
* So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation.
*/
common::Status CreateTensorRTCustomOpDomainList(std::vector<std::shared_ptr<OrtCustomOpDomain>>& domain_list, const std::string extra_plugin_lib_paths) {
static std::shared_ptr<OrtCustomOpDomain> custom_op_domain = std::make_shared<OrtCustomOpDomain>();
static std::unordered_set<std::shared_ptr<TensorRTCustomOp>> custom_op_set;
common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths) {
static std::unique_ptr<OrtCustomOpDomain> custom_op_domain = std::make_unique<OrtCustomOpDomain>();
static std::vector<std::unique_ptr<TensorRTCustomOp>> created_custom_op_list;
if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) {
domain_list.push_back(custom_op_domain);
domain_list.push_back(custom_op_domain.get());
return Status::OK();
}

Expand Down Expand Up @@ -73,14 +73,13 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<std::shared_ptr<OrtC
continue;
}

std::shared_ptr<TensorRTCustomOp> trt_custom_op = std::make_shared<TensorRTCustomOp>(onnxruntime::kTensorrtExecutionProvider, nullptr);
trt_custom_op->SetName(plugin_creator->getPluginName());
custom_op_domain->custom_ops_.push_back(trt_custom_op.get());
custom_op_set.insert(trt_custom_op); // Make sure trt_custom_op object won't be cleaned up
created_custom_op_list.push_back(std::make_unique<TensorRTCustomOp>(onnxruntime::kTensorrtExecutionProvider, nullptr)); // Make sure TensorRTCustomOp object won't be cleaned up

Check warning on line 76 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc#L76

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc:76:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

Check warning on line 76 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc#L76

Add #include <memory> for make_unique<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc:76:  Add #include <memory> for make_unique<>  [build/include_what_you_use] [4]
created_custom_op_list.back().get()->SetName(plugin_creator->getPluginName());
custom_op_domain->custom_ops_.push_back(created_custom_op_list.back().get());
registered_plugin_names.insert(plugin_name);
}
custom_op_domain->domain_ = "trt.plugins";
domain_list.push_back(custom_op_domain);
domain_list.push_back(custom_op_domain.get());
} catch (const std::exception&) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ using namespace onnxruntime;
namespace onnxruntime {

common::Status LoadDynamicLibrary(onnxruntime::PathString library_name);
common::Status CreateTensorRTCustomOpDomainList(std::vector<std::shared_ptr<OrtCustomOpDomain>>& domain_list, const std::string extra_plugin_lib_paths);
common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths);
common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info);
void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain);
void ReleaseTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ struct ProviderInfo_TensorRT_Impl final : ProviderInfo_TensorRT {
return nullptr;
}

OrtStatus* GetTensorRTCustomOpDomainList(std::vector<std::shared_ptr<OrtCustomOpDomain>>& domain_list, const std::string extra_plugin_lib_paths) override {
OrtStatus* GetTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths) override {
common::Status status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths);
if (!status.IsOK()) {
return CreateStatus(ORT_FAIL, "[TensorRT EP] Can't create custom ops for TRT plugins.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace onnxruntime {
struct ProviderInfo_TensorRT {
virtual OrtStatus* GetCurrentGpuDeviceId(_In_ int* device_id) = 0;
virtual OrtStatus* UpdateProviderOptions(void* provider_options, const ProviderOptions& options, bool string_copy) = 0;
virtual OrtStatus* GetTensorRTCustomOpDomainList(std::vector<std::shared_ptr<OrtCustomOpDomain>>& domain_list, const std::string extra_plugin_lib_paths) = 0;
virtual OrtStatus* GetTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths) = 0;
virtual OrtStatus* ReleaseCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list) = 0;

protected:
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -664,15 +664,15 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
// Register Custom Op if EP requests it
std::vector<OrtCustomOpDomain*> custom_op_domains;
std::vector<std::shared_ptr<OrtCustomOpDomain>> candidate_custom_op_domains;
std::vector<OrtCustomOpDomain*> candidate_custom_op_domains;
p_exec_provider->GetCustomOpDomainList(candidate_custom_op_domains);

auto registry_kernels = kernel_registry_manager_.GetKernelRegistriesByProviderType(p_exec_provider->Type());

// Register the custom op domain only if it has not been registered before
if (registry_kernels.empty()) {
for (auto candidate_custom_op_domain : candidate_custom_op_domains) {
custom_op_domains.push_back(candidate_custom_op_domain.get());
custom_op_domains.push_back(candidate_custom_op_domain);
}
} else {
for (auto candidate_custom_op_domain : candidate_custom_op_domains) {
Expand All @@ -688,7 +688,7 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr
}
}
if (need_register) {
custom_op_domains.push_back(candidate_custom_op_domain.get());
custom_op_domains.push_back(candidate_custom_op_domain);
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1687,12 +1687,12 @@ void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, std::s
return false;
};

std::vector<std::shared_ptr<OrtCustomOpDomain>> custom_op_domains;
std::vector<OrtCustomOpDomain*> custom_op_domains;
onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT();
provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths);
for (auto ptr : custom_op_domains) {
if (!is_already_in_domains(ptr->domain_, options->custom_op_domains_)) {
options->custom_op_domains_.push_back(ptr.get());
options->custom_op_domains_.push_back(ptr);
} else {
LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option.";
}
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -443,11 +443,11 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti
if (it != options.end()) {
trt_extra_plugin_lib_paths = it->second;
}
std::vector<std::shared_ptr<OrtCustomOpDomain>> custom_op_domains;
std::vector<OrtCustomOpDomain*> custom_op_domains;
tensorrt_provider_info->GetTensorRTCustomOpDomainList(custom_op_domains, trt_extra_plugin_lib_paths);
for (auto ptr : custom_op_domains) {
if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) {
so.custom_op_domains_.push_back(ptr.get());
so.custom_op_domains_.push_back(ptr);
} else {
LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option.";
}
Expand Down

0 comments on commit 18b8ba8

Please sign in to comment.