Skip to content

Commit

Permalink
add trt_dump_ep_context_model, trt_ep_context_embed_mode, trt_ep_cont…
Browse files Browse the repository at this point in the history
…ext_compute_capability_enable
  • Loading branch information
chilo-ms committed Nov 23, 2023
1 parent 8f7c7ac commit 77a62f2
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,7 @@ struct OrtTensorRTProviderOptionsV2 {
const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT
int trt_dump_ep_context_model{0}; // Dump EP context node model
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data

Check warning on line 50 in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h#L50

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h:50:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute
};
16 changes: 14 additions & 2 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer,
char* engine_data,
size_t size,
const int64_t embed_mode,
bool compute_capability_enable,
int device_id,
const logging::Logger* logger) {
auto model_build = graph_viewer.CreateModel(*logger);
auto& graph_build = model_build->MainGraph();
Expand All @@ -102,6 +104,7 @@ ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer,
// Create EP context node attributes
auto attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); // embed_mode
auto attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); // ep_cache_context
auto attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); // hardware_arch
std::string engine_data_str = "";
attr_0->set_name(EMBED_MODE);
attr_0->set_type(onnx::AttributeProto_AttributeType_INT);
Expand All @@ -117,10 +120,19 @@ ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer,
attr_1->set_s(engine_cache_path);
}
auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create();
int num_attributes = 2;
int num_attributes = compute_capability_enable ? 3 : 2;
node_attributes->reserve(num_attributes);
node_attributes->emplace(EMBED_MODE, *attr_0);
node_attributes->emplace(EP_CACHE_CONTEXT, *attr_1);

Check warning on line 127 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc#L127

Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc:127:  Line ends in whitespace.  Consider deleting these extra spaces.  [whitespace/end_of_line] [4]
if (compute_capability_enable) {
cudaDeviceProp prop;
CUDA_CALL_THROW(cudaGetDeviceProperties(&prop, device_id));
attr_2->set_name(COMPUTE_CAPABILITY);
attr_2->set_type(onnx::AttributeProto_AttributeType_STRING);
attr_2->set_s(GetComputeCapacityString(prop));
node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2);
}

// Create EP context node
graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
Expand All @@ -145,7 +157,7 @@ void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto,
std::string string_buf;
model_proto->SerializeToString(string_buf);

// Dump out EP context node model
// Dump EP context node model to disk
std::fstream dump(engine_cache_path + "_wrapper.onnx", std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path + "_wrapper.onnx";
Expand Down
6 changes: 2 additions & 4 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@
namespace onnxruntime {

static const std::string EPCONTEXT_OP = "EPContext";

Check warning on line 14 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L14

For a static/global string constant, use a C style string instead: "static const char EPCONTEXT_OP[]". [runtime/string] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:14:  For a static/global string constant, use a C style string instead: "static const char EPCONTEXT_OP[]".  [runtime/string] [4]
static const std::string MAIN_CONTEXT = "main_context";
static const std::string EMBED_MODE = "embed_mode";

Check warning on line 15 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L15

For a static/global string constant, use a C style string instead: "static const char EMBED_MODE[]". [runtime/string] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:15:  For a static/global string constant, use a C style string instead: "static const char EMBED_MODE[]".  [runtime/string] [4]
static const std::string EP_CACHE_CONTEXT = "ep_cache_context";

Check warning on line 16 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L16

For a static/global string constant, use a C style string instead: "static const char EP_CACHE_CONTEXT[]". [runtime/string] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:16:  For a static/global string constant, use a C style string instead: "static const char EP_CACHE_CONTEXT[]".  [runtime/string] [4]
static const std::string EP_SDK_VER = "ep_sdk_version";
static const std::string COMPUTE_CAPABILITY = "hardware_arch";

Check warning on line 17 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L17

For a static/global string constant, use a C style string instead: "static const char COMPUTE_CAPABILITY[]". [runtime/string] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:17:  For a static/global string constant, use a C style string instead: "static const char COMPUTE_CAPABILITY[]".  [runtime/string] [4]
static const std::string PARTITION_NAME = "partition_name";
static const std::string SOURCE = "source";
static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft";

Check warning on line 18 in onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h#L18

For a static/global string constant, use a C style string instead: "static const char EPCONTEXT_OP_DOMAIN[]". [runtime/string] [4]
Raw output
onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h:18:  For a static/global string constant, use a C style string instead: "static const char EPCONTEXT_OP_DOMAIN[]".  [runtime/string] [4]

bool GraphHasCtxNode(const GraphViewer& graph_viewer);
Expand All @@ -29,6 +25,8 @@ ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer,
char* engine_data,
size_t size,
const int64_t embed_mode,
bool compute_capability_enable,
int device_id,
const logging::Logger* logger);
void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string engine_cache_path);
Expand Down
52 changes: 44 additions & 8 deletions onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1308,6 +1308,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
profile_max_shapes = info.profile_max_shapes;
profile_opt_shapes = info.profile_opt_shapes;
cuda_graph_enable_ = info.cuda_graph_enable;
dump_ep_context_model_ = info.dump_ep_context_model;
ep_context_embed_mode_ = info.ep_context_embed_mode;
ep_context_compute_capability_enable_ = info.ep_context_compute_capability_enable;
} else {
try {
const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations);
Expand Down Expand Up @@ -1461,6 +1464,22 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
if (!cuda_graph_enable_env.empty()) {
cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true);
}

