diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 7618accbac2a2..5604f6c8463e8 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -2489,11 +2489,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectornode_name], input_info_[context->node_name], output_info_[context->node_name], input_shape_ranges_[context->node_name], &tensorrt_mu_, &max_workspace_size_, trt_node_name_with_precision, runtime_.get(), profiles_[context->node_name], &max_ctx_mem_size_, - dynamic_range_map, !tactic_sources_.empty(), tactics, - // Below: class private members - fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, - dla_core_, engine_cache_enable_, cache_path_ - }; + dynamic_range_map, !tactic_sources_.empty(), tactics}; *state = p.release(); return 0; }; @@ -2547,7 +2543,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine_cache_path, trt_state->trt_node_name_with_precision); + const std::string cache_path = GetCachePath(cache_path_, trt_state->trt_node_name_with_precision); const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine"; const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted"; const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile"; @@ -2557,7 +2553,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine_cache_enable && trt_engine == nullptr) { + if (engine_cache_enable_ && trt_engine == nullptr) { std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in); std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in); if (engine_file && !engine_decryption_enable_ && profile_file) { @@ -2641,7 +2637,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorint8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) { + if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) { trt_config->setInt8Calibrator(nullptr); if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) { return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range."); @@ -2649,20 +2645,20 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorfp16_enable && trt_state->int8_enable) { + if (fp16_enable_ && int8_enable_) { trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8)); - } else if (trt_state->fp16_enable) { + } else if (fp16_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kFP16); - } else if (trt_state->int8_enable) { + } else if (int8_enable_) { trt_config->setFlag(nvinfer1::BuilderFlag::kINT8); } // Set DLA (DLA can only run with FP16 or INT8) - if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) { - LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core; + if ((fp16_enable_ || int8_enable_) && dla_enable_) { + LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << dla_core_; trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK); trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA); - trt_config->setDLACore(trt_state->dla_core); + trt_config->setDLACore(dla_core_); } // enable sparse weights @@ -2737,7 +2733,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorengine->get(); - if (trt_state->engine_cache_enable) { + if (engine_cache_enable_) { // Serialize engine profile SerializeProfileV2(profile_cache_path, shape_ranges); LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 1d857b9ecf207..3e42df3884834 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -123,12 +123,7 @@ struct TensorrtFuncState { bool filter_tactic_sources = false; nvinfer1::TacticSources tactic_sources; // Below: class private members - bool fp16_enable = false; - bool int8_enable = false; - bool int8_calibration_cache_available = false; - bool dla_enable = false; - int dla_core = 0; - bool engine_cache_enable = false; + std::string engine_cache_path; };