Skip to content

Commit

Permalink
Forward detailed logging to TRT logger as well.
Browse files Browse the repository at this point in the history
  • Loading branch information
gedoensmax committed Feb 28, 2024
1 parent 4838cb6 commit 023c716
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 16 deletions.
37 changes: 24 additions & 13 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,12 @@ std::shared_ptr<KernelRegistry> TensorrtExecutionProvider::GetKernelRegistry() c
}

// Per TensorRT documentation, logger needs to be a singleton.
TensorrtLogger& GetTensorrtLogger() {
static TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING);
TensorrtLogger& GetTensorrtLogger(bool verbose_log) {
const auto log_level = verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING;
static TensorrtLogger trt_logger(log_level);
if (log_level != trt_logger.get_level()) {
trt_logger.set_level(verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING);
}
return trt_logger;
}

Expand Down Expand Up @@ -1696,7 +1700,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv

{
auto lock = GetApiLock();
runtime_ = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(GetTensorrtLogger()));
runtime_ = std::unique_ptr<nvinfer1::IRuntime>(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_)));

Check warning on line 1703 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1703: Lines should be <= 120 characters long [whitespace/line_length] [2]
}

LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT provider options: "
Expand Down Expand Up @@ -1832,9 +1836,8 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::
// Get the pointer to the IBuilder instance.
// Note: This function is not thread safe. Calls to this function from different threads must be serialized
// even though it doesn't make sense to have multiple threads initializing the same inference session.
nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const {
nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder(TensorrtLogger& trt_logger) const {
if (!builder_) {
TensorrtLogger& trt_logger = GetTensorrtLogger();
{
auto lock = GetApiLock();
builder_ = std::unique_ptr<nvinfer1::IBuilder>(nvinfer1::createInferBuilder(trt_logger));
Expand Down Expand Up @@ -2211,10 +2214,14 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect

// Get supported node list recursively
SubGraphCollection_t parser_nodes_list;
TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = GetBuilder();
const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(explicitBatch));
TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_);
auto trt_builder = GetBuilder(trt_logger);
auto network_flags = 0;
#if NV_TENSORRT_MAJOR > 8
network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED);

Check warning on line 2221 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2221: Lines should be <= 120 characters long [whitespace/line_length] [2]
#endif
network_flags |= 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(network_flags));

auto trt_parser = tensorrt_ptr::unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_);
Expand Down Expand Up @@ -2600,10 +2607,14 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
model_proto->SerializeToOstream(dump);
}

TensorrtLogger& trt_logger = GetTensorrtLogger();
auto trt_builder = GetBuilder();
const auto explicitBatch = 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(explicitBatch));
TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_);
auto trt_builder = GetBuilder(trt_logger);
auto network_flags = 0;
#if NV_TENSORRT_MAJOR > 8
network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED);

Check warning on line 2614 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2614: Lines should be <= 120 characters long [whitespace/line_length] [2]
#endif
network_flags |= 1U << static_cast<uint32_t>(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
auto trt_network = std::unique_ptr<nvinfer1::INetworkDefinition>(trt_builder->createNetworkV2(network_flags));
auto trt_config = std::unique_ptr<nvinfer1::IBuilderConfig>(trt_builder->createBuilderConfig());
auto trt_parser = tensorrt_ptr::unique_pointer<nvonnxparser::IParser>(nvonnxparser::createParser(*trt_network, trt_logger));
trt_parser->parse(string_buf.data(), string_buf.size(), model_path_);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ class TensorrtLogger : public nvinfer1::ILogger {
}
}
}
void set_level(Severity verbosity) {
verbosity_ = verbosity;
}
Severity get_level() const {
return verbosity_;
}
};

namespace tensorrt_ptr {
Expand Down Expand Up @@ -547,6 +553,6 @@ class TensorrtExecutionProvider : public IExecutionProvider {
* Get the pointer to the IBuilder instance.
* This function only creates the instance at the first time it's being called."
*/
nvinfer1::IBuilder* GetBuilder() const;
nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const;
};
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <unordered_set>

namespace onnxruntime {
extern TensorrtLogger& GetTensorrtLogger();
extern TensorrtLogger& GetTensorrtLogger(bool verbose);

/*
* Create custom op domain list for TRT plugins.
Expand Down Expand Up @@ -56,7 +56,7 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector<OrtCustomOpDomain*>&
try {
// Get all registered TRT plugins from registry
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ...";
TensorrtLogger trt_logger = GetTensorrtLogger();
TensorrtLogger trt_logger = GetTensorrtLogger(false);
initLibNvInferPlugins(&trt_logger, "");

int num_plugin_creator = 0;
Expand Down

0 comments on commit 023c716

Please sign in to comment.