const std::string dump_ep_context_model_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDumpEpContextModel);

Check warning on line 1468 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1468

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1468:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (!dump_ep_context_model_env.empty()) {
dump_ep_context_model_ = (std::stoi(dump_ep_context_model_env) == 0 ? false : true);
}

const std::string ep_context_embed_mode_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextEmbedMode);

Check warning on line 1473 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1473

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1473:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (!ep_context_embed_mode_env.empty()) {
ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env);
}

const std::string ep_context_compute_capability_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextComputeCapabilityEnable);

Check warning on line 1478 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1478

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1478:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
if (!ep_context_compute_capability_env.empty()) {
ep_context_compute_capability_enable_ = (std::stoi(ep_context_compute_capability_env) == 0 ? false : true);
}

Check warning on line 1482 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L1482

Redundant blank line at the end of a code block should be deleted. [whitespace/blank_line] [3]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:1482:  Redundant blank line at the end of a code block should be deleted.  [whitespace/blank_line] [3]
} catch (const std::invalid_argument& ex) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what();
} catch (const std::out_of_range& ex) {
Expand Down Expand Up @@ -2978,11 +2997,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeFromGraph(const GraphViewer&
CUDA_CALL_THROW(cudaGetDeviceProperties(&prop, device_id_));
std::string compute_capability = GetComputeCapacity(prop);
const std::string cache_path = GetCachePath(cache_path_, trt_node_name_with_precision);
const std::string cache_path_prefix = cache_path + "_sm" + compute_capability;
const std::string engine_cache_path = cache_path_prefix + ".engine";
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
const std::string profile_cache_path = cache_path_prefix + ".profile";

if (!has_dynamic_shape) {
const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine";
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile";
std::string timing_cache_path = "";
bool engine_update = false;
if (timing_cache_enable_) {
Expand Down Expand Up @@ -3123,8 +3143,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeFromGraph(const GraphViewer&
reinterpret_cast<char*>(serialized_engine->data()),

Check warning on line 3143 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3143

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3143:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
serialized_engine->size(),
ep_context_embed_mode_,
ep_context_compute_capability_enable_,

Check warning on line 3146 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3146

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3146:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
device_id_,
GetLogger())};
DumpCtxNodeModel(model_proto.get(), cache_path + "_sm" + compute_capability);
DumpCtxNodeModel(model_proto.get(), cache_path_prefix);
}
}
}
Expand Down Expand Up @@ -3180,8 +3202,21 @@ Status TensorrtExecutionProvider::CreateNodeComputeFromGraph(const GraphViewer&
input_shape_ranges_[fused_node.Name()] = input_implicit_shape_ranges;
profiles_.emplace(fused_node.Name(), std::move(trt_profiles));

// For dynamic shape input model, firstly TRT EP creates a model proto which includes inputs, outputs and empty engine.

Check warning on line 3205 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3205

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3205:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
// TRT EP will serialize the model at inference time due to engine can be updated and the updated engine should be included in the model.

Check warning on line 3206 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3206

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3206:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
// However, if the embed_mode is 0 (only includes engine path), TRT EP will serialize it here.
if (dump_ep_context_model_ && has_dynamic_shape) {
model_proto_.reset(CreateCtxNodeModel(graph_body_viewer, cache_path + "_sm" + compute_capability, nullptr, 0, ep_context_embed_mode_, GetLogger()));
model_proto_.reset(CreateCtxNodeModel(graph_body_viewer,
engine_cache_path,
nullptr,
0,
ep_context_embed_mode_,
ep_context_compute_capability_enable_,
device_id_,
GetLogger()));
if (ep_context_embed_mode_ == 0) {
DumpCtxNodeModel(model_proto_.get(), cache_path_prefix);
}
}

// Create function state
Expand Down Expand Up @@ -3259,9 +3294,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeFromGraph(const GraphViewer&

// Prepare cache name
const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
const std::string engine_cache_path = cache_path + "_sm" + compute_capability + ".engine";
const std::string cache_path_prefix = cache_path + "_sm" + compute_capability;
const std::string engine_cache_path = cache_path_prefix + ".engine";
const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
const std::string profile_cache_path = cache_path + "_sm" + compute_capability + ".profile";
const std::string profile_cache_path = cache_path_prefix + ".profile";
std::string timing_cache_path = "";
if (timing_cache_enable_) {
timing_cache_path = GetTimingCachePath(global_cache_path_, prop);
Expand Down Expand Up @@ -3497,7 +3533,7 @@ Status TensorrtExecutionProvider::CreateNodeComputeFromGraph(const GraphViewer&

Check warning on line 3533 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3533

Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3533:  Line ends in whitespace.  Consider deleting these extra spaces.  [whitespace/end_of_line] [4]
if (dump_ep_context_model_ && ep_context_embed_mode_) {
UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast<char*>(serialized_engine->data()), serialized_engine->size());

Check warning on line 3535 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc#L3535

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc:3535:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
DumpCtxNodeModel(model_proto_.get(), cache_path + "_sm" + compute_capability);
DumpCtxNodeModel(model_proto_.get(), cache_path_prefix);
}
context_update = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ static const std::string kProfilesMinShapes = "ORT_TENSORRT_PROFILE_MIN_SHAPES";
static const std::string kProfilesMaxShapes = "ORT_TENSORRT_PROFILE_MAX_SHAPES";
static const std::string kProfilesOptShapes = "ORT_TENSORRT_PROFILE_OPT_SHAPES";
static const std::string kCudaGraphEnable = "ORT_TENSORRT_CUDA_GRAPH_ENABLE";
static const std::string kDumpEpContextModel = "ORT_DUMP_EP_CONTEXT_MODEL";

Check warning on line 49 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h#L49

For a static/global string constant, use a C style string instead: "static const char kDumpEpContextModel[]". [runtime/string] [4]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h:49:  For a static/global string constant, use a C style string instead: "static const char kDumpEpContextModel[]".  [runtime/string] [4]
static const std::string kEpContextEmbedMode = "ORT_EP_CONTEXT_EMBED_MODE";

Check warning on line 50 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h#L50

For a static/global string constant, use a C style string instead: "static const char kEpContextEmbedMode[]". [runtime/string] [4]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h:50:  For a static/global string constant, use a C style string instead: "static const char kEpContextEmbedMode[]".  [runtime/string] [4]
static const std::string kEpContextComputeCapabilityEnable = "ORT_EP_CONTEXT_COMPUTE_CAPABILITY_ENABLE";

Check warning on line 51 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h#L51

For a static/global string constant, use a C style string instead: "static const char kEpContextComputeCapabilityEnable[]". [runtime/string] [4]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h:51:  For a static/global string constant, use a C style string instead: "static const char kEpContextComputeCapabilityEnable[]".  [runtime/string] [4]
// Old env variable for backward compatibility
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
} // namespace tensorrt_env_vars
Expand Down Expand Up @@ -315,8 +318,9 @@ class TensorrtExecutionProvider : public IExecutionProvider {
OrtAllocator* alloc_ = nullptr;

// For create/dump EP context node model
bool dump_ep_context_model_ = true;
int ep_context_embed_mode_ = 1;
bool dump_ep_context_model_ = false;
int ep_context_embed_mode_ = 0;
bool ep_context_compute_capability_enable_ = true;
std::unique_ptr<ONNX_NAMESPACE::ModelProto> model_proto_ = ONNX_NAMESPACE::ModelProto::Create();

std::unordered_set<std::string> control_flow_op_set_ = {"If", "Loop", "Scan"};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ constexpr const char* kProfilesMinShapes = "trt_profile_min_shapes";
constexpr const char* kProfilesMaxShapes = "trt_profile_max_shapes";
constexpr const char* kProfilesOptShapes = "trt_profile_opt_shapes";
constexpr const char* kCudaGraphEnable = "trt_cuda_graph_enable";
constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model";
constexpr const char* kEpContextEmbedMode = "trt_ep_context_embed_mode";
constexpr const char* kEpContextComputeCapabilityEnable = "trt_ep_context_compute_capability_enable";
} // namespace provider_option_names
} // namespace tensorrt

Expand Down Expand Up @@ -97,6 +100,9 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
.AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes)
.AddAssignmentToReference(tensorrt::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes)
.AddAssignmentToReference(tensorrt::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable)
.AddAssignmentToReference(tensorrt::provider_option_names::kDumpEpContextModel, info.dump_ep_context_model)
.AddAssignmentToReference(tensorrt::provider_option_names::kEpContextEmbedMode, info.ep_context_embed_mode)
.AddAssignmentToReference(tensorrt::provider_option_names::kEpContextComputeCapabilityEnable, info.ep_context_compute_capability_enable)

