Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Sep 20, 2023
1 parent 7c5c374 commit 52304cb
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
// When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry.
// This is done through macro, for example, REGISTER_TENSORRT_PLUGIN(VisionTransformerPluginCreator).
// extra_plugin_lib_paths has the format of "path_1;path_2....;path_n"
if (!extra_plugin_lib_paths.empty()) {
static bool is_loaded = false;
if (!extra_plugin_lib_paths.empty() && !is_loaded) {
std::stringstream extra_plugin_libs(extra_plugin_lib_paths);
std::string lib;
while (std::getline(extra_plugin_libs, lib, ';')) {
Expand All @@ -45,6 +46,7 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
LOGS_DEFAULT(WARNING) << "[TensorRT EP]" << status.ToString();
}
}
is_loaded = true;
}

try {
Expand Down Expand Up @@ -79,6 +81,26 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
return Status::OK();
}

common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) {
std::vector<OrtCustomOpDomain*> domain_list;
std::string extra_plugin_lib_paths{""};
if (info.has_trt_options) {
if (!info.extra_plugin_lib_paths.empty()) {
extra_plugin_lib_paths = info.extra_plugin_lib_paths;
}
} else {
const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths);
if (!extra_plugin_lib_paths_env.empty()) {
extra_plugin_lib_paths = extra_plugin_lib_paths_env;
}
}
auto status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths);
if (!domain_list.empty()) {
info.custom_op_domain_list = domain_list;
}
return Status::OK();
}

void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) {
if (domain != nullptr) {
for (auto ptr : domain->custom_ops_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace onnxruntime {

common::Status LoadDynamicLibrary(onnxruntime::PathString library_name);
common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths);
common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info);
void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain);
void ReleaseTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list);

Expand Down
21 changes: 10 additions & 11 deletions onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,10 @@ struct TensorrtProviderFactory : IExecutionProviderFactory {

std::unique_ptr<IExecutionProvider> CreateProvider() override;

void GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list);

private:
TensorrtExecutionProviderInfo info_;
};

void TensorrtProviderFactory::GetCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list) {
custom_op_domain_list = info_.custom_op_domain_list;
}

std::unique_ptr<IExecutionProvider> TensorrtProviderFactory::CreateProvider() {
return std::make_unique<TensorrtExecutionProvider>(info_);
}
Expand All @@ -81,6 +75,11 @@ struct Tensorrt_Provider : Provider {
info.device_id = device_id;
info.has_trt_options = false;

common::Status status = CreateTensorRTCustomOpDomainList(info);
if (!status.IsOK()) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
}

return std::make_shared<TensorrtProviderFactory>(info);
}

Expand Down Expand Up @@ -122,6 +121,11 @@ struct Tensorrt_Provider : Provider {
info.profile_opt_shapes = options.trt_profile_opt_shapes == nullptr ? "" : options.trt_profile_opt_shapes;
info.cuda_graph_enable = options.trt_cuda_graph_enable != 0;

common::Status status = CreateTensorRTCustomOpDomainList(info);
if (!status.IsOK()) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration.";
}

return std::make_shared<TensorrtProviderFactory>(info);
}

Expand All @@ -134,11 +138,6 @@ struct Tensorrt_Provider : Provider {
return onnxruntime::TensorrtExecutionProviderInfo::ToProviderOptions(options);
}

void GetCustomOpDomainList(IExecutionProviderFactory* factory, std::vector<OrtCustomOpDomain*>& custom_op_domains_ptr) override {
TensorrtProviderFactory* trt_factory = reinterpret_cast<TensorrtProviderFactory*>(factory);
trt_factory->GetCustomOpDomainList(custom_op_domains_ptr);
}

void Initialize() override {
InitializeRegistry();
}
Expand Down
4 changes: 0 additions & 4 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1420,10 +1420,6 @@ std::shared_ptr<IExecutionProviderFactory> TensorrtProviderFactoryCreator::Creat
return s_library_tensorrt.Get().CreateExecutionProviderFactory(provider_options);
}

void TensorrtProviderGetCustomOpDomainList(IExecutionProviderFactory* factory, std::vector<OrtCustomOpDomain*>& custom_op_domains_ptr) {
s_library_tensorrt.Get().GetCustomOpDomainList(factory, custom_op_domains_ptr);
}

std::shared_ptr<IExecutionProviderFactory> MIGraphXProviderFactoryCreator::Create(const OrtMIGraphXProviderOptions* provider_options) {
return s_library_migraphx.Get().CreateExecutionProviderFactory(provider_options);
}
Expand Down

0 comments on commit 52304cb

Please sign in to comment.