Skip to content

Commit

Permalink
handle relative path for 'ep_cache_context' node attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
chilo-ms committed Jan 20, 2024
1 parent da1207f commit ec7c8f3
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,10 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
LOGS_DEFAULT(ERROR) << "In the case of dumping context model and for security purpose, The trt_engine_cache_path has '..', it's not allowed to point outside the directory.";

Check warning on line 1599 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1599

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1599:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}

// Engine cache relative path to context model directory.
// It's used when dumping the "ep_cache_context" node attribute.
engine_cache_relative_path_to_context_model_dir = cache_path_;

// Make cache_path_ to be the relative path of ep_context_file_path_
cache_path_ = GetPathOrParentPathOfCtxModel(ep_context_file_path_).append(cache_path_).string();
}
Expand Down Expand Up @@ -3018,7 +3022,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
if (dump_ep_context_model_) {
// "ep_cache_context" node attribute should be a relative path to context model directory
if (ep_cache_context_attr_.empty()) {
ep_cache_context_attr_ = std::filesystem::relative(engine_cache_path, ep_context_file_path_).string();
auto cache_file_name = std::filesystem::path(engine_cache_path).filename();
ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string();

Check warning on line 3026 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3026

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3026:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}

std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto{CreateCtxModel(graph_body_viewer,
Expand Down Expand Up @@ -3090,7 +3095,8 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView
if (dump_ep_context_model_ && has_dynamic_shape) {
// "ep_cache_context" node attribute should be a relative path to context model directory
if (ep_cache_context_attr_.empty()) {
ep_cache_context_attr_ = std::filesystem::relative(engine_cache_path, ep_context_file_path_).string();
auto cache_file_name = std::filesystem::path(engine_cache_path).filename();
ep_cache_context_attr_ = std::filesystem::path(engine_cache_relative_path_to_context_model_dir).append(cache_file_name.string()).string();

Check warning on line 3099 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3099

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3099:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
model_proto_.reset(CreateCtxModel(graph_body_viewer,
ep_cache_context_attr_,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
int ep_context_embed_mode_ = 0;
std::string ctx_model_path_;
std::string ep_cache_context_attr_;
std::string engine_cache_relative_path_to_context_model_dir;
std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto_ = ONNX_NAMESPACE::ModelProto::Create();

std::unordered_set<std::string> control_flow_op_set_ = {"If", "Loop", "Scan"};
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/test/providers/tensorrt/tensorrt_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,7 @@ TEST(TensorrtExecutionProviderTest, EPContextNode) {
*/
InferenceSession session_object6{so, GetEnvironment()};
OrtTensorRTProviderOptionsV2 params6;
params6.trt_ep_context_embed_mode = 1;
model_name = params5.trt_ep_context_file_path;
execution_provider = TensorrtExecutionProviderWithOptions(&params6);
EXPECT_TRUE(session_object6.RegisterExecutionProvider(std::move(execution_provider)).IsOK());
Expand Down

0 comments on commit ec7c8f3

Please sign in to comment.