Check warning on line 105 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc#L105

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc:105:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
.Parse(options)); // add new provider option here.

return info;
Expand Down Expand Up @@ -138,6 +144,9 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)},
{tensorrt::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)},
{tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)},
{tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.dump_ep_context_model)},
{tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.ep_context_embed_mode)},
{tensorrt::provider_option_names::kEpContextComputeCapabilityEnable, MakeStringWithClassicLocale(info.ep_context_compute_capability_enable)},

Check warning on line 149 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc#L149

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc:149:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
};
return options;
}
Expand Down Expand Up @@ -188,6 +197,9 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
{tensorrt::provider_option_names::kProfilesMaxShapes, kProfilesMaxShapes_},
{tensorrt::provider_option_names::kProfilesOptShapes, kProfilesOptShapes_},
{tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.trt_cuda_graph_enable)},
{tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.trt_dump_ep_context_model)},

Check warning on line 200 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc#L200

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc:200:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
{tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.trt_ep_context_embed_mode)},

Check warning on line 201 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc#L201

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc:201:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
{tensorrt::provider_option_names::kEpContextComputeCapabilityEnable, MakeStringWithClassicLocale(info.trt_ep_context_compute_capability_enable)},

Check warning on line 202 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc#L202

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc:202:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
};
return options;
}
Expand Down Expand Up @@ -279,5 +291,8 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options
trt_provider_options_v2.trt_profile_opt_shapes = copy_string_if_needed(internal_options.profile_opt_shapes);

