Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
yf711 committed Dec 11, 2023
1 parent eefe68b commit a298ba9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 19 deletions.
20 changes: 9 additions & 11 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2491,9 +2491,8 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
trt_node_name_with_precision, runtime_.get(), profiles_[context->node_name], &max_ctx_mem_size_,
dynamic_range_map, !tactic_sources_.empty(), tactics,
// Below: class private members
sync_stream_after_enqueue_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_,
dla_core_, engine_cache_enable_, cache_path_, context_memory_sharing_enable_, engine_decryption_enable_,
engine_decryption_, engine_encryption_, timing_cache_enable_, global_cache_path_
fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_,
dla_core_, engine_cache_enable_, cache_path_
};
*state = p.release();
return 0;
Expand All @@ -2517,7 +2516,6 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
const std::unordered_map<std::string, size_t>& input_indexes = (trt_state->input_info)[0];
const std::unordered_map<std::string, size_t>& output_indexes = (trt_state->output_info)[0];
const std::unordered_map<std::string, size_t>& output_types = (trt_state->output_info)[1];
bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue;
auto fused_node_name = trt_state->fused_node_name;
auto& shape_ranges = trt_state->input_shape_ranges;
auto trt_builder = trt_state->builder;
Expand Down Expand Up @@ -2708,7 +2706,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd

// Load timing cache from file. Create a fresh cache if the file doesn't exist
std::unique_ptr<nvinfer1::ITimingCache> timing_cache = nullptr;
if (trt_state->timing_cache_enable) {
if (timing_cache_enable_) {
std::vector<char> loaded_timing_cache = loadTimingCacheFile(timing_cache_path);
timing_cache.reset(trt_config->createTimingCache(static_cast<const void*>(loaded_timing_cache.data()), loaded_timing_cache.size()));
if (timing_cache == nullptr) {
Expand Down Expand Up @@ -2749,8 +2747,8 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
size_t engine_size = serializedModel->size();
if (trt_state->engine_decryption_enable) {
// Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first.
if (trt_state->engine_encryption != nullptr) {
if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {
if (engine_encryption_ != nullptr) {
if (!engine_encryption_(encrypted_engine_cache_path.c_str(), reinterpret_cast<char*>(serializedModel->data()), engine_size)) {

Check warning on line 2751 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L2751

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2751:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not call engine encryption function encrypt");
}
Expand All @@ -2766,7 +2764,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
}

// serialize and save timing cache
if (trt_state->timing_cache_enable) {
if (timing_cache_enable_) {
auto timing_cache = trt_config->getTimingCache();
std::unique_ptr<nvinfer1::IHostMemory> timingCacheHostData{timing_cache->serialize()};
if (timingCacheHostData == nullptr) {
Expand All @@ -2782,7 +2780,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
}

if (context_update) {
if (trt_state->context_memory_sharing_enable) {
if (context_memory_sharing_enable_) {
*(trt_state->context) = std::unique_ptr<nvinfer1::IExecutionContext>(
trt_state->engine->get()->createExecutionContextWithoutDeviceMemory());
} else {
Expand Down Expand Up @@ -3113,7 +3111,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
}

// Set execution context memory
if (trt_state->context_memory_sharing_enable) {
if (context_memory_sharing_enable_) {
size_t mem_size = trt_engine->getDeviceMemorySize();
if (mem_size > *max_context_mem_size_ptr) {
*max_context_mem_size_ptr = mem_size;
Expand All @@ -3135,7 +3133,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
}

if (sync_stream_after_enqueue) {
if (sync_stream_after_enqueue_) {
cudaStreamSynchronize(stream);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,20 +123,13 @@ struct TensorrtFuncState {
bool filter_tactic_sources = false;
nvinfer1::TacticSources tactic_sources;
// Below: class private members
bool sync_stream_after_enqueue = false;
bool fp16_enable = false;
bool int8_enable = false;
bool int8_calibration_cache_available = false;
bool dla_enable = false;
int dla_core = 0;
bool engine_cache_enable = false;
std::string engine_cache_path;
bool context_memory_sharing_enable = false;
bool engine_decryption_enable = false;
int (*engine_decryption)(const char*, char*, size_t*) = nullptr;
int (*engine_encryption)(const char*, char*, size_t) = nullptr;
bool timing_cache_enable = true;
std::string timing_cache_path;
};

// Holds important information for building valid ORT graph.
Expand Down Expand Up @@ -227,7 +220,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool engine_decryption_enable_ = false;
int (*engine_decryption_)(const char*, char*, size_t*) = nullptr;
int (*engine_encryption_)(const char*, char*, size_t) = nullptr;
bool timing_cache_enable_ = false;
bool timing_cache_enable_ = true;
bool force_timing_cache_match_ = false;
bool detailed_build_log_ = false;
bool cuda_graph_enable_ = false;
Expand Down

0 comments on commit a298ba9

Please sign in to comment.