Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Enhance EP context configs in session options and provider options #19154

Merged
merged 26 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
/// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions.
/// </summary>
struct OrtTensorRTProviderOptionsV2 {
OrtTensorRTProviderOptionsV2& operator=(const OrtTensorRTProviderOptionsV2& other); // copy assignment operator

int device_id{0}; // cuda device id.
int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream.
void* user_compute_stream{nullptr}; // user specified CUDA compute stream.
Expand Down Expand Up @@ -47,7 +49,8 @@
const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT
int trt_dump_ep_context_model{0}; // Dump EP context node model
const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model.
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute
int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute and check it against the compute capability when running

Check warning on line 54 in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h#L54

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h:54:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
};
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,10 @@
// Flag to specify whether to dump the EP context into the Onnx model.
// "0": dump the EP context into separate file, keep the file name in the Onnx model.
// "1": dump the EP context into the Onnx model. (default).
static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";

// Enable to dump the EP context node with "hardware_architecture" attribute and check this attribute against the
// hardware architecture when inferencing.
// "0": disable. (default)
// "1": enable.
static const char* const kOrtSessionOptionEpContextHardwareArchitectureEnable = "ep.context_hardware_architecture_enable";

Check warning on line 258 in include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h#L258

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h:258:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

Check warning on line 258 in include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h#L258

Could not find a newline character at the end of the file. [whitespace/ending_newline] [5]
Raw output
include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h:258:  Could not find a newline character at the end of the file.  [whitespace/ending_newline] [5]
jywu-msft marked this conversation as resolved.
Show resolved Hide resolved
86 changes: 82 additions & 4 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer,
engine_data_str.assign(engine_data, size);
}
attr_1->set_s(engine_data_str);
LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING;
} else {
attr_1->set_s(engine_cache_path);
}
Expand Down Expand Up @@ -137,15 +138,90 @@ ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer,
return model_proto.release();
}

/*
* Get "EP context node" model path
jywu-msft marked this conversation as resolved.
Show resolved Hide resolved
*
*
* If ep_context_file_path is provided:
* - If ep_context_file_path is a file:
* - If it's a file name without any path associated with it, return "engine_cache_path/ep_context_file_path".
- If it's a file name with path associated with it, return "ep_context_file_path".
* - If ep_context_file_path is a directory, return "ep_context_file_path/original_model_name_ctx.onnx".
* If ep_context_file_path is not provided:
* - Return "engine_cache_path/original_model_name_ctx.onnx".
*
*
* Example 1:
* ep_context_file_path = "/home/user/ep_context_model_foler"
jywu-msft marked this conversation as resolved.
Show resolved Hide resolved
* engine_cache_path = "trt_engine.engine"
* original_model_path = "model.onnx"
* => return "/home/user/ep_context_model_folder/model_ctx.onnx"
*
* Example 2:
* ep_context_file_path = "my_ctx_model.onnx"
* engine_cache_path = "/home/user/cache_folder/trt_engine.engine"
* original_model_path = "model.onnx"
* => return "/home/user/cache_folder/my_ctx_model.onnx"
*
* Example 3:
* ep_context_file_path = "/home/user2/ep_context_model_foler/my_ctx_model.onnx"
* engine_cache_path = "trt_engine.engine"
* original_model_path = "model.onnx"
* => return "/home/user2/ep_context_model_foler/my_ctx_model.onnx"
*
* Example 4:
* ep_context_file_path = ""
* engine_cache_path = "/home/user3/cache_folder/trt_engine.engine"
* original_model_path = "model.onnx"
* => return "/home/user3/cache_folder/model_ctx.onnx"
*
*/
std::string GetCtxNodeModelPath(const std::string& ep_context_file_path,
const std::string& engine_cache_path,
const std::string& original_model_path) {
std::string ctx_model_path;

if (!ep_context_file_path.empty() && !std::filesystem::is_directory(ep_context_file_path)) {
std::filesystem::path ctx_model_file_path = ep_context_file_path;
if (ctx_model_file_path.filename().string() == ep_context_file_path) {
std::filesystem::path cache_path = engine_cache_path;
if (cache_path.has_parent_path()) {
ctx_model_path = cache_path.parent_path().append(ep_context_file_path).string();
} else {
ctx_model_path = ep_context_file_path;
}
} else {
ctx_model_path = ep_context_file_path;
}
} else {
std::filesystem::path model_path = original_model_path;
std::filesystem::path model_name_stem = model_path.stem(); // model_name.onnx -> model_name
std::string ctx_model_name = model_name_stem.string() + "_ctx.onnx";

if (std::filesystem::is_directory(ep_context_file_path)) {
std::filesystem::path model_directory = ep_context_file_path;
ctx_model_path = model_directory.append(ctx_model_name).string();
} else {
std::filesystem::path cache_path = engine_cache_path;
if (cache_path.has_parent_path()) {
ctx_model_path = cache_path.parent_path().append(ctx_model_name).string();
} else {
ctx_model_path = ctx_model_name;
}
}
}
return ctx_model_path;
}

