Skip to content

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Jan 25, 2024
1 parent 178e306 commit 4b3c473
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ extern TensorrtLogger& GetTensorrtLogger();
*/
common::Status CreateTensorRTCustomOpDomainList(std::vector<std::shared_ptr<OrtCustomOpDomain>>& domain_list, const std::string extra_plugin_lib_paths) {

Check warning on line 29 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc#L29

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc:29:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
static std::shared_ptr<OrtCustomOpDomain> custom_op_domain = std::make_shared<OrtCustomOpDomain>();
static std::unordered_set<std::shared_ptr<TensorRTCustomOp>> custom_op_set;
if (custom_op_domain->domain_ != "" && custom_op_domain->custom_ops_.size() > 0) {
domain_list.push_back(custom_op_domain);
return Status::OK();
Expand Down Expand Up @@ -75,6 +76,7 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<std::shared_ptr<OrtC
std::shared_ptr<TensorRTCustomOp> trt_custom_op = std::make_shared<TensorRTCustomOp>(onnxruntime::kTensorrtExecutionProvider, nullptr);

Check warning on line 76 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc#L76

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc:76:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

Check warning on line 76 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc#L76

Add #include <memory> for shared_ptr<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc:76:  Add #include <memory> for shared_ptr<>  [build/include_what_you_use] [4]
trt_custom_op->SetName(plugin_creator->getPluginName());
custom_op_domain->custom_ops_.push_back(trt_custom_op.get());
custom_op_set.insert(trt_custom_op); // Make sure trt_custom_op object won't be cleaned up
registered_plugin_names.insert(plugin_name);
}
custom_op_domain->domain_ = "trt.plugins";
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -443,11 +443,11 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti
if (it != options.end()) {
trt_extra_plugin_lib_paths = it->second;
}
std::vector<OrtCustomOpDomain*> domain_list;
tensorrt_provider_info->GetTensorRTCustomOpDomainList(domain_list, trt_extra_plugin_lib_paths);
for (auto ptr : domain_list) {
std::vector<std::shared_ptr<OrtCustomOpDomain>> custom_op_domains;
tensorrt_provider_info->GetTensorRTCustomOpDomainList(custom_op_domains, trt_extra_plugin_lib_paths);
for (auto ptr : custom_op_domains) {
if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) {
so.custom_op_domains_.push_back(ptr);
so.custom_op_domains_.push_back(ptr.get());
} else {
LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option.";
}
Expand Down

0 comments on commit 4b3c473

Please sign in to comment.