Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorRT EP] Enhance EP context configs in session options and provider options #19154

Merged
merged 26 commits into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
/// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions.
/// </summary>
struct OrtTensorRTProviderOptionsV2 {
OrtTensorRTProviderOptionsV2& operator=(const OrtTensorRTProviderOptionsV2& other); // copy assignment operator

int device_id{0}; // cuda device id.
int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream.
void* user_compute_stream{nullptr}; // user specified CUDA compute stream.
Expand Down Expand Up @@ -47,7 +49,7 @@ struct OrtTensorRTProviderOptionsV2 {
const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT
int trt_dump_ep_context_model{0}; // Dump EP context node model
const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model.
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute
const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
};
108 changes: 91 additions & 17 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
char* engine_data,
size_t size,
const int64_t embed_mode,
bool compute_capability_enable,
std::string compute_capability,
const logging::Logger* logger) {
auto model_build = graph_viewer.CreateModel(*logger);
Expand Down Expand Up @@ -107,21 +106,20 @@
engine_data_str.assign(engine_data, size);
}
attr_1->set_s(engine_data_str);
LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING;
} else {
attr_1->set_s(engine_cache_path);
}
attr_2->set_name(COMPUTE_CAPABILITY);
attr_2->set_type(onnx::AttributeProto_AttributeType_STRING);
attr_2->set_s(compute_capability);

auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create();
int num_attributes = compute_capability_enable ? 3 : 2;
int num_attributes = 3;
node_attributes->reserve(num_attributes);
node_attributes->emplace(EMBED_MODE, *attr_0);
node_attributes->emplace(EP_CACHE_CONTEXT, *attr_1);

if (compute_capability_enable) {
attr_2->set_name(COMPUTE_CAPABILITY);
attr_2->set_type(onnx::AttributeProto_AttributeType_STRING);
attr_2->set_s(compute_capability);
node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2);
}
node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2);

// Create EP context node
graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
Expand All @@ -137,15 +135,90 @@
return model_proto.release();
}

/*
* Get "EP context node" model path
jywu-msft marked this conversation as resolved.
Show resolved Hide resolved
*
*
* If ep_context_file_path is provided:
* - If ep_context_file_path is a file:
* - If it's a file name without any path associated with it, return "engine_cache_path/ep_context_file_path".
- If it's a file name with path associated with it, return "ep_context_file_path".
* - If ep_context_file_path is a directory, return "ep_context_file_path/original_model_name_ctx.onnx".
* If ep_context_file_path is not provided:
* - Return "engine_cache_path/original_model_name_ctx.onnx".
*
*
* Example 1:
* ep_context_file_path = "/home/user/ep_context_model_foler"
jywu-msft marked this conversation as resolved.
Show resolved Hide resolved
* engine_cache_path = "trt_engine.engine"
* original_model_path = "model.onnx"
* => return "/home/user/ep_context_model_folder/model_ctx.onnx"
*
* Example 2:
* ep_context_file_path = "my_ctx_model.onnx"
* engine_cache_path = "/home/user/cache_folder/trt_engine.engine"
* original_model_path = "model.onnx"
* => return "/home/user/cache_folder/my_ctx_model.onnx"
*
* Example 3:
* ep_context_file_path = "/home/user2/ep_context_model_foler/my_ctx_model.onnx"
* engine_cache_path = "trt_engine.engine"
* original_model_path = "model.onnx"
* => return "/home/user2/ep_context_model_foler/my_ctx_model.onnx"
*
* Example 4:
* ep_context_file_path = ""
* engine_cache_path = "/home/user3/cache_folder/trt_engine.engine"
* original_model_path = "model.onnx"
* => return "/home/user3/cache_folder/model_ctx.onnx"
*
*/
std::string GetCtxNodeModelPath(const std::string& ep_context_file_path,
const std::string& engine_cache_path,
const std::string& original_model_path) {
std::string ctx_model_path;

if (!ep_context_file_path.empty() && !std::filesystem::is_directory(ep_context_file_path)) {
std::filesystem::path ctx_model_file_path = ep_context_file_path;
if (ctx_model_file_path.filename().string() == ep_context_file_path) {
std::filesystem::path cache_path = engine_cache_path;
if (cache_path.has_parent_path()) {
ctx_model_path = cache_path.parent_path().append(ep_context_file_path).string();
} else {
ctx_model_path = ep_context_file_path;
}
} else {
ctx_model_path = ep_context_file_path;
}
} else {
std::filesystem::path model_path = original_model_path;
std::filesystem::path model_name_stem = model_path.stem(); // model_name.onnx -> model_name
std::string ctx_model_name = model_name_stem.string() + "_ctx.onnx";

if (std::filesystem::is_directory(ep_context_file_path)) {
std::filesystem::path model_directory = ep_context_file_path;
ctx_model_path = model_directory.append(ctx_model_name).string();
} else {
std::filesystem::path cache_path = engine_cache_path;
if (cache_path.has_parent_path()) {
ctx_model_path = cache_path.parent_path().append(ctx_model_name).string();
} else {
ctx_model_path = ctx_model_name;
}
}
}
return ctx_model_path;
}