/*
* Dump "EP context node" model
*
*/
void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string engine_cache_path) {
std::fstream dump(engine_cache_path + "_wrapper.onnx", std::ios::out | std::ios::trunc | std::ios::binary);
const std::string& ctx_model_path) {
std::fstream dump(ctx_model_path, std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path + "_wrapper.onnx";
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Dumped " + ctx_model_path;
}

Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) {
Expand Down Expand Up @@ -194,7 +270,7 @@ bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewe
auto& attrs = node->GetAttributes();

// Check hardware_architecture(compute_capability) if it's present as an attribute
if (attrs.count(COMPUTE_CAPABILITY) > 0) {
if (compute_capability_enable_ && attrs.count(COMPUTE_CAPABILITY) > 0) {
std::string model_compute_capability = attrs.at(COMPUTE_CAPABILITY).s();
if (model_compute_capability != compute_capability_) {
LOGS_DEFAULT(ERROR) << "The compute capability of the engine cache doesn't match with the GPU's compute capability";
Expand Down Expand Up @@ -222,6 +298,8 @@ bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewe
return false;
}
}
} else if (embed_mode == 1) {
LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING;
}
}
return true;
Expand Down
12 changes: 10 additions & 2 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
static const std::string EP_CACHE_CONTEXT = "ep_cache_context";
static const std::string COMPUTE_CAPABILITY = "hardware_architecture";
static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft";
static const std::string EPCONTEXT_WARNING = "It's suggested to set the ORT graph optimization level to 0 and \

Check warning on line 19 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L19

Multi-line string ("...") found. This lint script doesn't do well with such strings, and may give bogus warnings. Use C++11 raw strings or concatenation instead. [readability/multiline_string] [5]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:19:  Multi-line string ("...") found.  This lint script doesn't do well with such strings, and may give bogus warnings.  Use C++11 raw strings or concatenation instead.  [readability/multiline_string] [5]

Check warning on line 19 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L19

Use operator && instead of and [readability/alt_tokens] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:19:  Use operator && instead of and  [readability/alt_tokens] [2]

Check warning on line 19 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L19

For a static/global string constant, use a C style string instead: "static const char EPCONTEXT_WARNING[]". [runtime/string] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:19:  For a static/global string constant, use a C style string instead: "static const char EPCONTEXT_WARNING[]".  [runtime/string] [4]
make \"embed_mode\" to 0 (\"ep_cache_context\" is the cache path)\

Check warning on line 20 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L20

Extra space before ( in function call [whitespace/parens] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:20:  Extra space before ( in function call  [whitespace/parens] [4]
for the best model loading time";

Check warning on line 21 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L21

Multi-line string ("...") found. This lint script doesn't do well with such strings, and may give bogus warnings. Use C++11 raw strings or concatenation instead. [readability/multiline_string] [5]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:21:  Multi-line string ("...") found.  This lint script doesn't do well with such strings, and may give bogus warnings.  Use C++11 raw strings or concatenation instead.  [readability/multiline_string] [5]

bool GraphHasCtxNode(const GraphViewer& graph_viewer);
const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer);
Expand All @@ -28,8 +31,11 @@
bool compute_capability_enable,
std::string compute_capability,
const logging::Logger* logger);
std::string GetCtxNodeModelPath(const std::string& ep_context_file_path,
const std::string& engine_cache_path,
const std::string& original_model_path);
void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string engine_cache_path);
const std::string& ctx_model_path);
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
char* engine_data,
size_t size);
Expand All @@ -38,7 +44,8 @@
public:
TensorRTCacheModelHandler(std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine,
nvinfer1::IRuntime* trt_runtime,
std::string compute_capability) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), compute_capability_(compute_capability) {
std::string compute_capability,
bool compute_capability_enable = true) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), compute_capability_(compute_capability), compute_capability_enable_(compute_capability_enable) {

Check warning on line 48 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L48

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:48:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler);

Expand All @@ -51,5 +58,6 @@
nvinfer1::IRuntime* trt_runtime_;
std::filesystem::path engine_cache_path_;
std::string compute_capability_;
bool compute_capability_enable_;
}; // TRTCacheModelHandler
} // namespace onnxruntime
32 changes: 24 additions & 8 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1381,6 +1381,7 @@
profile_opt_shapes = info.profile_opt_shapes;
cuda_graph_enable_ = info.cuda_graph_enable;
dump_ep_context_model_ = info.dump_ep_context_model;
ep_context_file_path_ = info.ep_context_file_path;
ep_context_embed_mode_ = info.ep_context_embed_mode;
ep_context_compute_capability_enable_ = info.ep_context_compute_capability_enable;
} else {
Expand Down Expand Up @@ -1543,6 +1544,11 @@
dump_ep_context_model_ = (std::stoi(dump_ep_context_model_env) == 0 ? false : true);
}

