Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Refactor OrtTensorRTProviderOptions initialization and make it easy to add new field #17617

Merged
merged 26 commits into from
Oct 6, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b25517f
Add member initialization for OrtTensorRTProviderOptionsV2
chilo-ms Sep 17, 2023
afb5524
Move OrtSessionOptionsAppendExecutionProvider_Tensorrt declaratioin t…
chilo-ms Sep 18, 2023
f91583d
add destructor for OrtTensorRTProviderOptionsV2
chilo-ms Sep 19, 2023
29875f6
add ProviderInfo_TensorRT
chilo-ms Sep 19, 2023
e0a5f3e
move UpdateProviderOptions to TensorrtExecutionProviderInfo
chilo-ms Sep 19, 2023
e3810b9
remove ~OrtTensorRTProviderOptionsV2()
chilo-ms Sep 19, 2023
adb68c1
refactor tensorrt ep GetCustomOpDomainList
chilo-ms Sep 19, 2023
9d90248
add python api to regester ep custom op domain
chilo-ms Sep 19, 2023
a59a5a0
remove comment code
chilo-ms Sep 19, 2023
e9b9ba3
add a switch for copy_string
chilo-ms Sep 19, 2023
7afe24e
remove redundant code
chilo-ms Sep 19, 2023
fa49c53
Merge branch 'main' into chi/trt_plugin_python
chilo-ms Sep 19, 2023
83eea74
remove inlcude tensorrt_provider_factory.h
chilo-ms Sep 19, 2023
d5936a2
fix bug
chilo-ms Sep 19, 2023
7c5c374
fix bug
chilo-ms Sep 19, 2023
52304cb
update
chilo-ms Sep 20, 2023
2b6429f
fix bug
chilo-ms Sep 20, 2023
8c38ebd
fix format
chilo-ms Sep 20, 2023
53fc69e
fix bug
chilo-ms Sep 21, 2023
6cb7b09
revert python binding for TRT EP factory creation
chilo-ms Sep 21, 2023
64de02f
fix bug
chilo-ms Sep 21, 2023
f172c91
fix bug
chilo-ms Sep 21, 2023
6f68fbc
refactor and fix format
chilo-ms Sep 27, 2023
e79e9bf
Merge branch 'main' into chi/trt_plugin_python
chilo-ms Sep 27, 2023
b762c00
fix format
chilo-ms Sep 27, 2023
1c4de7f
fix bug for package pipelines
chilo-ms Oct 4, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -11,38 +11,38 @@
/// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions.
/// </summary>
struct OrtTensorRTProviderOptionsV2 {
int device_id; // cuda device id.
int has_user_compute_stream; // indicator of user specified CUDA compute stream.
void* user_compute_stream; // user specified CUDA compute stream.
int trt_max_partition_iterations; // maximum iterations for TensorRT parser to get capability
int trt_min_subgraph_size; // minimum size of TensorRT subgraphs
size_t trt_max_workspace_size; // maximum workspace size for TensorRT.
int trt_fp16_enable; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true
int trt_int8_enable; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true
const char* trt_int8_calibration_table_name; // TensorRT INT8 calibration table name.
int trt_int8_use_native_calibration_table; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true
int trt_dla_enable; // enable DLA. Default 0 = false, nonzero = true
int trt_dla_core; // DLA core number. Default 0
int trt_dump_subgraphs; // dump TRT subgraph. Default 0 = false, nonzero = true
int trt_engine_cache_enable; // enable engine caching. Default 0 = false, nonzero = true
const char* trt_engine_cache_path; // specify engine cache path
int trt_engine_decryption_enable; // enable engine decryption. Default 0 = false, nonzero = true
const char* trt_engine_decryption_lib_path; // specify engine decryption library path
int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true
int trt_context_memory_sharing_enable; // enable context memory sharing between subgraphs. Default 0 = false, nonzero = true
int trt_layer_norm_fp32_fallback; // force Pow + Reduce ops in layer norm to FP32. Default 0 = false, nonzero = true
int trt_timing_cache_enable; // enable TensorRT timing cache. Default 0 = false, nonzero = true
int trt_force_timing_cache; // force the TensorRT cache to be used even if device profile does not match. Default 0 = false, nonzero = true
int trt_detailed_build_log; // Enable detailed build step logging on TensorRT EP with timing for each engine build. Default 0 = false, nonzero = true
int trt_build_heuristics_enable; // Build engine using heuristics to reduce build time. Default 0 = false, nonzero = true
int trt_sparsity_enable; // Control if sparsity can be used by TRT. Default 0 = false, 1 = true
int trt_builder_optimization_level; // Set the builder optimization level. WARNING: levels below 3 do not guarantee good engine performance, but greatly improve build time. Default 3, valid range [0-5]
int trt_auxiliary_streams; // Set maximum number of auxiliary streams per inference stream. Setting this value to 0 will lead to optimal memory usage. Default -1 = heuristics
const char* trt_tactic_sources; // pecify the tactics to be used by adding (+) or removing (-) tactics from the default
// tactic sources (default = all available tactics) e.g. "-CUDNN,+CUBLAS" available keys: "CUBLAS"|"CUBLAS_LT"|"CUDNN"|"EDGE_MASK_CONVOLUTIONS"
const char* trt_extra_plugin_lib_paths; // specify extra TensorRT plugin library paths
const char* trt_profile_min_shapes; // Specify the range of the input shapes to build the engine with
const char* trt_profile_max_shapes; // Specify the range of the input shapes to build the engine with
const char* trt_profile_opt_shapes; // Specify the range of the input shapes to build the engine with
int trt_cuda_graph_enable; // Enable CUDA graph in ORT TRT
int device_id{0}; // cuda device id.
int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream.
void* user_compute_stream{nullptr}; // user specified CUDA compute stream.
int trt_max_partition_iterations{1000}; // maximum iterations for TensorRT parser to get capability
int trt_min_subgraph_size{1}; // minimum size of TensorRT subgraphs
size_t trt_max_workspace_size{1 << 30}; // maximum workspace size for TensorRT.
int trt_fp16_enable{0}; // enable TensorRT FP16 precision. Default 0 = false, nonzero = true
int trt_int8_enable{0}; // enable TensorRT INT8 precision. Default 0 = false, nonzero = true
const char* trt_int8_calibration_table_name{nullptr}; // TensorRT INT8 calibration table name.
int trt_int8_use_native_calibration_table{0}; // use native TensorRT generated calibration table. Default 0 = false, nonzero = true
int trt_dla_enable{0}; // enable DLA. Default 0 = false, nonzero = true
int trt_dla_core{0}; // DLA core number. Default 0
int trt_dump_subgraphs{0}; // dump TRT subgraph. Default 0 = false, nonzero = true
int trt_engine_cache_enable{0}; // enable engine caching. Default 0 = false, nonzero = true
const char* trt_engine_cache_path{nullptr}; // specify engine cache path
int trt_engine_decryption_enable{0}; // enable engine decryption. Default 0 = false, nonzero = true
const char* trt_engine_decryption_lib_path{nullptr}; // specify engine decryption library path
int trt_force_sequential_engine_build{0}; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true
int trt_context_memory_sharing_enable{0}; // enable context memory sharing between subgraphs. Default 0 = false, nonzero = true
int trt_layer_norm_fp32_fallback{0}; // force Pow + Reduce ops in layer norm to FP32. Default 0 = false, nonzero = true
int trt_timing_cache_enable{0}; // enable TensorRT timing cache. Default 0 = false, nonzero = true
int trt_force_timing_cache{0}; // force the TensorRT cache to be used even if device profile does not match. Default 0 = false, nonzero = true
int trt_detailed_build_log{0}; // Enable detailed build step logging on TensorRT EP with timing for each engine build. Default 0 = false, nonzero = true
int trt_build_heuristics_enable{0}; // Build engine using heuristics to reduce build time. Default 0 = false, nonzero = true
int trt_sparsity_enable{0}; // Control if sparsity can be used by TRT. Default 0 = false, 1 = true
int trt_builder_optimization_level{3}; // Set the builder optimization level. WARNING: levels below 3 do not guarantee good engine performance, but greatly improve build time. Default 3, valid range [0-5]
int trt_auxiliary_streams{-1}; // Set maximum number of auxiliary streams per inference stream. Setting this value to 0 will lead to optimal memory usage. Default -1 = heuristics
const char* trt_tactic_sources{nullptr}; // pecify the tactics to be used by adding (+) or removing (-) tactics from the default
// tactic sources (default = all available tactics) e.g. "-CUDNN,+CUBLAS" available keys: "CUBLAS"|"CUBLAS_LT"|"CUDNN"|"EDGE_MASK_CONVOLUTIONS"
const char* trt_extra_plugin_lib_paths{nullptr}; // specify extra TensorRT plugin library paths
const char* trt_profile_min_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT
};
8 changes: 8 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4544,6 +4544,14 @@ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessio
*/
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena);