/*
* Dump "EP context node" model
*
*/
void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string engine_cache_path) {
std::fstream dump(engine_cache_path + "_wrapper.onnx", std::ios::out | std::ios::trunc | std::ios::binary);
const std::string& ctx_model_path) {
std::fstream dump(ctx_model_path, std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path + "_wrapper.onnx";
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Dumped " + ctx_model_path;
}

Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) {
Expand Down Expand Up @@ -193,14 +266,13 @@
auto node = graph_viewer.GetNode(0);
auto& attrs = node->GetAttributes();

// Check hardware_architecture(compute_capability) if it's present as an attribute
// Show the warning if compute capability is not matched
if (attrs.count(COMPUTE_CAPABILITY) > 0) {
std::string model_compute_capability = attrs.at(COMPUTE_CAPABILITY).s();
if (model_compute_capability != compute_capability_) {
LOGS_DEFAULT(ERROR) << "The compute capability of the engine cache doesn't match with the GPU's compute capability";
LOGS_DEFAULT(ERROR) << "The compute capability of the engine cache: " << model_compute_capability;
LOGS_DEFAULT(ERROR) << "The compute capability of the GPU: " << compute_capability_;
return false;
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine was compiled for a different compatibility level and might not work or perform suboptimal";

Check warning on line 273 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L273

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:273:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the engine: " << model_compute_capability;
LOGS_DEFAULT(WARNING) << "[TensorRT EP] The compute capability of the GPU: " << compute_capability_;
}
}

Expand All @@ -222,6 +294,8 @@
return false;
}
}
} else if (embed_mode == 1) {
LOGS_DEFAULT(WARNING) << EPCONTEXT_WARNING;
}
}
return true;
Expand Down
10 changes: 8 additions & 2 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
static const std::string EP_CACHE_CONTEXT = "ep_cache_context";
static const std::string COMPUTE_CAPABILITY = "hardware_architecture";
static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft";
static const std::string EPCONTEXT_WARNING =

Check warning on line 19 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L19

For a static/global string constant, use a C style string instead: "static const char EPCONTEXT_WARNING[]". [runtime/string] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:19:  For a static/global string constant, use a C style string instead: "static const char EPCONTEXT_WARNING[]".  [runtime/string] [4]
"It's suggested to set the ORT graph optimization level to 0 and \

Check warning on line 20 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L20

Multi-line string ("...") found. This lint script doesn't do well with such strings, and may give bogus warnings. Use C++11 raw strings or concatenation instead. [readability/multiline_string] [5]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:20:  Multi-line string ("...") found.  This lint script doesn't do well with such strings, and may give bogus warnings.  Use C++11 raw strings or concatenation instead.  [readability/multiline_string] [5]

Check warning on line 20 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L20

