Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TRT detailed log and strong typed networks #19695

Merged
merged 1 commit into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,12 @@
}

// Per TensorRT documentation, logger needs to be a singleton.
TensorrtLogger& GetTensorrtLogger() {
static TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
TensorrtLogger& GetTensorrtLogger(bool verbose_log) {
const auto log_level = verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING;
static TensorrtLogger trt_logger(log_level);
if (log_level != trt_logger.get_level()) {
trt_logger.set_level(verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING);
}
return trt_logger;
}

Expand Down Expand Up @@ -1558,7 +1562,7 @@

{
auto lock = GetApiLock();
runtime_ = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(GetTensorrtLogger()));
runtime_ = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_)));

Check warning on line 1565 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1565: Lines should be <= 120 characters long [whitespace/line_length] [2]
}

LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT provider options: "
Expand Down Expand Up @@ -1695,9 +1699,8 @@
// Get the pointer to the IBuilder instance.
// Note: This function is not thread safe. Calls to this function from different threads must be serialized
// even though it doesn't make sense to have multiple threads initializing the same inference session.
nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const {
nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder(TensorrtLogger& trt_logger) const {
if (!builder_) {
TensorrtLogger& trt_logger = GetTensorrtLogger();
{
auto lock = GetApiLock();
builder_ = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
Expand Down Expand Up @@ -2074,10 +2077,14 @@

// Get supported node list recursively
SubGraphCollection_t parser_nodes_list;
TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = GetBuilder();
const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(explicitBatch));
TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_);
auto trt_builder = GetBuilder(trt_logger);
auto network_flags = 0;
#if NV_TENSORRT_MAJOR > 8
network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED);

Check warning on line 2084 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2084: Lines should be <= 120 characters long [whitespace/line_length] [2]
#endif
network_flags |= 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(network_flags));

auto trt_parser = tensorrt_ptr::unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_);
Expand Down Expand Up @@ -2463,10 +2470,14 @@
model_proto->SerializeToOstream(dump);
}

TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = GetBuilder();
const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(explicitBatch));
TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_);
auto trt_builder = GetBuilder(trt_logger);
auto network_flags = 0;
#if NV_TENSORRT_MAJOR > 8
network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED);

Check warning on line 2477 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2477: Lines should be <= 120 characters long [whitespace/line_length] [2]
#endif
network_flags |= 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(network_flags));
auto trt_config = std::unique_ptr<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
auto trt_parser = tensorrt_ptr::unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
trt_parser->parse(string_buf.data(), string_buf.size(), model_path_);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ class TensorrtLogger : public nvinfer1::ILogger {
}
}
}
void set_level(Severity verbosity) {
verbosity_ = verbosity;
}
Severity get_level() const {
return verbosity_;
}
};

namespace tensorrt_ptr {
Expand Down Expand Up @@ -548,6 +554,6 @@ class TensorrtExecutionProvider : public IExecutionProvider {
* Get the pointer to the IBuilder instance.
* This function only creates the instance at the first time it's being called."
*/
nvinfer1::IBuilder* GetBuilder() const;
nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const;
};
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
#include "tensorrt_execution_provider.h"

namespace onnxruntime {
extern TensorrtLogger& GetTensorrtLogger();
extern TensorrtLogger& GetTensorrtLogger(bool verbose);

/*
* Create custom op domain list for TRT plugins.
Expand Down Expand Up @@ -57,7 +57,7 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
try {
// Get all registered TRT plugins from registry
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ...";
TensorrtLogger trt_logger = GetTensorrtLogger();
TensorrtLogger trt_logger = GetTensorrtLogger(false);
initLibNvInferPlugins(&trt_logger, "");

int num_plugin_creator = 0;
Expand Down
Loading