trt_provider_options_v2.trt_cuda_graph_enable = internal_options.cuda_graph_enable;
trt_provider_options_v2.trt_dump_ep_context_model = internal_options.dump_ep_context_model;
trt_provider_options_v2.trt_ep_context_embed_mode = internal_options.ep_context_embed_mode;
trt_provider_options_v2.trt_ep_context_compute_capability_enable = internal_options.ep_context_compute_capability_enable;

Check warning on line 296 in onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc#L296

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc:296:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ struct TensorrtExecutionProviderInfo {
std::string profile_max_shapes{""};
std::string profile_opt_shapes{""};
bool cuda_graph_enable{false};
bool dump_ep_context_model{false};
int ep_context_embed_mode{0};
bool ep_context_compute_capability_enable{1};

static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ struct Tensorrt_Provider : Provider {
info.profile_max_shapes = options.trt_profile_max_shapes == nullptr ? "" : options.trt_profile_max_shapes;
info.profile_opt_shapes = options.trt_profile_opt_shapes == nullptr ? "" : options.trt_profile_opt_shapes;
info.cuda_graph_enable = options.trt_cuda_graph_enable != 0;
info.dump_ep_context_model = options.trt_dump_ep_context_model != 0;
info.ep_context_embed_mode = options.trt_ep_context_embed_mode;
info.ep_context_compute_capability_enable = options.trt_ep_context_compute_capability_enable != 0;

return std::make_shared<TensorrtProviderFactory>(info);
}
Expand Down
22 changes: 22 additions & 0 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -717,6 +717,28 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_cuda_graph_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "trt_dump_ep_context_model") {
if (option.second == "True" || option.second == "true") {
params.trt_dump_ep_context_model = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_dump_ep_context_model = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dump_ep_context_model' should be 'True' or 'False'. Default value is 'False'.\n");

Check warning on line 726 in onnxruntime/python/onnxruntime_pybind_state.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/onnxruntime_pybind_state.cc#L726

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/onnxruntime_pybind_state.cc:726:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
} else if (option.first == "trt_ep_context_embed_mode") {
if (!option.second.empty()) {
params.trt_ep_context_embed_mode = std::stoi(option.second);
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_embed_mode' should be a positive integer number i.e. '1'.\n");

Check warning on line 732 in onnxruntime/python/onnxruntime_pybind_state.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/onnxruntime_pybind_state.cc#L732

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/onnxruntime_pybind_state.cc:732:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
} else if (option.first == "trt_ep_context_compute_capability_enable") {
if (option.second == "True" || option.second == "true") {
params.trt_ep_context_compute_capability_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.trt_ep_context_compute_capability_enable = false;
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_compute_capability_enable' should be 'True' or 'False'. Default value is 'False'.\n");

Check warning on line 740 in onnxruntime/python/onnxruntime_pybind_state.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/python/onnxruntime_pybind_state.cc#L740

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/python/onnxruntime_pybind_state.cc:740:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
}
} else {
ORT_THROW("Invalid TensorRT EP option: ", option.first);
}
Expand Down

0 comments on commit 77a62f2

Please sign in to comment.