/*
* This is the old way to add the TensorRT provider to the session, please use SessionOptionsAppendExecutionProvider_TensorRT_V2 above to access the latest functionality
* This function always exists, but will only succeed if Onnxruntime was built with TensorRT support and the TensorRT provider shared library exists
*
* \param device_id CUDA device id, starts from zero.
*/
ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtSessionOptions* options, int device_id);

#ifdef __cplusplus
}
#endif
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "onnxruntime/core/providers/nnapi/nnapi_provider_factory.h"
#include "onnxruntime/core/providers/tvm/tvm_provider_factory.h"
#include "onnxruntime/core/providers/openvino/openvino_provider_factory.h"
#include "onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h"
#include "onnxruntime/core/providers/acl/acl_provider_factory.h"
#include "onnxruntime/core/providers/armnn/armnn_provider_factory.h"
#include "onnxruntime/core/providers/coreml/coreml_provider_factory.h"
Expand Down
1 change: 0 additions & 1 deletion js/node/src/session_options_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "core/providers/dml/dml_provider_factory.h"
#endif
#ifdef USE_TENSORRT
#include "core/providers/tensorrt/tensorrt_provider_factory.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#endif
#ifdef USE_COREML
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,16 @@ extern TensorrtLogger& GetTensorrtLogger();
* Note: Current TRT plugin doesn't have APIs to get number of inputs/outputs of the plugin.
* So, TensorRTCustomOp uses variadic inputs/outputs to pass ONNX graph validation.
*/
common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) {
common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths) {
std::unique_ptr<OrtCustomOpDomain> custom_op_domain = std::make_unique<OrtCustomOpDomain>();
custom_op_domain->domain_ = "trt.plugins";

// Load any extra TRT plugin library if any.
// When the TRT plugin library is loaded, the global static object is created and the plugin is registered to TRT registry.
// This is done through macro, for example, REGISTER_TENSORRT_PLUGIN(VisionTransformerPluginCreator).
std::string extra_plugin_lib_paths{""};
if (info.has_trt_options) {
if (!info.extra_plugin_lib_paths.empty()) {
extra_plugin_lib_paths = info.extra_plugin_lib_paths;
}
} else {
const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths);
if (!extra_plugin_lib_paths_env.empty()) {
extra_plugin_lib_paths = extra_plugin_lib_paths_env;
}
}