Use operator && instead of and [readability/alt_tokens] [2]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:20:  Use operator && instead of and  [readability/alt_tokens] [2]
make \"embed_mode\" to 0 (\"ep_cache_context\" is the cache path)\

Check warning on line 21 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L21

Extra space before ( in function call [whitespace/parens] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:21:  Extra space before ( in function call  [whitespace/parens] [4]
for the best model loading time";

Check warning on line 22 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L22

Multi-line string ("...") found. This lint script doesn't do well with such strings, and may give bogus warnings. Use C++11 raw strings or concatenation instead. [readability/multiline_string] [5]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:22:  Multi-line string ("...") found.  This lint script doesn't do well with such strings, and may give bogus warnings.  Use C++11 raw strings or concatenation instead.  [readability/multiline_string] [5]

bool GraphHasCtxNode(const GraphViewer& graph_viewer);
const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer);
Expand All @@ -25,11 +29,13 @@
char* engine_data,
size_t size,
const int64_t embed_mode,
bool compute_capability_enable,
std::string compute_capability,
const logging::Logger* logger);
std::string GetCtxNodeModelPath(const std::string& ep_context_file_path,
const std::string& engine_cache_path,
const std::string& original_model_path);
void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string engine_cache_path);
const std::string& ctx_model_path);
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
char* engine_data,
size_t size);
Expand Down
43 changes: 24 additions & 19 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1079,8 +1079,6 @@
char const* output_name,
size_t output_index,
size_t output_type,
std::vector<IAllocatorUniquePtr<void>>& scratch_buffers,
OrtAllocator* alloc,
cudaStream_t stream) {
auto allocator = allocator_map[output_name].get();
auto& shape = allocator->getOutputShape();
Expand Down Expand Up @@ -1381,8 +1379,8 @@
profile_opt_shapes = info.profile_opt_shapes;
cuda_graph_enable_ = info.cuda_graph_enable;
dump_ep_context_model_ = info.dump_ep_context_model;
ep_context_file_path_ = info.ep_context_file_path;
ep_context_embed_mode_ = info.ep_context_embed_mode;
ep_context_compute_capability_enable_ = info.ep_context_compute_capability_enable;
} else {
try {
const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations);
Expand Down Expand Up @@ -1543,16 +1541,16 @@
dump_ep_context_model_ = (std::stoi(dump_ep_context_model_env) == 0 ? false : true);
}

const std::string ep_context_file_path_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextComputeCapabilityEnable);

Check warning on line 1544 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1544

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1544:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (!ep_context_file_path_env.empty()) {
ep_context_file_path_ = ep_context_file_path_env;
}

const std::string ep_context_embed_mode_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextEmbedMode);
if (!ep_context_embed_mode_env.empty()) {
ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env);
}

const std::string ep_context_compute_capability_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextComputeCapabilityEnable);
if (!ep_context_compute_capability_env.empty()) {
ep_context_compute_capability_enable_ = (std::stoi(ep_context_compute_capability_env) == 0 ? false : true);
}

} catch (const std::invalid_argument& ex) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what();
} catch (const std::out_of_range& ex) {
Expand Down Expand Up @@ -1580,7 +1578,7 @@
dla_core_ = 0;
}

