diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 81346671f2aad..abdbd35c00a92 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -380,8 +380,12 @@ std::shared_ptr TensorrtExecutionProvider::GetKernelRegistry() c } // Per TensorRT documentation, logger needs to be a singleton. -TensorrtLogger& GetTensorrtLogger() { - static TensorrtLogger trt_logger(nvinfer1::ILogger::Severity::kWARNING); +TensorrtLogger& GetTensorrtLogger(bool verbose_log) { + const auto log_level = verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING; + static TensorrtLogger trt_logger(log_level); + if (log_level != trt_logger.get_level()) { + trt_logger.set_level(verbose_log ? nvinfer1::ILogger::Severity::kVERBOSE : nvinfer1::ILogger::Severity::kWARNING); + } return trt_logger; } @@ -1696,7 +1700,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv { auto lock = GetApiLock(); - runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger())); + runtime_ = std::unique_ptr(nvinfer1::createInferRuntime(GetTensorrtLogger(detailed_build_log_))); } LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] TensorRT provider options: " @@ -1832,9 +1836,8 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime:: // Get the pointer to the IBuilder instance. // Note: This function is not thread safe. Calls to this function from different threads must be serialized // even though it doesn't make sense to have multiple threads initializing the same inference session. -nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder() const { +nvinfer1::IBuilder* TensorrtExecutionProvider::GetBuilder(TensorrtLogger& trt_logger) const { if (!builder_) { - TensorrtLogger& trt_logger = GetTensorrtLogger(); { auto lock = GetApiLock(); builder_ = std::unique_ptr(nvinfer1::createInferBuilder(trt_logger)); @@ -2211,10 +2214,14 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect // Get supported node list recursively SubGraphCollection_t parser_nodes_list; - TensorrtLogger& trt_logger = GetTensorrtLogger(); - auto trt_builder = GetBuilder(); - const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); - auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(explicitBatch)); + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); + auto trt_builder = GetBuilder(trt_logger); + auto network_flags = 0; +#if NV_TENSORRT_MAJOR > 8 + network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); +#endif + network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); trt_parser->supportsModel(string_buf.data(), string_buf.size(), parser_nodes_list, model_path_); @@ -2600,10 +2607,14 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView model_proto->SerializeToOstream(dump); } - TensorrtLogger& trt_logger = GetTensorrtLogger(); - auto trt_builder = GetBuilder(); - const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); - auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(explicitBatch)); + TensorrtLogger& trt_logger = GetTensorrtLogger(detailed_build_log_); + auto trt_builder = GetBuilder(trt_logger); + auto network_flags = 0; +#if NV_TENSORRT_MAJOR > 8 + network_flags |= fp16_enable_ || int8_enable_ ? 0 : 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kSTRONGLY_TYPED); +#endif + network_flags |= 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH); + auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(network_flags)); auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig()); auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger)); trt_parser->parse(string_buf.data(), string_buf.size(), model_path_); diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 26f6b2dcc3020..43f96ec68ce74 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -84,6 +84,12 @@ class TensorrtLogger : public nvinfer1::ILogger { } } } + void set_level(Severity verbosity) { + verbosity_ = verbosity; + } + Severity get_level() const { + return verbosity_; + } }; namespace tensorrt_ptr { @@ -547,6 +553,6 @@ class TensorrtExecutionProvider : public IExecutionProvider { * Get the pointer to the IBuilder instance. * This function only creates the instance at the first time it's being called." */ - nvinfer1::IBuilder* GetBuilder() const; + nvinfer1::IBuilder* GetBuilder(TensorrtLogger& trt_logger) const; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc index eb340ba1e64b6..241bce72f7477 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_custom_ops.cc @@ -9,7 +9,7 @@ #include namespace onnxruntime { -extern TensorrtLogger& GetTensorrtLogger(); +extern TensorrtLogger& GetTensorrtLogger(bool verbose); /* * Create custom op domain list for TRT plugins. @@ -56,7 +56,7 @@ common::Status CreateTensorRTCustomOpDomainList(std::vector& try { // Get all registered TRT plugins from registry LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Getting all registered TRT plugins from TRT plugin registry ..."; - TensorrtLogger trt_logger = GetTensorrtLogger(); + TensorrtLogger trt_logger = GetTensorrtLogger(false); initLibNvInferPlugins(&trt_logger, ""); int num_plugin_creator = 0;