// extra_plugin_lib_paths has the format of "path_1;path_2....;path_n"
if (!extra_plugin_lib_paths.empty()) {
static bool is_loaded = false;
if (!extra_plugin_lib_paths.empty() && !is_loaded) {
std::stringstream extra_plugin_libs(extra_plugin_lib_paths);
std::string lib;
while (std::getline(extra_plugin_libs, lib, ';')) {
Expand All @@ -57,35 +46,59 @@ common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& i
LOGS_DEFAULT(WARNING) << "[TensorRT EP]" << status.ToString();
}
}
is_loaded = true;
}

// Get all registered TRT plugins from registry
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ...";
TensorrtLogger trt_logger = GetTensorrtLogger();
initLibNvInferPlugins(&trt_logger, "");
try {
// Get all registered TRT plugins from registry
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ...";
TensorrtLogger trt_logger = GetTensorrtLogger();
initLibNvInferPlugins(&trt_logger, "");

int num_plugin_creator = 0;
auto plugin_creators = getPluginRegistry()->getPluginCreatorList(&num_plugin_creator);
std::unordered_set<std::string> registered_plugin_names;
int num_plugin_creator = 0;
auto plugin_creators = getPluginRegistry()->getPluginCreatorList(&num_plugin_creator);
std::unordered_set<std::string> registered_plugin_names;

for (int i = 0; i < num_plugin_creator; i++) {
auto plugin_creator = plugin_creators[i];
std::string plugin_name(plugin_creator->getPluginName());
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << plugin_name << ", version : " << plugin_creator->getPluginVersion();
for (int i = 0; i < num_plugin_creator; i++) {
auto plugin_creator = plugin_creators[i];
std::string plugin_name(plugin_creator->getPluginName());
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] " << plugin_name << ", version : " << plugin_creator->getPluginVersion();

// plugin has different versions and we only register once
if (registered_plugin_names.find(plugin_name) != registered_plugin_names.end()) {
continue;
}
// plugin has different versions and we only register once
if (registered_plugin_names.find(plugin_name) != registered_plugin_names.end()) {
continue;
}

std::unique_ptr<TensorRTCustomOp> trt_custom_op = std::make_unique<TensorRTCustomOp>(onnxruntime::kTensorrtExecutionProvider, nullptr);
trt_custom_op->SetName(plugin_creator->getPluginName());
custom_op_domain->custom_ops_.push_back(trt_custom_op.release());
registered_plugin_names.insert(plugin_name);
std::unique_ptr<TensorRTCustomOp> trt_custom_op = std::make_unique<TensorRTCustomOp>(onnxruntime::kTensorrtExecutionProvider, nullptr);
trt_custom_op->SetName(plugin_creator->getPluginName());
custom_op_domain->custom_ops_.push_back(trt_custom_op.release());
registered_plugin_names.insert(plugin_name);
}
domain_list.push_back(custom_op_domain.release());
} catch (const std::exception&) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration. Therefore, TRT EP can't create custom ops for TRT plugins";
}
info.custom_op_domain_list.push_back(custom_op_domain.release());
return Status::OK();
}

return common::Status::OK();
common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info) {
std::vector<OrtCustomOpDomain*> domain_list;
std::string extra_plugin_lib_paths{""};
if (info.has_trt_options) {
if (!info.extra_plugin_lib_paths.empty()) {
extra_plugin_lib_paths = info.extra_plugin_lib_paths;
}
} else {
const std::string extra_plugin_lib_paths_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kExtraPluginLibPaths);
if (!extra_plugin_lib_paths_env.empty()) {
extra_plugin_lib_paths = extra_plugin_lib_paths_env;
}
}
auto status = CreateTensorRTCustomOpDomainList(domain_list, extra_plugin_lib_paths);
if (!domain_list.empty()) {
info.custom_op_domain_list = domain_list;
}
return Status::OK();
}

void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ using namespace onnxruntime;
namespace onnxruntime {

common::Status LoadDynamicLibrary(onnxruntime::PathString library_name);
common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& domain_list, const std::string extra_plugin_lib_paths);
common::Status CreateTensorRTCustomOpDomainList(TensorrtExecutionProviderInfo& info);
void ReleaseTensorRTCustomOpDomain(OrtCustomOpDomain* domain);
void ReleaseTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>& custom_op_domain_list);
Expand Down
Loading