if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_ || !cache_prefix_.empty()) {
if (engine_cache_enable_ || int8_enable_ || timing_cache_enable_) {
if (!cache_path_.empty() && !fs::is_directory(cache_path_)) {
if (!fs::create_directory(cache_path_)) {
throw std::runtime_error("Failed to create directory " + cache_path_);
Expand Down Expand Up @@ -1692,6 +1690,9 @@
<< ", trt_profile_max_shapes: " << profile_max_shapes
<< ", trt_profile_opt_shapes: " << profile_opt_shapes
<< ", trt_cuda_graph_enable: " << cuda_graph_enable_
<< ", trt_dump_ep_context_model: " << dump_ep_context_model_
<< ", trt_ep_context_file_path: " << ep_context_file_path_
<< ", trt_ep_context_embed_mode: " << ep_context_embed_mode_
<< ", trt_cache_prefix: " << cache_prefix_;
}

Expand Down Expand Up @@ -2831,10 +2832,8 @@
std::unique_ptr<nvinfer1::ICudaEngine> trt_engine;
std::unique_ptr<nvinfer1::IExecutionContext> trt_context;

// Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
// Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity
std::string cache_suffix = "";
std::string cache_path = "";
std::string cache_suffix = "";
// Customize cache prefix if assigned
if (!cache_prefix_.empty()) {
// Generate cache suffix in case user would like to customize cache prefix
Expand All @@ -2843,11 +2842,19 @@
} else {
cache_path = GetCachePath(cache_path_, trt_node_name_with_precision);
}

// Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
// Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity

Check warning on line 2847 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L2847

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:2847:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
const std::string cache_path_prefix = cache_path + "_sm" + compute_capability_;
const std::string engine_cache_path = cache_path_prefix + ".engine";
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
const std::string profile_cache_path = cache_path_prefix + ".profile";

// Generate file name for dumping ep context model
if (dump_ep_context_model_ && ctx_model_path_.empty()) {
ctx_model_path_ = GetCtxNodeModelPath(ep_context_file_path_, engine_cache_path, model_path_);
}

if (!has_dynamic_shape) {
std::string timing_cache_path = "";
bool engine_update = false;
Expand Down Expand Up @@ -2989,10 +2996,9 @@
reinterpret_cast<char*>(serialized_engine->data()),
serialized_engine->size(),
ep_context_embed_mode_,
ep_context_compute_capability_enable_,
compute_capability_,
GetLogger())};
DumpCtxNodeModel(model_proto.get(), cache_path_prefix);
DumpCtxNodeModel(model_proto.get(), ctx_model_path_);
}
}
}
Expand Down Expand Up @@ -3057,11 +3063,10 @@
nullptr,
0,
ep_context_embed_mode_,
ep_context_compute_capability_enable_,
compute_capability_,
GetLogger()));
if (ep_context_embed_mode_ == 0) {
DumpCtxNodeModel(model_proto_.get(), cache_path_prefix);
DumpCtxNodeModel(model_proto_.get(), ctx_model_path_);
}
}

Expand Down Expand Up @@ -3382,7 +3387,7 @@
// dump ep context model
if (dump_ep_context_model_ && ep_context_embed_mode_) {
UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast<char*>(serialized_engine->data()), serialized_engine->size());
DumpCtxNodeModel(model_proto_.get(), cache_path_prefix);
DumpCtxNodeModel(model_proto_.get(), ctx_model_path_);
}
context_update = true;
}
Expand Down Expand Up @@ -3521,7 +3526,7 @@
if (index_iter != output_indexes.end()) {
output_index = index_iter->second;
}
auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, scratch_buffers, alloc, stream);
auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream);

Check warning on line 3529 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3529

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3529:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage());
}
Expand Down Expand Up @@ -3802,7 +3807,7 @@
if (index_iter != output_indexes.end()) {
output_index = index_iter->second;
}
auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, scratch_buffers, alloc, stream);
auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, stream);

Check warning on line 3810 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3810

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3810:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (status != Status::OK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ class TensorrtExecutionProvider : public IExecutionProvider {
bool force_timing_cache_match_ = false;
bool detailed_build_log_ = false;
bool cuda_graph_enable_ = false;
std::string ctx_model_path_;
std::string cache_prefix_;

// The OrtAllocator object will be get during ep compute time
Expand All @@ -301,8 +302,8 @@ class TensorrtExecutionProvider : public IExecutionProvider {

// For create/dump EP context node model
bool dump_ep_context_model_ = false;
std::string ep_context_file_path_;
int ep_context_embed_mode_ = 0;
bool ep_context_compute_capability_enable_ = true;
std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto_ = ONNX_NAMESPACE::ModelProto::Create();

std::unordered_set<std::string> control_flow_op_set_ = {"If", "Loop", "Scan"};
Expand Down
Loading
Loading