const std::string ep_context_file_path_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextComputeCapabilityEnable);

Check warning on line 1547 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1547

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1547:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (!ep_context_file_path_env.empty()) {
ep_context_file_path_ = ep_context_file_path_env;
}

const std::string ep_context_embed_mode_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextEmbedMode);
if (!ep_context_embed_mode_env.empty()) {
ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env);
Expand Down Expand Up @@ -1580,7 +1586,7 @@
dla_core_ = 0;
}

if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_ || !cache_prefix_.empty()) {
if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) {
if (!cache_path_.empty() && !fs::is_directory(cache_path_)) {
if (!fs::create_directory(cache_path_)) {
throw std::runtime_error("Failed to create directory " + cache_path_);
Expand Down Expand Up @@ -1692,6 +1698,10 @@
<< ", trt_profile_max_shapes: " << profile_max_shapes
<< ", trt_profile_opt_shapes: " << profile_opt_shapes
<< ", trt_cuda_graph_enable: " << cuda_graph_enable_
<< ", trt_dump_ep_context_model: " << dump_ep_context_model_
<< ", trt_ep_context_file_path: " << ep_context_file_path_
<< ", trt_ep_context_embed_mode: " << ep_context_embed_mode_
<< ", trt_ep_context_compute_capability_enable: " << ep_context_compute_capability_enable_
<< ", trt_cache_prefix: " << cache_prefix_;
}

Expand Down Expand Up @@ -2831,10 +2841,8 @@
std::unique_ptr<nvinfer1::ICudaEngine> trt_engine;
std::unique_ptr<nvinfer1::IExecutionContext> trt_context;

// Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
// Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity
std::string cache_suffix = "";
std::string cache_path = "";
std::string cache_suffix = "";
// Customize cache prefix if assigned
if (!cache_prefix_.empty()) {
// Generate cache suffix in case user would like to customize cache prefix
Expand All @@ -2843,11 +2851,19 @@
} else {
cache_path = GetCachePath(cache_path_, trt_node_name_with_precision);
}

// Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
// Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity

Check warning on line 2856 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L2856

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2856:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
const std::string cache_path_prefix = cache_path + "_sm" + compute_capability_;
const std::string engine_cache_path = cache_path_prefix + ".engine";
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
const std::string profile_cache_path = cache_path_prefix + ".profile";

// Generate file name for dumping ep context model
if (dump_ep_context_model_ && ctx_model_path_.empty()) {
ctx_model_path_ = GetCtxNodeModelPath(ep_context_file_path_, engine_cache_path, model_path_);
}

if (!has_dynamic_shape) {
std::string timing_cache_path = "";
bool engine_update = false;
Expand Down Expand Up @@ -2992,7 +3008,7 @@
ep_context_compute_capability_enable_,
compute_capability_,
GetLogger())};
DumpCtxNodeModel(model_proto.get(), cache_path_prefix);
DumpCtxNodeModel(model_proto.get(), ctx_model_path_);
}
}
}
Expand Down Expand Up @@ -3061,7 +3077,7 @@
compute_capability_,
GetLogger()));
if (ep_context_embed_mode_ == 0) {
DumpCtxNodeModel(model_proto_.get(), cache_path_prefix);
DumpCtxNodeModel(model_proto_.get(), ctx_model_path_);
}
}

Expand Down Expand Up @@ -3382,7 +3398,7 @@
// dump ep context model
if (dump_ep_context_model_ && ep_context_embed_mode_) {
UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast<char*>(serialized_engine->data()), serialized_engine->size());
DumpCtxNodeModel(model_proto_.get(), cache_path_prefix);
DumpCtxNodeModel(model_proto_.get(), ctx_model_path_);
}
context_update = true;
}
Expand Down Expand Up @@ -3575,7 +3591,7 @@
std::unordered_map<std::string, size_t> output_types; // TRT engine output name -> ORT output tensor type

// Get engine binary data and deserialize it
auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine, runtime_.get(), compute_capability_);
auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine, runtime_.get(), compute_capability_, ep_context_compute_capability_enable_);

Check warning on line 3594 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3594

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3594:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer);
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool force_timing_cache_match_ = false;
bool detailed_build_log_ = false;
bool cuda_graph_enable_ = false;
std::string ctx_model_path_;
std::string cache_prefix_;

// The OrtAllocator object will be get during ep compute time
Expand All @@ -301,8 +302,9 @@ class TensorrtExecutionProvider : public IExecutionProvider {

// For create/dump EP context node model
bool dump_ep_context_model_ = false;
std::string ep_context_file_path_;
int ep_context_embed_mode_ = 0;
bool ep_context_compute_capability_enable_ = true;
bool ep_context_compute_capability_enable_ = false;
std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto_ = ONNX_NAMESPACE::ModelProto::Create();

std::unordered_set<std::string> control_flow_op_set_ = {"If", "Loop", "Scan"};
Expand Down
Loading
Loading