From 46dd0d3f52183d951640258c09c54adc1daeeeb9 Mon Sep 17 00:00:00 2001
From: Chi Lo <54722500+chilo-ms@users.noreply.github.com>
Date: Thu, 11 Jan 2024 22:20:54 -0800
Subject: [PATCH] [TensorRT EP] Load precompiled TRT engine file directly
(#18217)
When the TRT engine cache (precompiled engine) is present, it doesn't
make sense to go over the processes of model verification, model
optimization, TRT EP's GetCapability(), TRT EP's model proto
reconstruction, calling TRT parser and engine compilation.
This PR makes TRT EP skip those processes and directly load the engine
to perform inference.
The feature request:
https://github.com/microsoft/onnxruntime/issues/18072
Features:
- Replace original model with TRT engine wrapped ONNX model. It can save
a lot of time as mentioned above.
- How to get TRT engine wrapped ONNX model?
1. Set `trt_dump_ep_context_model` provider option to "true" and run the
inference. You will find the "xxx_wrapper.onnx" at the engine cache
path. (The same logic of generating engine cache)
2. Use gen_trt_engine_wrapper_onnx_model.py
- Three provider options are added,
`trt_dump_ep_context_model`: Enable dump wrapped onnx model by TRT EP
`trt_ep_context_embed_mode`: Add embed_mode as attribute. 0 means engine
cache path, 1 means engine binary data.
`trt_ep_context_compute_capability_enable`: Add hardware_arch as
attribute. When running the model, TRT EP will check consistency between
model's hardware_arch and GPU's compute capability.
- When the engine cache path is given in the wrapped model, TRT EP will
first search for the engine file using the path (relative to model
path), if it can't find it, it will change to use the path as it is
(depends on user, could be relative to working dir or absolute path)
Note:
1. This PR includes the change of
https://github.com/microsoft/onnxruntime/pull/17751
Constraints:
1. The whole model should be fully supported by TRT.
4. Users need to make sure the engine is built with min/max/opt
optimization profiles that large enough to cover the range of all
inputs. TRT EP will simply fail and won't rebuild the engine if the
input shape is out of range during runtime.
---
docs/ContribOperators.md | 2 +
.../tensorrt/tensorrt_provider_options.h | 3 +
.../core/graph/contrib_ops/contrib_defs.cc | 5 +
.../shared_library/provider_interfaces.h | 3 +
.../shared_library/provider_wrappedtypes.h | 3 +
.../tensorrt/onnx_ctx_model_helper.cc | 229 ++
.../tensorrt/onnx_ctx_model_helper.h | 55 +
.../tensorrt/tensorrt_execution_provider.cc | 2048 ++++++++++-------
.../tensorrt/tensorrt_execution_provider.h | 44 +
.../tensorrt_execution_provider_info.cc | 15 +
.../tensorrt_execution_provider_info.h | 3 +
.../tensorrt_execution_provider_utils.h | 1 +
.../tensorrt/tensorrt_provider_factory.cc | 3 +
.../core/session/provider_bridge_ort.cc | 3 +
.../python/onnxruntime_pybind_state.cc | 22 +
.../gen_trt_engine_wrapper_onnx_model.py | 174 ++
.../python/onnxruntime_test_engine_wrapper.py | 100 +
17 files changed, 1873 insertions(+), 840 deletions(-)
create mode 100644 onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
create mode 100644 onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
create mode 100644 onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py
create mode 100644 onnxruntime/test/python/onnxruntime_test_engine_wrapper.py
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index b5b69c15d65c9..45c0e6f822ce9 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -1588,6 +1588,8 @@ This version of the operator has been available since version 1 of the 'com.micr
payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0.
ep_sdk_version : string
(Optional) SDK version used to convert the model.
+hardware_architecture : string
+(Optional) Hardware architecture.
main_context : int
Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.
notes : string
diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
index 680ce1cc5b9a2..daa4089061825 100644
--- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
+++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
@@ -46,4 +46,7 @@ struct OrtTensorRTProviderOptionsV2 {
const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT
+ int trt_dump_ep_context_model{0}; // Dump EP context node model
+ int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
+ int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute
};
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 54eb43753931a..982e8fd834b76 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -3230,6 +3230,11 @@ void RegisterContribSchemas() {
"(Optional) SDK version used to convert the model.",
AttributeProto::STRING,
OPTIONAL_VALUE)
+ .Attr(
+ "hardware_architecture",
+ "(Optional) Hardware architecture.",
+ AttributeProto::STRING,
+ OPTIONAL_VALUE)
.Attr(
"partition_name",
"(Optional) partitioned graph name.",
diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h
index 27226005a9c0b..2883d92e90dba 100644
--- a/onnxruntime/core/providers/shared_library/provider_interfaces.h
+++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h
@@ -330,6 +330,7 @@ struct ProviderHost {
virtual int64_t AttributeProto__i(const ONNX_NAMESPACE::AttributeProto* p) = 0;
virtual float AttributeProto__f(const ONNX_NAMESPACE::AttributeProto* p) = 0;
virtual void AttributeProto__set_s(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0;
+ virtual void AttributeProto__set_i(ONNX_NAMESPACE::AttributeProto* p, int64_t value) = 0;
virtual const ::std::string& AttributeProto__s(const ONNX_NAMESPACE::AttributeProto* p) = 0;
virtual void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0;
virtual void AttributeProto__set_type(ONNX_NAMESPACE::AttributeProto* p, ONNX_NAMESPACE::AttributeProto_AttributeType value) = 0;
@@ -351,6 +352,7 @@ struct ProviderHost {
virtual ONNX_NAMESPACE::ValueInfoProtos* GraphProto__mutable_value_info(ONNX_NAMESPACE::GraphProto* p) = 0;
virtual ONNX_NAMESPACE::TensorProtos* GraphProto__mutable_initializer(ONNX_NAMESPACE::GraphProto* p) = 0;
virtual ONNX_NAMESPACE::NodeProto* GraphProto__add_node(ONNX_NAMESPACE::GraphProto* p) = 0;
+ virtual ONNX_NAMESPACE::NodeProto* GraphProto__mutable_node(ONNX_NAMESPACE::GraphProto* p, int index) = 0;
// ModelProto
virtual std::unique_ptr ModelProto__construct() = 0;
@@ -372,6 +374,7 @@ struct ProviderHost {
virtual void NodeProto__operator_assign(ONNX_NAMESPACE::NodeProto* p, const ONNX_NAMESPACE::NodeProto& v) = 0;
virtual int NodeProto__attribute_size(ONNX_NAMESPACE::NodeProto* p) = 0;
virtual const ONNX_NAMESPACE::AttributeProto& NodeProto__attribute(const ONNX_NAMESPACE::NodeProto* p, int index) const = 0;
+ virtual ONNX_NAMESPACE::AttributeProto* NodeProto__mutable_attribute(ONNX_NAMESPACE::NodeProto* p, int index) = 0;
// TensorProto
virtual std::unique_ptr TensorProto__construct() = 0;
diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
index c0b282b202ef6..149a43222b445 100644
--- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
+++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h
@@ -74,6 +74,7 @@ struct AttributeProto final {
int64_t i() const { return g_host->AttributeProto__i(this); }
float f() const { return g_host->AttributeProto__f(this); }
void set_s(const ::std::string& value) { return g_host->AttributeProto__set_s(this, value); }
+ void set_i(int64_t value) { return g_host->AttributeProto__set_i(this, value); }
const ::std::string& s() const { return g_host->AttributeProto__s(this); }
void set_name(const ::std::string& value) { return g_host->AttributeProto__set_name(this, value); }
void set_type(AttributeProto_AttributeType value) { return g_host->AttributeProto__set_type(this, value); }
@@ -118,6 +119,7 @@ struct GraphProto final {
ValueInfoProtos* mutable_value_info() { return g_host->GraphProto__mutable_value_info(this); }
TensorProtos* mutable_initializer() { return g_host->GraphProto__mutable_initializer(this); }
NodeProto* add_node() { return g_host->GraphProto__add_node(this); }
+ NodeProto* mutable_node(int index) { return g_host->GraphProto__mutable_node(this, index); }
GraphProto() = delete;
GraphProto(const GraphProto&) = delete;
@@ -148,6 +150,7 @@ struct NodeProto final {
void operator=(const NodeProto& v) { g_host->NodeProto__operator_assign(this, v); }
int attribute_size() { return g_host->NodeProto__attribute_size(this); }
const AttributeProto& attribute(int index) const { return g_host->NodeProto__attribute(this, index); }
+ AttributeProto* mutable_attribute(int index) { return g_host->NodeProto__mutable_attribute(this, index); }
NodeProto() = delete;
NodeProto(const NodeProto&) = delete;
diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
new file mode 100644
index 0000000000000..4d8ba6a0891e3
--- /dev/null
+++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
@@ -0,0 +1,229 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+#include
+
+#include "onnx_ctx_model_helper.h"
+#include "core/providers/cuda/shared_inc/cuda_call.h"
+#include "core/framework/execution_provider.h"
+
+namespace onnxruntime {
+
+/*
+ * Check whether the graph has the EP context contrib op.
+ * The op can contain the precompiled engine info for TRT EP to directly load the engine.
+ *
+ * Note: Please see more details about "EPContext" contrib op in contrib_defs.cc
+ */
+bool GraphHasCtxNode(const GraphViewer& graph_viewer) {
+ for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) {
+ auto node = graph_viewer.GetNode(i);
+ if (node != nullptr && node->OpType() == EPCONTEXT_OP) {
+ return true;
+ }
+ }
+ return false;
+}
+
+const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer) {
+ // find the top level graph
+ const Graph* cur_graph = &graph_viewer.GetGraph();
+ while (cur_graph->IsSubgraph()) {
+ cur_graph = cur_graph->ParentGraph();
+ }
+
+ const Graph& main_graph = *cur_graph;
+ return main_graph.ModelPath();
+}
+
+std::filesystem::path LocateEngineRelativeToPath(std::string engine_cache_path, const onnxruntime::Path& path) {
+ std::filesystem::path base_path(path.ToPathString());
+ std::filesystem::path parent_path = base_path.parent_path();
+ std::filesystem::path engine_path = parent_path.append(engine_cache_path);
+ return engine_path;
+}
+
+/*
+ * Update ep_cache_context attribute of the EP context node with the given engine binary data
+ */
+void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
+ char* engine_data,
+ size_t size) {
+ ONNX_NAMESPACE::GraphProto* graph_proto = model_proto->mutable_graph();
+ ONNX_NAMESPACE::NodeProto* node_proto = graph_proto->mutable_node(0);
+
+ for (int i = 0; i < node_proto->attribute_size(); ++i) {
+ ONNX_NAMESPACE::AttributeProto* attribute_proto = node_proto->mutable_attribute(i);
+ if (attribute_proto->name() == EP_CACHE_CONTEXT) {
+ std::string engine_data_str = "";
+ if (size > 0) {
+ engine_data_str.assign(engine_data, size);
+ }
+ attribute_proto->set_s(engine_data_str);
+ }
+ }
+}
+
+/*
+ * Create "EP context node" model where engine information is embedded
+ */
+ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer,
+ const std::string engine_cache_path,
+ char* engine_data,
+ size_t size,
+ const int64_t embed_mode,
+ bool compute_capability_enable,
+ std::string compute_capability,
+ const logging::Logger* logger) {
+ auto model_build = graph_viewer.CreateModel(*logger);
+ auto& graph_build = model_build->MainGraph();
+
+ // Get graph inputs and outputs
+ std::vector inputs, outputs;
+ for (auto input : graph_viewer.GetInputs()) {
+ auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto());
+ inputs.push_back(&n_input);
+ }
+
+ for (auto output : graph_viewer.GetOutputs()) {
+ auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto());
+ outputs.push_back(&n_output);
+ }
+
+ // Create EP context node attributes
+ auto attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); // embed_mode
+ auto attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); // ep_cache_context
+ auto attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); // hardware_architecture
+ std::string engine_data_str = "";
+ attr_0->set_name(EMBED_MODE);
+ attr_0->set_type(onnx::AttributeProto_AttributeType_INT);
+ attr_0->set_i(embed_mode);
+ attr_1->set_name(EP_CACHE_CONTEXT);
+ attr_1->set_type(onnx::AttributeProto_AttributeType_STRING);
+ if (embed_mode) {
+ if (size > 0) {
+ engine_data_str.assign(engine_data, size);
+ }
+ attr_1->set_s(engine_data_str);
+ } else {
+ attr_1->set_s(engine_cache_path);
+ }
+ auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create();
+ int num_attributes = compute_capability_enable ? 3 : 2;
+ node_attributes->reserve(num_attributes);
+ node_attributes->emplace(EMBED_MODE, *attr_0);
+ node_attributes->emplace(EP_CACHE_CONTEXT, *attr_1);
+
+ if (compute_capability_enable) {
+ attr_2->set_name(COMPUTE_CAPABILITY);
+ attr_2->set_type(onnx::AttributeProto_AttributeType_STRING);
+ attr_2->set_s(compute_capability);
+ node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2);
+ }
+
+ // Create EP context node
+ graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
+ ORT_ENFORCE(graph_build.Resolve().IsOK());
+
+ // Serialize modelproto to string
+ auto new_graph_viewer = graph_build.CreateGraphViewer();
+ auto model = new_graph_viewer->CreateModel(*logger);
+ auto model_proto = model->ToProto();
+ new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true);
+ model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
+
+ return model_proto.release();
+}
+
+/*
+ * Dump "EP context node" model
+ *
+ */
+void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto,
+ const std::string engine_cache_path) {
+ std::fstream dump(engine_cache_path + "_wrapper.onnx", std::ios::out | std::ios::trunc | std::ios::binary);
+ model_proto->SerializeToOstream(dump);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path + "_wrapper.onnx";
+}
+
+Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) {
+ if (!ValidateEPCtxNode(graph_viewer)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node");
+ }
+ auto node = graph_viewer.GetNode(0);
+ auto& attrs = node->GetAttributes();
+
+ const int64_t embed_mode = attrs.at(EMBED_MODE).i();
+ if (embed_mode) {
+ // Get engine from byte stream
+ const std::string& context_binary = attrs.at(EP_CACHE_CONTEXT).s();
+ *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(const_cast(context_binary.c_str()),
+ static_cast(context_binary.length())));
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it";
+ if (!(*trt_engine_)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP could not deserialize engine from binary data");
+ }
+ } else {
+ // Get engine from cache file
+ std::ifstream engine_file(engine_cache_path_.string(), std::ios::binary | std::ios::in);
+ engine_file.seekg(0, std::ios::end);
+ size_t engine_size = engine_file.tellg();
+ engine_file.seekg(0, std::ios::beg);
+ std::unique_ptr engine_buf{new char[engine_size]};
+ engine_file.read((char*)engine_buf.get(), engine_size);
+ *(trt_engine_) = std::unique_ptr(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size));
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path_.string();
+ if (!(*trt_engine_)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP could not deserialize engine from cache: " + engine_cache_path_.string());
+ }
+ }
+ return Status::OK();
+}
+
+/*
+ * The sanity check for EP context contrib op.
+ */
+bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer) {
+ assert(graph_viewer.NumberOfNodes() == 1);
+ assert(graph_viewer.GetNode(0)->OpType() == EPCONTEXT_OP);
+ auto node = graph_viewer.GetNode(0);
+ auto& attrs = node->GetAttributes();
+
+ // Check hardware_architecture(compute_capability) if it's present as an attribute
+ if (attrs.count(COMPUTE_CAPABILITY) > 0) {
+ std::string model_compute_capability = attrs.at(COMPUTE_CAPABILITY).s();
+ if (model_compute_capability != compute_capability_) {
+ LOGS_DEFAULT(ERROR) << "The compute capability of the engine cache doesn't match with the GPU's compute capability";
+ LOGS_DEFAULT(ERROR) << "The compute capability of the engine cache: " << model_compute_capability;
+ LOGS_DEFAULT(ERROR) << "The compute capability of the GPU: " << compute_capability_;
+ return false;
+ }
+ }
+
+ // "embed_mode" attr and "ep_cache_context" attr should be present
+ if (attrs.count(EMBED_MODE) > 0 && attrs.count(EP_CACHE_CONTEXT) > 0) {
+ // ep_cache_context: payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0
+ const int64_t embed_mode = attrs.at(EMBED_MODE).i();
+
+ // engine cache path
+ if (embed_mode == 0) {
+ // First assume engine cache path is relatvie to model path,
+ // If not, then assume the engine cache path is an absolute path.
+ engine_cache_path_ = LocateEngineRelativeToPath(attrs.at(EP_CACHE_CONTEXT).s(), GetModelPath(graph_viewer));
+ auto default_engine_cache_path_ = engine_cache_path_;
+ if (!std::filesystem::exists(engine_cache_path_)) {
+ engine_cache_path_.assign(attrs.at(EP_CACHE_CONTEXT).s());
+ if (!std::filesystem::exists(engine_cache_path_)) {
+ LOGS_DEFAULT(ERROR) << "Can't find " << default_engine_cache_path_.string() << " or " << engine_cache_path_.string() << " TensorRT engine";
+ return false;
+ }
+ }
+ }
+ }
+ return true;
+}
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
new file mode 100644
index 0000000000000..ab6ea733adfa1
--- /dev/null
+++ b/onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
@@ -0,0 +1,55 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+#include
+
+#include "NvInfer.h"
+#include "core/providers/shared_library/provider_api.h"
+
+namespace onnxruntime {
+
+static const std::string EPCONTEXT_OP = "EPContext";
+static const std::string EMBED_MODE = "embed_mode";
+static const std::string EP_CACHE_CONTEXT = "ep_cache_context";
+static const std::string COMPUTE_CAPABILITY = "hardware_architecture";
+static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft";
+
+bool GraphHasCtxNode(const GraphViewer& graph_viewer);
+const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer);
+std::filesystem::path LocateEngineRelativeToPath(std::string engine_cache_path, const onnxruntime::Path& path);
+ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer,
+ const std::string engine_cache_path,
+ char* engine_data,
+ size_t size,
+ const int64_t embed_mode,
+ bool compute_capability_enable,
+ std::string compute_capability,
+ const logging::Logger* logger);
+void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto,
+ const std::string engine_cache_path);
+void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
+ char* engine_data,
+ size_t size);
+
+class TensorRTCacheModelHandler {
+ public:
+ TensorRTCacheModelHandler(std::unique_ptr* trt_engine,
+ nvinfer1::IRuntime* trt_runtime,
+ std::string compute_capability) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), compute_capability_(compute_capability) {
+ }
+ ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler);
+
+ bool ValidateEPCtxNode(const GraphViewer& graph_viewer);
+
+ Status GetEpContextFromGraph(const GraphViewer& graph_viewer);
+
+ private:
+ std::unique_ptr* trt_engine_;
+ nvinfer1::IRuntime* trt_runtime_;
+ std::filesystem::path engine_cache_path_;
+ std::string compute_capability_;
+}; // TRTCacheModelHandler
+} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
index 4ece068b50fd1..1d4ead019dc27 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc
@@ -11,6 +11,7 @@
#include "tensorrt_execution_provider.h"
#include "tensorrt_execution_provider_utils.h"
#include "tensorrt_execution_provider_custom_ops.h"
+#include "onnx_ctx_model_helper.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/providers/cuda/math/unary_elementwise_ops_impl.h"
#include "core/providers/cuda/gpu_data_transfer.h"
@@ -1378,6 +1379,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
profile_max_shapes = info.profile_max_shapes;
profile_opt_shapes = info.profile_opt_shapes;
cuda_graph_enable_ = info.cuda_graph_enable;
+ dump_ep_context_model_ = info.dump_ep_context_model;
+ ep_context_embed_mode_ = info.ep_context_embed_mode;
+ ep_context_compute_capability_enable_ = info.ep_context_compute_capability_enable;
} else {
try {
const std::string max_partition_iterations_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kMaxPartitionIterations);
@@ -1531,6 +1535,22 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const TensorrtExecutionProv
if (!cuda_graph_enable_env.empty()) {
cuda_graph_enable_ = (std::stoi(cuda_graph_enable_env) == 0 ? false : true);
}
+
+ const std::string dump_ep_context_model_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kDumpEpContextModel);
+ if (!dump_ep_context_model_env.empty()) {
+ dump_ep_context_model_ = (std::stoi(dump_ep_context_model_env) == 0 ? false : true);
+ }
+
+ const std::string ep_context_embed_mode_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextEmbedMode);
+ if (!ep_context_embed_mode_env.empty()) {
+ ep_context_embed_mode_ = std::stoi(ep_context_embed_mode_env);
+ }
+
+ const std::string ep_context_compute_capability_env = onnxruntime::GetEnvironmentVar(tensorrt_env_vars::kEpContextComputeCapabilityEnable);
+ if (!ep_context_compute_capability_env.empty()) {
+ ep_context_compute_capability_enable_ = (std::stoi(ep_context_compute_capability_env) == 0 ? false : true);
+ }
+
} catch (const std::invalid_argument& ex) {
LOGS_DEFAULT(WARNING) << "[TensorRT EP] Invalid Argument (from environment variables): " << ex.what();
} catch (const std::out_of_range& ex) {
@@ -2283,6 +2303,19 @@ bool TensorrtExecutionProvider::DetectTensorRTGraphCycles(SubGraphCollection_t&
std::vector>
TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
const IKernelLookup& /*kernel_lookup*/) const {
+ // Construct subgraph capability from node list
+ std::vector> result;
+
+ // If the model consists of only a single "EPContext" contrib op, it means TRT EP can fetch the precompiled engine info from the node and
+ // load the engine directly without having to go through the processes of graph proto reconstruction, calling TRT parser and engine compilation.
+ // So, simply return the ComputeCapability here.
+ if (graph.NumberOfNodes() == 1 && GraphHasCtxNode(graph)) {
+ SubGraph_t supported_node_vector = {{0}, true};
+ std::unique_ptr sub_graph = GetSubGraph(supported_node_vector, graph, TRTGenerateId(graph), 0);
+ result.push_back(ComputeCapability::Create(std::move(sub_graph)));
+ return result;
+ }
+
// Get ModelPath
const auto& path_string = graph.ModelPath().ToPathString();
#ifdef _WIN32
@@ -2371,9 +2404,6 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
}
}
- // Construct subgraph capability from node list
- std::vector> result;
-
// Handle the case where the graph is subgraph of control flow op.
// The purpose is to make control flow op as well as its subgraphs run on TRT.
// Here we need to check whether subgraph is fully supported by TRT and don't fuse the nodes of the subgraph until control flow op level.
@@ -2488,721 +2518,391 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorName()] = i;
}
- // Reconstruct graph proto from fused node's function body
- auto model = graph_body_viewer.CreateModel(*GetLogger());
- auto model_proto = model->ToProto();
- graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true);
- model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
- std::string string_buf;
- model_proto->SerializeToString(string_buf);
-
- if (dump_subgraphs_) {
- // Dump TensorRT subgraphs
- std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
- model_proto->SerializeToOstream(dump);
+ Status status;
+ if (GraphHasCtxNode(graph_body_viewer)) {
+ status = CreateNodeComputeInfoFromPrecompiledEngine(graph_body_viewer, fused_node, input_map, output_map, node_compute_funcs);
+ } else {
+ status = CreateNodeComputeInfoFromGraph(graph_body_viewer, fused_node, input_map, output_map, node_compute_funcs);
+ }
+ if (status != Status::OK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage());
}
+ }
+ return Status::OK();
+}
- TensorrtLogger& trt_logger = GetTensorrtLogger();
- auto trt_builder = GetBuilder();
- const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
- auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(explicitBatch));
- auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig());
- auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger));
- trt_parser->parse(string_buf.data(), string_buf.size(), model_path_);
- trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_);
-
- // Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow
- if (fp16_enable_ && layer_norm_fp32_fallback_) {
- for (auto idx = 1; idx < trt_network->getNbLayers() - 1; ++idx) {
- auto layer = trt_network->getLayer(idx);
- auto next_layer = trt_network->getLayer(idx + 1);
- if (layer->getType() == nvinfer1::LayerType::kELEMENTWISE && next_layer->getType() == nvinfer1::LayerType::kREDUCE && (static_cast(layer))->getOperation() == nvinfer1::ElementWiseOperation::kPOW) {
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow";
- layer->setPrecision(nvinfer1::DataType::kFLOAT);
- next_layer->setPrecision(nvinfer1::DataType::kFLOAT);
- layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
- next_layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
- }
- }
- }
-
- int num_inputs = trt_network->getNbInputs();
- int num_outputs = trt_network->getNbOutputs();
- std::unordered_map input_indexes(num_inputs);
- std::unordered_map output_indexes(num_outputs);
- std::unordered_map output_types(num_outputs);
+Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphViewer& graph_body_viewer,
+ const Node& fused_node,
+ std::unordered_map& input_map,
+ std::unordered_map& output_map,
+ std::vector& node_compute_funcs) {
+ // Reconstruct graph proto from fused node's function body
+ auto model = graph_body_viewer.CreateModel(*GetLogger());
+ auto model_proto = model->ToProto();
+ graph_body_viewer.ToProto(*model_proto->mutable_graph(), true, true);
+ model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
+ std::string string_buf;
+ model_proto->SerializeToString(string_buf);
+
+ if (dump_subgraphs_) {
+ // Dump TensorRT subgraphs
+ std::fstream dump(fused_node.Name() + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
+ model_proto->SerializeToOstream(dump);
+ }
+
+ TensorrtLogger& trt_logger = GetTensorrtLogger();
+ auto trt_builder = GetBuilder();
+ const auto explicitBatch = 1U << static_cast(nvinfer1::NetworkDefinitionCreationFlag::kEXPLICIT_BATCH);
+ auto trt_network = std::unique_ptr(trt_builder->createNetworkV2(explicitBatch));
+ auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig());
+ auto trt_parser = tensorrt_ptr::unique_pointer(nvonnxparser::createParser(*trt_network, trt_logger));
+ trt_parser->parse(string_buf.data(), string_buf.size(), model_path_);
+ trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, max_workspace_size_);
+
+ // Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow
+ if (fp16_enable_ && layer_norm_fp32_fallback_) {
+ for (auto idx = 1; idx < trt_network->getNbLayers() - 1; ++idx) {
+ auto layer = trt_network->getLayer(idx);
+ auto next_layer = trt_network->getLayer(idx + 1);
+ if (layer->getType() == nvinfer1::LayerType::kELEMENTWISE && next_layer->getType() == nvinfer1::LayerType::kREDUCE && (static_cast(layer))->getOperation() == nvinfer1::ElementWiseOperation::kPOW) {
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Force Pow + Reduce ops in layer norm to run in FP32 to avoid overflow";
+ layer->setPrecision(nvinfer1::DataType::kFLOAT);
+ next_layer->setPrecision(nvinfer1::DataType::kFLOAT);
+ layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
+ next_layer->setOutputType(0, nvinfer1::DataType::kFLOAT);
+ }
+ }
+ }
+
+ int num_inputs = trt_network->getNbInputs();
+ int num_outputs = trt_network->getNbOutputs();
+ std::unordered_map input_indexes(num_inputs);
+ std::unordered_map output_indexes(num_outputs);
+ std::unordered_map output_types(num_outputs);
- /*
- * Initialize shape range for each dynamic shape input tensor:
- * 1) If user explicitly specifies optimization profiles via provider options, TRT EP will create those profiles during EP compile time.
- * It won't make adjustment for profile values during EP compute time.
- *
- * 2) If no explicit optimization profiles provided by user, TRT EP will firstly set min/max/opt shape to [INT_MAX, INT_MIN, INT_MIN].
- * Later in EP compute time, the shape will be adjusted to [min_input_value, max_input_value, max_input_value] based on input tensor value.
- *
- *
- * Once the TRT profiles are created:
- * 1) If all the dynamic shape input tensors have associated profiles explicitly provided by user, those profiles will be applied to TRT builder config
- * and the engine will be built at EP compile time.
- *
- * 2) As long as one of the dynamic shape input tensors has no explicitly associated profile, TRT EP will create default shape as described above,
- * and all the profiles won't be applied and engine won't be built until EP compute time.
- */
- bool has_dynamic_shape = false; // True if input tensor has dynamic shape and no explicit profile is specified, otherwise false.
- bool has_explicit_profile = false;
- bool apply_explicit_profile = false;
- int num_profiles = 0;
- std::vector trt_profiles;
-
- // Following c++ map data structure is used to help serialize/deserialize profiles where it saves dynamic shape dimension(s) and min/max/opt values for dynamic shape input tensor.
- //
- // (1) Single profile case:
- // For example, assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, and tensor_b
- // has one dynamic shape dimension: dim_1. The data will be:
- // {
- // tensor_a: {
- // dim_0: [[min_shape, max_shape, opt_shape]],
- // dim_2: [[min_shape, max_shape, opt_shape]]
- // },
- // tensor_b: {
- // dim_1: [[min_shape, max_shape, opt_shape]]
- // }
- // }
- //
- // (2) Multiple profiles case:
- // For example, assume tensor_a has one dynamic shap dimension: dim 0, and tensor_b has one dynamic shape dimension: dim_1,
- // and both of the tensors have two profiles. The data will be:
- // {
- // tensor_a: {
- // dim_0: [[min_shape_0, max_shape_0, opt_shape_0], [min_shape_1, max_shape_1, opt_shape_1]]
- // },
- // tensor_b: {
- // dim_1: [[min_shape_2, max_shape_2, opt_shape_2], [min_shape_3, max_shape_3, opt_shape_3]]
- // }
- // }
- ShapeRangesMap input_explicit_shape_ranges;
- ShapeRangesMap input_implicit_shape_ranges;
-
- if ((!profile_min_shapes_.empty()) && (!profile_max_shapes_.empty()) && (!profile_opt_shapes_.empty())) {
- has_explicit_profile = true;
- num_profiles = GetNumProfiles(profile_min_shapes_);
- for (int i = 0; i < num_profiles; i++) {
- trt_profiles.push_back(trt_builder->createOptimizationProfile());
- }
- }
-
- // Iterate all input tensors to check dynamic shape
- for (unsigned int i = 0, end = num_inputs; i < end; ++i) {
- auto input = trt_network->getInput(i);
- const std::string& input_name = input->getName();
- nvinfer1::Dims dims = input->getDimensions();
- int nb_dims = dims.nbDims;
-
- // Apply explicit optimization profiles provided by user
- if (has_explicit_profile) {
- apply_explicit_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges);
- }
+ /*
+ * Initialize shape range for each dynamic shape input tensor:
+ * 1) If user explicitly specifies optimization profiles via provider options, TRT EP will create those profiles during EP compile time.
+ * It won't make adjustment for profile values during EP compute time.
+ *
+ * 2) If no explicit optimization profiles provided by user, TRT EP will firstly set min/max/opt shape to [INT_MAX, INT_MIN, INT_MIN].
+ * Later in EP compute time, the shape will be adjusted to [min_input_value, max_input_value, max_input_value] based on input tensor value.
+ *
+ *
+ * Once the TRT profiles are created:
+ * 1) If all the dynamic shape input tensors have associated profiles explicitly provided by user, those profiles will be applied to TRT builder config
+ * and the engine will be built at EP compile time.
+ *
+ * 2) As long as one of the dynamic shape input tensors has no explicitly associated profile, TRT EP will create default shape as described above,
+ * and all the profiles won't be applied and engine won't be built until EP compute time.
+ */
+ bool has_dynamic_shape = false; // True if input tensor has dynamic shape and no explicit profile is specified, otherwise false.
+ bool has_explicit_profile = false;
+ bool apply_explicit_profile = false;
+ int num_profiles = 0;
+ std::vector trt_profiles;
- // If no explicit optimization profile is being applied, TRT EP will later set min/max/opt shape values based on input tensor values at EP compute time
- if (!apply_explicit_profile) {
- if (input->isShapeTensor()) {
- // Shape tensor
- std::vector> profile_vector;
- std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN};
- profile_vector.push_back(shape_vector); // only one profile needed
- input_implicit_shape_ranges[input_name][0] = profile_vector;
- has_dynamic_shape = true;
- } else {
- // Execution tensor
- for (int j = 0, end = nb_dims; j < end; ++j) {
- if (dims.d[j] == -1) {
- std::vector> profile_vector;
- std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN};
- profile_vector.push_back(shape_vector); // only one profile needed
- input_implicit_shape_ranges[input_name][j] = profile_vector;
- has_dynamic_shape = true;
- }
- }
- }
- apply_explicit_profile = false;
- }
+ // Following c++ map data structure is used to help serialize/deserialize profiles where it saves dynamic shape dimension(s) and min/max/opt values for dynamic shape input tensor.
+ //
+ // (1) Single profile case:
+ // For example, assume tensor_a has two dynamic shape dimensions: dim_0 and dim_2, and tensor_b
+ // has one dynamic shape dimension: dim_1. The data will be:
+ // {
+ // tensor_a: {
+ // dim_0: [[min_shape, max_shape, opt_shape]],
+ // dim_2: [[min_shape, max_shape, opt_shape]]
+ // },
+ // tensor_b: {
+ // dim_1: [[min_shape, max_shape, opt_shape]]
+ // }
+ // }
+ //
+ // (2) Multiple profiles case:
+ // For example, assume tensor_a has one dynamic shap dimension: dim 0, and tensor_b has one dynamic shape dimension: dim_1,
+ // and both of the tensors have two profiles. The data will be:
+ // {
+ // tensor_a: {
+ // dim_0: [[min_shape_0, max_shape_0, opt_shape_0], [min_shape_1, max_shape_1, opt_shape_1]]
+ // },
+ // tensor_b: {
+ // dim_1: [[min_shape_2, max_shape_2, opt_shape_2], [min_shape_3, max_shape_3, opt_shape_3]]
+ // }
+ // }
+ ShapeRangesMap input_explicit_shape_ranges;
+ ShapeRangesMap input_implicit_shape_ranges;
+
+ if ((!profile_min_shapes_.empty()) && (!profile_max_shapes_.empty()) && (!profile_opt_shapes_.empty())) {
+ has_explicit_profile = true;
+ num_profiles = GetNumProfiles(profile_min_shapes_);
+ for (int i = 0; i < num_profiles; i++) {
+ trt_profiles.push_back(trt_builder->createOptimizationProfile());
}
+ }
- // Set explicit profiles in TRT config if all dynamic shape inputs have associated profiles provided by user
+ // Iterate all input tensors to check dynamic shape
+ for (unsigned int i = 0, end = num_inputs; i < end; ++i) {
+ auto input = trt_network->getInput(i);
+ const std::string& input_name = input->getName();
+ nvinfer1::Dims dims = input->getDimensions();
+ int nb_dims = dims.nbDims;
+
+ // Apply explicit optimization profiles provided by user
if (has_explicit_profile) {
- // TRT EP has a constraint here.
- // Users need to provide all the dynamic shape inputs with associated profiles if they want to explicitly specify profiles through provider options.
- if (has_dynamic_shape) {
- std::ostringstream msg;
- msg << "User needs to provide all the dynamic shape inputs with associated profiles if they want to explicitly set profiles through provider options.\n";
- msg << "Please note that main graph could be partitioned into TRT/CUDA/CPU subgraphs, in this case, user also needs to provide shape profiles for the TRT subgraph's input if it's dynamic shape input.\n";
- msg << "Following input(s) has no associated shape profiles provided: ";
- auto begin = input_implicit_shape_ranges.begin();
- auto end = input_implicit_shape_ranges.end();
- auto it = begin;
- if (it != end) {
- msg << it->first;
- ++it;
- }
- for (; it != end; ++it) {
- msg << "," << it->first;
- }
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, msg.str());
+ apply_explicit_profile = ApplyProfileShapesFromProviderOptions(trt_profiles, input, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_, input_explicit_shape_ranges);
+ }
+
+ // If no explicit optimization profile is being applied, TRT EP will later set min/max/opt shape values based on input tensor values at EP compute time
+ if (!apply_explicit_profile) {
+ if (input->isShapeTensor()) {
+ // Shape tensor
+ std::vector> profile_vector;
+ std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN};
+ profile_vector.push_back(shape_vector); // only one profile needed
+ input_implicit_shape_ranges[input_name][0] = profile_vector;
+ has_dynamic_shape = true;
} else {
- for (auto trt_profile : trt_profiles) {
- trt_config->addOptimizationProfile(trt_profile);
+ // Execution tensor
+ for (int j = 0, end = nb_dims; j < end; ++j) {
+ if (dims.d[j] == -1) {
+ std::vector> profile_vector;
+ std::vector shape_vector{INT_MAX, INT_MIN, INT_MIN};
+ profile_vector.push_back(shape_vector); // only one profile needed
+ input_implicit_shape_ranges[input_name][j] = profile_vector;
+ has_dynamic_shape = true;
+ }
}
}
+ apply_explicit_profile = false;
}
- // If no explicit profile is applied and the input has dynamic shape, TRT EP simply creates one profile by default.
- // It will later set proper min/max/opt shape values duing EP compute time.
- else if (!has_explicit_profile && has_dynamic_shape) {
- trt_profiles.push_back(trt_builder->createOptimizationProfile());
- }
+ }
- // Check platform availability for low precision
- if (fp16_enable_) {
- if (!trt_builder->platformHasFastFp16()) {
- fp16_enable_ = false;
- LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE is set, but platform doesn't support fast native fp16";
+ // Set explicit profiles in TRT config if all dynamic shape inputs have associated profiles provided by user
+ if (has_explicit_profile) {
+ // TRT EP has a constraint here.
+ // Users need to provide all the dynamic shape inputs with associated profiles if they want to explicitly specify profiles through provider options.
+ if (has_dynamic_shape) {
+ std::ostringstream msg;
+ msg << "User needs to provide all the dynamic shape inputs with associated profiles if they want to explicitly set profiles through provider options.\n";
+ msg << "Please note that main graph could be partitioned into TRT/CUDA/CPU subgraphs, in this case, user also needs to provide shape profiles for the TRT subgraph's input if it's dynamic shape input.\n";
+ msg << "Following input(s) has no associated shape profiles provided: ";
+ auto begin = input_implicit_shape_ranges.begin();
+ auto end = input_implicit_shape_ranges.end();
+ auto it = begin;
+ if (it != end) {
+ msg << it->first;
+ ++it;
+ }
+ for (; it != end; ++it) {
+ msg << "," << it->first;
+ }
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, msg.str());
+ } else {
+ for (auto trt_profile : trt_profiles) {
+ trt_config->addOptimizationProfile(trt_profile);
}
}
+ }
+ // If no explicit profile is applied and the input has dynamic shape, TRT EP simply creates one profile by default.
+ // It will later set proper min/max/opt shape values duing EP compute time.
+ else if (!has_explicit_profile && has_dynamic_shape) {
+ trt_profiles.push_back(trt_builder->createOptimizationProfile());
+ }
- if (int8_enable_) {
- if (!trt_builder->platformHasFastInt8()) {
- int8_enable_ = false;
- LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8";
- }
- }
-
- // Load INT8 calibration table
- std::unordered_map dynamic_range_map;
- if (int8_enable_ && int8_calibration_cache_available_) {
- const std::string calibration_cache_path = GetCachePath(cache_path_, int8_calibration_cache_name_);
- if (!ReadDynamicRange(calibration_cache_path, int8_use_native_tensorrt_calibration_table_, dynamic_range_map)) {
- throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path);
- }
- }
-
- // Set precision flags
- std::string trt_node_name_with_precision = fused_node.Name();
- if (fp16_enable_ && int8_enable_) {
- trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8));
- trt_node_name_with_precision += "_fp16_int8";
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 and INT8 mode is enabled";
- } else if (fp16_enable_) {
- trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
- trt_node_name_with_precision += "_fp16";
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled";
- } else if (int8_enable_) {
- trt_config->setFlag(nvinfer1::BuilderFlag::kINT8);
- trt_node_name_with_precision += "_int8";
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled";
- }
-
- // Set DLA
- if (fp16_enable_ || int8_enable_) {
- if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8
- int number_of_dla_core = trt_builder->getNbDLACores();
- if (number_of_dla_core == 0) {
- LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core";
- dla_enable_ = false;
- } else {
- if (dla_core_ >= number_of_dla_core) {
- LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core #" << dla_core_ << ", but it exceeds platform's maximum DLA core number " << number_of_dla_core << ". Use DLA core 0 instead.";
- dla_core_ = 0;
- }
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << dla_core_;
- trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
- trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
- trt_config->setDLACore(dla_core_);
- trt_node_name_with_precision += "_dlacore" + std::to_string(dla_core_);
+ // Check platform availability for low precision
+ if (fp16_enable_) {
+ if (!trt_builder->platformHasFastFp16()) {
+ fp16_enable_ = false;
+ LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_FP16_ENABLE is set, but platform doesn't support fast native fp16";
+ }
+ }
+
+ if (int8_enable_) {
+ if (!trt_builder->platformHasFastInt8()) {
+ int8_enable_ = false;
+ LOGS_DEFAULT(WARNING) << "[TensorRT EP] ORT_TENSORRT_INT8_ENABLE is set, but platform doesn't support fast native int8";
+ }
+ }
+
+ // Load INT8 calibration table
+ std::unordered_map dynamic_range_map;
+ if (int8_enable_ && int8_calibration_cache_available_) {
+ const std::string calibration_cache_path = GetCachePath(cache_path_, int8_calibration_cache_name_);
+ if (!ReadDynamicRange(calibration_cache_path, int8_use_native_tensorrt_calibration_table_, dynamic_range_map)) {
+ throw std::runtime_error("Failed to read INT8 calibration table " + calibration_cache_path);
+ }
+ }
+
+ // Set precision flags
+ std::string trt_node_name_with_precision = fused_node.Name();
+ if (fp16_enable_ && int8_enable_) {
+ trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8));
+ trt_node_name_with_precision += "_fp16_int8";
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 and INT8 mode is enabled";
+ } else if (fp16_enable_) {
+ trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
+ trt_node_name_with_precision += "_fp16";
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] FP16 mode is enabled";
+ } else if (int8_enable_) {
+ trt_config->setFlag(nvinfer1::BuilderFlag::kINT8);
+ trt_node_name_with_precision += "_int8";
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] INT8 mode is enabled";
+ }
+
+ // Set DLA
+ if (fp16_enable_ || int8_enable_) {
+ if (dla_enable_ && dla_core_ >= 0) { // DLA can only run with FP16 and INT8
+ int number_of_dla_core = trt_builder->getNbDLACores();
+ if (number_of_dla_core == 0) {
+ LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core, but platform doesn't have any DLA core";
+ dla_enable_ = false;
+ } else {
+ if (dla_core_ >= number_of_dla_core) {
+ LOGS_DEFAULT(WARNING) << "[TensorRT EP] Try to use DLA core #" << dla_core_ << ", but it exceeds platform's maximum DLA core number " << number_of_dla_core << ". Use DLA core 0 instead.";
+ dla_core_ = 0;
}
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << dla_core_;
+ trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
+ trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
+ trt_config->setDLACore(dla_core_);
+ trt_node_name_with_precision += "_dlacore" + std::to_string(dla_core_);
}
}
+ }
+
+ // enable sparse weights
+ if (sparsity_enable_) {
+ trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed";
+ }
- // enable sparse weights
- if (sparsity_enable_) {
- trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed";
- }
#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5
- if (build_heuristics_enable_) {
- trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC);
- LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are enabled."
- << " For TRT > 8.5, trt_build_heuristics_enable is deprecated, please set builder optimization level as 2 to enable builder heuristics.";
- }
+ if (build_heuristics_enable_) {
+ trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC);
+ LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are enabled."
+ << " For TRT > 8.5, trt_build_heuristics_enable is deprecated, please set builder optimization level as 2 to enable builder heuristics.";
+ }
#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8
- // for TRT 8.6 onwards, heuristic-based tactic option is automatically enabled by setting builder optimization level 2
- if (build_heuristics_enable_) {
- if (builder_optimization_level_ == 2) {
- LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are automatically enabled by builder optimization level 2. trt_build_heuristics_enable is deprecated on TRT 8.6 onwards.";
- } else {
- LOGS_DEFAULT(WARNING) << "[TensorRT EP] trt_build_heuristics_enable is deprecated on TRT 8.6 onwards. Please set builder optimization level as 2 to enable builder heuristics.";
- }
+ // for TRT 8.6 onwards, heuristic-based tactic option is automatically enabled by setting builder optimization level 2
+ if (build_heuristics_enable_) {
+ if (builder_optimization_level_ == 2) {
+ LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder heuristics are automatically enabled by builder optimization level 2. trt_build_heuristics_enable is deprecated on TRT 8.6 onwards.";
+ } else {
+ LOGS_DEFAULT(WARNING) << "[TensorRT EP] trt_build_heuristics_enable is deprecated on TRT 8.6 onwards. Please set builder optimization level as 2 to enable builder heuristics.";
}
+ }
#endif
#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8
- // switch optimizaion level
- if (builder_optimization_level_ != 3) {
- trt_config->setBuilderOptimizationLevel(builder_optimization_level_);
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_;
- }
+ // switch optimizaion level
+ if (builder_optimization_level_ != 3) {
+ trt_config->setBuilderOptimizationLevel(builder_optimization_level_);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_;
+ }
- // limit auxiliary streams
- if (auxiliary_streams_ >= 0) {
- trt_config->setMaxAuxStreams(auxiliary_streams_);
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << auxiliary_streams_;
- }
+ // limit auxiliary streams
+ if (auxiliary_streams_ >= 0) {
+ trt_config->setMaxAuxStreams(auxiliary_streams_);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << auxiliary_streams_;
+ }
#else
- if (builder_optimization_level_ != 3) {
- LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!";
- }
- if (auxiliary_streams_ >= 0) {
- LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!";
- }
+ if (builder_optimization_level_ != 3) {
+ LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!";
+ }
+ if (auxiliary_streams_ >= 0) {
+ LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!";
+ }
#endif
- // limit used tactic sources
- if (!tactic_sources_.empty()) {
- nvinfer1::TacticSources tactics = trt_config->getTacticSources();
- tactics |= GetTacticSourceFromString(tactic_sources_);
- trt_config->setTacticSources(tactics);
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using " << tactic_sources_;
- }
-
- // Build TRT engine (if needed) and load TRT engine if:
- // (1) Graph has no dynamic shape input
- // (2) All the dynamic shape inputs have associated explicit profiles specified by user
- //
- // Otherwise engine will be handled at inference time.
- std::unique_ptr trt_engine;
- std::unique_ptr trt_context;
- // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
- // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity
- if (!has_dynamic_shape) {
- const std::string cache_path = GetCachePath(cache_path_, trt_node_name_with_precision);
- const std::string engine_cache_path = cache_path + "_sm" + compute_capability_ + ".engine";
- const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
- const std::string profile_cache_path = cache_path + "_sm" + compute_capability_ + ".profile";
- std::string timing_cache_path = "";
- bool engine_update = false;
- if (timing_cache_enable_) {
- timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_);
- }
- {
- // ifstream file check, engine serialization/deserialization and engine build are in critical section. It needs lock protection to prevent race condition when inferencing with multithreading.
- auto lock = GetApiLock();
+ // limit used tactic sources
+ if (!tactic_sources_.empty()) {
+ nvinfer1::TacticSources tactics = trt_config->getTacticSources();
+ tactics |= GetTacticSourceFromString(tactic_sources_);
+ trt_config->setTacticSources(tactics);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using " << tactic_sources_;
+ }
- // If explicit profile flag is on and engine cache enable flag is on,
- // we need to compare explicit profiles and profiles used to build the engine in order to decide whether to rebuild the engine.
- if (has_explicit_profile && engine_cache_enable_) {
- engine_update = CompareProfiles(profile_cache_path, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_);
- if (engine_update) {
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine will be built";
- } else {
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine won't be rebuilt";
- }
- }
+ // Build TRT engine (if needed) and load TRT engine if:
+ // (1) Graph has no dynamic shape input
+ // (2) All the dynamic shape inputs have associated explicit profiles specified by user
+ //
+ // Otherwise engine will be handled at inference time.
+ std::unique_ptr trt_engine;
+ std::unique_ptr trt_context;
+
+ // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
+ // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity
+ const std::string cache_path = GetCachePath(cache_path_, trt_node_name_with_precision);
+ const std::string cache_path_prefix = cache_path + "_sm" + compute_capability_;
+ const std::string engine_cache_path = cache_path_prefix + ".engine";
+ const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
+ const std::string profile_cache_path = cache_path_prefix + ".profile";
+
+ if (!has_dynamic_shape) {
+ std::string timing_cache_path = "";
+ bool engine_update = false;
+ if (timing_cache_enable_) {
+ timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_);
+ }
+ {
+ // ifstream file check, engine serialization/deserialization and engine build are in critical section. It needs lock protection to prevent race condition when inferencing with multithreading.
+ auto lock = GetApiLock();
- std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
- if (engine_cache_enable_ && !engine_decryption_enable_ && engine_file && !engine_update) {
- engine_file.seekg(0, std::ios::end);
- size_t engine_size = engine_file.tellg();
- engine_file.seekg(0, std::ios::beg);
- std::unique_ptr engine_buf{new char[engine_size]};
- engine_file.read((char*)engine_buf.get(), engine_size);
- trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size));
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
- if (trt_engine == nullptr) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT EP could not deserialize engine from cache: " + engine_cache_path);
- }
- } else if (engine_decryption_enable_ && engine_cache_enable_ && std::filesystem::exists(encrypted_engine_cache_path) && !engine_update) {
- // Decrypt engine
- size_t engine_size = 0;
- if (!engine_decryption_(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT EP could not get engine buffer size");
- }
- std::unique_ptr engine_buf{new char[engine_size]};
- if (!engine_decryption_(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT EP could not call engine decryption function decrypt");
- }
- // Deserialize engine
- trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size));
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path;
- if (trt_engine == nullptr) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path);
- }
+ // If explicit profile flag is on and engine cache enable flag is on,
+ // we need to compare explicit profiles and profiles used to build the engine in order to decide whether to rebuild the engine.
+ if (has_explicit_profile && engine_cache_enable_) {
+ engine_update = CompareProfiles(profile_cache_path, profile_min_shapes_, profile_max_shapes_, profile_opt_shapes_);
+ if (engine_update) {
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine will be built";
} else {
- // Set INT8 per tensor dynamic range
- if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) {
- trt_config->setInt8Calibrator(nullptr);
- if (!SetDynamicRange(*trt_network, dynamic_range_map)) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node.Name());
- }
- }
-
- // Load timing cache from file. Create a fresh cache if the file doesn't exist
- std::unique_ptr timing_cache = nullptr;
- if (timing_cache_enable_) {
- std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path);
- timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size()));
- if (timing_cache == nullptr) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT EP could not create timing cache: " + timing_cache_path);
- }
- trt_config->setTimingCache(*timing_cache, force_timing_cache_match_);
- if (detailed_build_log_) {
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path;
- }
- }
-
- // Build engine
- std::chrono::steady_clock::time_point engine_build_start;
- if (detailed_build_log_) {
- engine_build_start = std::chrono::steady_clock::now();
- }
- std::unique_ptr serialized_engine{trt_builder->buildSerializedNetwork(*trt_network, *trt_config)};
- if (serialized_engine == nullptr) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT EP failed to create engine from network for fused node: " + fused_node.Name());
- }
- trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size()));
- if (trt_engine == nullptr) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT EP failed to deserialize engine for fused node: " + fused_node.Name());
- }
- if (detailed_build_log_) {
- auto engine_build_stop = std::chrono::steady_clock::now();
- LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl;
- }
- if (engine_cache_enable_) {
- // Serialize engine profile if it has explicit profiles
- if (has_explicit_profile) {
- SerializeProfileV2(profile_cache_path, input_explicit_shape_ranges);
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path;
- }
-
- if (engine_decryption_enable_) {
- // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first.
- if (engine_encryption_ != nullptr) {
- if (!engine_encryption_(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT EP call to engine encryption library failed");
- }
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path;
- } else {
- LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk";
- }
- } else {
- std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
- file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size());
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path;
- }
- }
- // serialize and save timing cache
- if (timing_cache_enable_) {
- auto timing_cache = trt_config->getTimingCache();
- std::unique_ptr timingCacheHostData{timing_cache->serialize()};
- if (timingCacheHostData == nullptr) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT EP could not serialize timing cache: " + timing_cache_path);
- }
- saveTimingCacheFile(timing_cache_path, timingCacheHostData.get());
- if (detailed_build_log_) {
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path;
- }
- }
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Engine won't be rebuilt";
}
}
- // Build context
- // Note: Creating an execution context from an engine is thread safe per TRT doc
- // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
- if (context_memory_sharing_enable_) {
- size_t mem_size = trt_engine->getDeviceMemorySize();
- if (mem_size > max_ctx_mem_size_) {
- max_ctx_mem_size_ = mem_size;
+ std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
+ if (engine_cache_enable_ && !engine_decryption_enable_ && engine_file && !engine_update) {
+ engine_file.seekg(0, std::ios::end);
+ size_t engine_size = engine_file.tellg();
+ engine_file.seekg(0, std::ios::beg);
+ std::unique_ptr engine_buf{new char[engine_size]};
+ engine_file.read((char*)engine_buf.get(), engine_size);
+ trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size));
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
+ if (trt_engine == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP could not deserialize engine from cache: " + engine_cache_path);
}
- trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory());
- } else {
- trt_context = std::unique_ptr(trt_engine->createExecutionContext());
- }
- if (!trt_context) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT EP could not build execution context for fused node: " + fused_node.Name());
- }
- }
-
- // Create input to index map
- for (int i = 0; i < num_inputs; ++i) {
- auto input = trt_network->getInput(i);
- const std::string& input_name = input->getName();
- const auto& iter = input_map.find(input_name);
- if (iter != input_map.end()) {
- input_indexes[input_name] = iter->second;
- }
- }
-
- // Create output to index and type maps
- const auto& graph_output = model_proto->graph().output();
- for (int i = 0; i < num_outputs; ++i) {
- const std::string& output_name = trt_network->getOutput(i)->getName();
- const auto& iter = output_map.find(output_name);
- if (iter != output_map.end()) {
- output_indexes[output_name] = iter->second;
- }
- const auto& tensor_type = graph_output[i].type().tensor_type();
- output_types[output_name] = tensor_type.elem_type();
- }
-
- // Save TRT engine, other TRT objects and input/output info to map
- parsers_.emplace(fused_node.Name(), std::move(trt_parser));
- engines_.emplace(fused_node.Name(), std::move(trt_engine));
- contexts_.emplace(fused_node.Name(), std::move(trt_context));
- networks_.emplace(fused_node.Name(), std::move(trt_network));
- input_info_[fused_node.Name()].push_back(input_indexes);
- output_info_[fused_node.Name()].push_back(output_indexes);
- output_info_[fused_node.Name()].push_back(output_types);
- input_shape_ranges_[fused_node.Name()] = input_implicit_shape_ranges;
- profiles_.emplace(fused_node.Name(), std::move(trt_profiles));
-
- // Create function state
- // TODO: remove default capture
- NodeComputeInfo compute_info;
- compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
- std::unique_ptr p = std::make_unique();
- // translate tactic sources string to nvinfer1::TacticSources
- nvinfer1::TacticSources tactics = 0;
- if (!tactic_sources_.empty()) {
- tactics = GetTacticSourceFromString(tactic_sources_);
- }
- *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(),
- &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name],
- &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
- input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
- dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
- runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_,
- dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_,
- global_cache_path_, force_timing_cache_match_, detailed_build_log_, build_heuristics_enable_, sparsity_enable_,
- builder_optimization_level_, auxiliary_streams_, !tactic_sources_.empty(), tactics};
- *state = p.release();
- return 0;
- };
-
- // Release function state
- compute_info.release_state_func = [](FunctionState state) {
- delete static_cast(state);
- };
-
- // Create compute function
- compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
- Ort::KernelContext ctx(context);
-
- TensorrtFuncState* trt_state = reinterpret_cast(state);
-
- // The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine,
- // save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading.
- // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
- std::lock_guard lock(*(trt_state->tensorrt_mu_ptr));
- const std::unordered_map& input_indexes = (trt_state->input_info)[0];
- const std::unordered_map& output_indexes = (trt_state->output_info)[0];
- const std::unordered_map& output_types = (trt_state->output_info)[1];
- bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue;
- auto fused_node_name = trt_state->fused_node_name;
- auto& shape_ranges = trt_state->input_shape_ranges;
- auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name];
- auto trt_builder = trt_state->builder;
- auto trt_engine = trt_state->engine->get();
- auto trt_context = trt_state->context->get();
- auto trt_profiles = trt_state->profiles;
- auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr;
- int num_inputs = static_cast(input_indexes.size());
- int num_outputs = static_cast(output_indexes.size());
- bool engine_update = false;
- bool context_update = false;
- std::unordered_set input_names;
- std::unordered_map> tensor_shape_values;
-
- OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id_), device_id_);
- if (alloc_ == nullptr) {
- Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_));
- }
- OrtAllocator* alloc = alloc_;
-
- void* cuda_stream;
- Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream));
- cudaStream_t stream = static_cast(cuda_stream);
-
- // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
- // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity
- // Prepare cache name
- const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
- const std::string engine_cache_path = cache_path + "_sm" + compute_capability_ + ".engine";
- const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
- const std::string profile_cache_path = cache_path + "_sm" + compute_capability_ + ".profile";
- std::string timing_cache_path = "";
- if (timing_cache_enable_) {
- timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_);
- }
-
- // Load serialized engine
- if (trt_state->engine_cache_enable && trt_engine == nullptr) {
- std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
- std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in);
- if (engine_file && !trt_state->engine_decryption_enable && profile_file) {
- // Deserialize profile
- shape_ranges = DeserializeProfileV2(profile_file);
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
-
- // Prepare buffer
- engine_file.seekg(0, std::ios::end);
- size_t engine_size = engine_file.tellg();
- engine_file.seekg(0, std::ios::beg);
- std::unique_ptr engine_buf{new char[engine_size]};
- engine_file.read((char*)engine_buf.get(), engine_size);
-
- // Deserialize engine
- // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc
- // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
- trt_state->engine->reset();
- *(trt_state->engine) = std::unique_ptr(
- trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size));
- if (!(*(trt_state->engine))) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
- }
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
- trt_engine = trt_state->engine->get();
- context_update = true;
- } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) {
- shape_ranges = DeserializeProfileV2(profile_file);
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
- // Decrypt engine
- size_t engine_size = 0;
- if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT EP could not get engine buffer size");
- }
- std::unique_ptr engine_buf{new char[engine_size]};
- if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT EP could not call engine decryption function decrypt");
- }
- // Deserialize engine
- // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc
- // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
- trt_state->engine->reset();
- *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size));
- if (!(*(trt_state->engine))) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path);
- }
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path;
- trt_engine = trt_state->engine->get();
- context_update = true;
+ } else if (engine_decryption_enable_ && engine_cache_enable_ && std::filesystem::exists(encrypted_engine_cache_path) && !engine_update) {
+ // Decrypt engine
+ size_t engine_size = 0;
+ if (!engine_decryption_(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP could not get engine buffer size");
}
- }
-
- // Check and update shape ranges for dynamic shape inputs.
- for (int i = 0, end = num_inputs; i < end; ++i) {
- auto input = trt_state->network->get()->getInput(i);
- const std::string& input_name = input->getName();
- input_names.insert(input_name);
-
- // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved.
- // TRT EP will help determine the min/max/opt profile values based on current input tensor value.
- if (shape_ranges.find(input_name) != shape_ranges.end()) {
- auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, tensor_shape_values, stream, &engine_update);
- if (status != Status::OK()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles.");
- }
+ std::unique_ptr engine_buf{new char[engine_size]};
+ if (!engine_decryption_(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP could not call engine decryption function decrypt");
}
- }
-
- // Regenerate engine
- if (engine_update) {
- // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior.
- trt_state->context->reset();
- trt_state->engine->reset();
- auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig());
- trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, *(trt_state->max_workspace_size_ptr));
- for (auto trt_profile : trt_profiles) {
- trt_config->addOptimizationProfile(trt_profile);
+ // Deserialize engine
+ trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(engine_buf.get(), engine_size));
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path;
+ if (trt_engine == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path);
}
-
- // Set INT8 Per Tensor Dynamic range
- if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) {
+ } else {
+ // Set INT8 per tensor dynamic range
+ if (int8_enable_ && trt_builder->platformHasFastInt8() && int8_calibration_cache_available_) {
trt_config->setInt8Calibrator(nullptr);
- if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range.");
+ if (!SetDynamicRange(*trt_network, dynamic_range_map)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP could not set INT8 dynamic range for fused node: " + fused_node.Name());
}
}
- // Set precision
- if (trt_state->fp16_enable && trt_state->int8_enable) {
- trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8));
- } else if (trt_state->fp16_enable) {
- trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
- } else if (trt_state->int8_enable) {
- trt_config->setFlag(nvinfer1::BuilderFlag::kINT8);
- }
-
- // Set DLA (DLA can only run with FP16 or INT8)
- if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) {
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core;
- trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
- trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
- trt_config->setDLACore(trt_state->dla_core);
- }
-
- // enable sparse weights
- if (trt_state->sparsity_enable) {
- trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed";
- }
-#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5
- // enable builder heuristics
- if (trt_state->build_heuristics_enable) {
- trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC);
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled";
- }
-#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8
- // switch optimizaion level
- if (trt_state->builder_optimization_level != 3) {
- trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level);
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_;
- }
-
- // limit auxiliary streams
- if (trt_state->auxiliary_streams >= 0) {
- trt_config->setMaxAuxStreams(trt_state->auxiliary_streams);
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams;
- }
-#else
- if (trt_state->builder_optimization_level != 3) {
- LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!";
- }
- if (trt_state->auxiliary_streams >= 0) {
- LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!";
- }
-#endif
- // limit used tactic sources
- if (trt_state->filter_tactic_sources) {
- nvinfer1::TacticSources tactics = trt_config->getTacticSources();
- tactics |= trt_state->tactic_sources;
- trt_config->setTacticSources(tactics);
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics;
- }
-
// Load timing cache from file. Create a fresh cache if the file doesn't exist
std::unique_ptr timing_cache = nullptr;
- if (trt_state->timing_cache_enable) {
+ if (timing_cache_enable_) {
std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path);
timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size()));
if (timing_cache == nullptr) {
@@ -3216,44 +2916,37 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector serialized_engine;
- {
- auto lock = GetApiLock();
- std::chrono::steady_clock::time_point engine_build_start;
- if (detailed_build_log_) {
- engine_build_start = std::chrono::steady_clock::now();
- }
- serialized_engine = std::unique_ptr(
- trt_builder->buildSerializedNetwork(*trt_state->network->get(), *trt_config));
- if (!serialized_engine) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create engine from network.");
- }
- *(trt_state->engine) = std::unique_ptr(
- trt_state->runtime->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size()));
- if (!(*(trt_state->engine))) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to deserialize engine.");
- }
- if (detailed_build_log_) {
- auto engine_build_stop = std::chrono::steady_clock::now();
- LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl;
- }
+ std::chrono::steady_clock::time_point engine_build_start;
+ if (detailed_build_log_) {
+ engine_build_start = std::chrono::steady_clock::now();
}
- if (!(*(trt_state->engine))) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
+ std::unique_ptr serialized_engine{trt_builder->buildSerializedNetwork(*trt_network, *trt_config)};
+ if (serialized_engine == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP failed to create engine from network for fused node: " + fused_node.Name());
}
- trt_engine = trt_state->engine->get();
- if (trt_state->engine_cache_enable) {
- // Serialize engine profile
- SerializeProfileV2(profile_cache_path, shape_ranges);
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path;
+ trt_engine = std::unique_ptr(runtime_->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size()));
+ if (trt_engine == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP failed to deserialize engine for fused node: " + fused_node.Name());
+ }
+ if (detailed_build_log_) {
+ auto engine_build_stop = std::chrono::steady_clock::now();
+ LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl;
+ }
+ if (engine_cache_enable_) {
+ // Serialize engine profile if it has explicit profiles
+ if (has_explicit_profile) {
+ SerializeProfileV2(profile_cache_path, input_explicit_shape_ranges);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path;
+ }
- // Serialize engine
- if (trt_state->engine_decryption_enable) {
+ if (engine_decryption_enable_) {
// Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first.
- if (trt_state->engine_encryption != nullptr) {
- if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) {
+ if (engine_encryption_ != nullptr) {
+ if (!engine_encryption_(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
- "TensorRT EP could not call engine encryption function encrypt");
+ "TensorRT EP call to engine encryption library failed");
}
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path;
} else {
@@ -3262,12 +2955,11 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector(serialized_engine->data()), serialized_engine->size());
- LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized engine " + engine_cache_path;
}
}
-
// serialize and save timing cache
- if (trt_state->timing_cache_enable) {
+ if (timing_cache_enable_) {
auto timing_cache = trt_config->getTimingCache();
std::unique_ptr timingCacheHostData{timing_cache->serialize()};
if (timingCacheHostData == nullptr) {
@@ -3279,183 +2971,859 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector model_proto{CreateCtxNodeModel(graph_body_viewer,
+ engine_cache_path,
+ reinterpret_cast(serialized_engine->data()),
+ serialized_engine->size(),
+ ep_context_embed_mode_,
+ ep_context_compute_capability_enable_,
+ compute_capability_,
+ GetLogger())};
+ DumpCtxNodeModel(model_proto.get(), cache_path_prefix);
+ }
}
+ }
- if (context_update) {
- if (trt_state->context_memory_sharing_enable) {
- *(trt_state->context) = std::unique_ptr(
- trt_state->engine->get()->createExecutionContextWithoutDeviceMemory());
- } else {
- *(trt_state->context) = std::unique_ptr(
- trt_state->engine->get()->createExecutionContext());
+ // Build context
+ // Note: Creating an execution context from an engine is thread safe per TRT doc
+ // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
+ if (context_memory_sharing_enable_) {
+ size_t mem_size = trt_engine->getDeviceMemorySize();
+ if (mem_size > max_ctx_mem_size_) {
+ max_ctx_mem_size_ = mem_size;
+ }
+ trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory());
+ } else {
+ trt_context = std::unique_ptr(trt_engine->createExecutionContext());
+ }
+ if (!trt_context) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP could not build execution context for fused node: " + fused_node.Name());
+ }
+ }
+
+ // Create input to index map
+ for (int i = 0; i < num_inputs; ++i) {
+ auto input = trt_network->getInput(i);
+ const std::string& input_name = input->getName();
+ const auto& iter = input_map.find(input_name);
+ if (iter != input_map.end()) {
+ input_indexes[input_name] = iter->second;
+ }
+ }
+
+ // Create output to index and type maps
+ const auto& graph_output = model_proto->graph().output();
+ for (int i = 0; i < num_outputs; ++i) {
+ const std::string& output_name = trt_network->getOutput(i)->getName();
+ const auto& iter = output_map.find(output_name);
+ if (iter != output_map.end()) {
+ output_indexes[output_name] = iter->second;
+ }
+ const auto& tensor_type = graph_output[i].type().tensor_type();
+ output_types[output_name] = tensor_type.elem_type();
+ }
+
+ // Save TRT engine, other TRT objects and input/output info to map
+ parsers_.emplace(fused_node.Name(), std::move(trt_parser));
+ engines_.emplace(fused_node.Name(), std::move(trt_engine));
+ contexts_.emplace(fused_node.Name(), std::move(trt_context));
+ networks_.emplace(fused_node.Name(), std::move(trt_network));
+ input_info_[fused_node.Name()].push_back(input_indexes);
+ output_info_[fused_node.Name()].push_back(output_indexes);
+ output_info_[fused_node.Name()].push_back(output_types);
+ input_shape_ranges_[fused_node.Name()] = input_implicit_shape_ranges;
+ profiles_.emplace(fused_node.Name(), std::move(trt_profiles));
+
+ // For dynamic shape input model, firstly TRT EP creates a model proto which includes inputs, outputs and empty engine.
+ // TRT EP will serialize the model at inference time due to engine can be updated and the updated engine should be included in the model.
+ // However, if the embed_mode is 0 (only includes engine path), TRT EP will serialize it here.
+ if (dump_ep_context_model_ && has_dynamic_shape) {
+ model_proto_.reset(CreateCtxNodeModel(graph_body_viewer,
+ engine_cache_path,
+ nullptr,
+ 0,
+ ep_context_embed_mode_,
+ ep_context_compute_capability_enable_,
+ compute_capability_,
+ GetLogger()));
+ if (ep_context_embed_mode_ == 0) {
+ DumpCtxNodeModel(model_proto_.get(), cache_path_prefix);
+ }
+ }
+
+ // Create function state
+ // TODO: remove default capture
+ NodeComputeInfo compute_info;
+ compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
+ std::unique_ptr p = std::make_unique();
+ // translate tactic sources string to nvinfer1::TacticSources
+ nvinfer1::TacticSources tactics = 0;
+ if (!tactic_sources_.empty()) {
+ tactics = GetTacticSourceFromString(tactic_sources_);
+ }
+ *p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name, builder_.get(),
+ &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name],
+ &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
+ input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
+ dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
+ runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_,
+ dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_,
+ global_cache_path_, force_timing_cache_match_, detailed_build_log_, build_heuristics_enable_, sparsity_enable_,
+ builder_optimization_level_, auxiliary_streams_, !tactic_sources_.empty(), tactics};
+ *state = p.release();
+ return 0;
+ };
+
+ // Release function state
+ compute_info.release_state_func = [](FunctionState state) {
+ delete static_cast(state);
+ };
+
+ // Create compute function
+ compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
+ Ort::KernelContext ctx(context);
+
+ TensorrtFuncState* trt_state = reinterpret_cast(state);
+
+ // The whole compute_function should be considered the critical section where multiple threads may update kernel function state, access one builder, create/serialize/save engine,
+ // save profile and serialize/save timing cache. Therefore, those operations should be synchronized across different threads when ORT is using multithreading.
+ // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
+ std::lock_guard lock(*(trt_state->tensorrt_mu_ptr));
+ const std::unordered_map& input_indexes = (trt_state->input_info)[0];
+ const std::unordered_map& output_indexes = (trt_state->output_info)[0];
+ const std::unordered_map& output_types = (trt_state->output_info)[1];
+ bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue;
+ auto fused_node_name = trt_state->fused_node_name;
+ auto& shape_ranges = trt_state->input_shape_ranges;
+ auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name];
+ auto trt_builder = trt_state->builder;
+ auto trt_engine = trt_state->engine->get();
+ auto trt_context = trt_state->context->get();
+ auto trt_profiles = trt_state->profiles;
+ auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr;
+ int num_inputs = static_cast(input_indexes.size());
+ int num_outputs = static_cast(output_indexes.size());
+ bool engine_update = false;
+ bool context_update = false;
+ std::unordered_set input_names;
+ std::unordered_map> tensor_shape_values;
+
+ OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id_), device_id_);
+ if (alloc_ == nullptr) {
+ Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_));
+ }
+ OrtAllocator* alloc = alloc_;
+
+ void* cuda_stream;
+ Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream));
+ cudaStream_t stream = static_cast(cuda_stream);
+
+ // Name the engine cache based on GPU compute capacity and reduce the chance of loading an incompatible cache
+ // Note: Engine cache generated on a GPU with large memory might not be loadable on a GPU with smaller memory, even if they share the same compute capacity
+ // Prepare cache name
+ const std::string cache_path = GetCachePath(trt_state->engine_cache_path, trt_state->trt_node_name_with_precision);
+ const std::string cache_path_prefix = cache_path + "_sm" + compute_capability_;
+ const std::string engine_cache_path = cache_path_prefix + ".engine";
+ const std::string encrypted_engine_cache_path = engine_cache_path + ".encrypted";
+ const std::string profile_cache_path = cache_path_prefix + ".profile";
+ std::string timing_cache_path = "";
+ if (timing_cache_enable_) {
+ timing_cache_path = GetTimingCachePath(global_cache_path_, compute_capability_);
+ }
+
+ // Load serialized engine
+ if (trt_state->engine_cache_enable && trt_engine == nullptr) {
+ std::ifstream engine_file(engine_cache_path, std::ios::binary | std::ios::in);
+ std::ifstream profile_file(profile_cache_path, std::ios::binary | std::ios::in);
+ if (engine_file && !trt_state->engine_decryption_enable && profile_file) {
+ // Deserialize profile
+ shape_ranges = DeserializeProfileV2(profile_file);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
+
+ // Prepare buffer
+ engine_file.seekg(0, std::ios::end);
+ size_t engine_size = engine_file.tellg();
+ engine_file.seekg(0, std::ios::beg);
+ std::unique_ptr engine_buf{new char[engine_size]};
+ engine_file.read((char*)engine_buf.get(), engine_size);
+
+ // Deserialize engine
+ // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc
+ // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
+ trt_state->engine->reset();
+ *(trt_state->engine) = std::unique_ptr(
+ trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size));
+ if (!(*(trt_state->engine))) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
}
- if (!(*(trt_state->context))) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path;
+ trt_engine = trt_state->engine->get();
+ context_update = true;
+ } else if (trt_state->engine_decryption_enable && std::filesystem::exists(encrypted_engine_cache_path) && profile_file) {
+ shape_ranges = DeserializeProfileV2(profile_file);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + profile_cache_path;
+ // Decrypt engine
+ size_t engine_size = 0;
+ if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), nullptr, &engine_size)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP could not get engine buffer size");
+ }
+ std::unique_ptr engine_buf{new char[engine_size]};
+ if (!trt_state->engine_decryption(encrypted_engine_cache_path.c_str(), &engine_buf[0], &engine_size)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP could not call engine decryption function decrypt");
}
- trt_context = trt_state->context->get();
+ // Deserialize engine
+ // Note: Deserializing an engine from a TensorRT runtime is thread safe per TRT doc
+ // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
+ trt_state->engine->reset();
+ *(trt_state->engine) = std::unique_ptr(trt_state->runtime->deserializeCudaEngine(engine_buf.get(), engine_size));
+ if (!(*(trt_state->engine))) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP could not deserialize engine from encrypted cache: " + encrypted_engine_cache_path);
+ }
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Decrypted and DeSerialized " + encrypted_engine_cache_path;
+ trt_engine = trt_state->engine->get();
+ context_update = true;
}
+ }
- // Get input and output binding names
- int total_bindings = trt_engine->getNbIOTensors();
- std::vector input_binding_names, output_binding_names;
- for (int i = 0, end = total_bindings; i < end; ++i) {
- auto const& name = trt_engine->getIOTensorName(i);
- auto const& mode = trt_engine->getTensorIOMode(name);
- if (mode == nvinfer1::TensorIOMode::kINPUT) {
- input_binding_names.push_back(name);
- } else {
- output_binding_names.push_back(name);
+ // Check and update shape ranges for dynamic shape inputs.
+ for (int i = 0, end = num_inputs; i < end; ++i) {
+ auto input = trt_state->network->get()->getInput(i);
+ const std::string& input_name = input->getName();
+ input_names.insert(input_name);
+
+ // If there is any input tensor in shape_ranges, it means this input tensor has dynamic shape and its profile shape values have not yet resolved.
+ // TRT EP will help determine the min/max/opt profile values based on current input tensor value.
+ if (shape_ranges.find(input_name) != shape_ranges.end()) {
+ auto status = ApplyProfileShapesFromInputTensorValue(trt_profiles, ctx, input, shape_ranges, input_indexes, tensor_shape_values, stream, &engine_update);
+ if (status != Status::OK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to parse input tensor and generate optimization profiles.");
}
}
+ }
- /*
- * Set input shapes and bind input buffers
- */
- std::vector> scratch_buffers;
- for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) {
- char const* input_name = input_binding_names[i];
+ // Regenerate engine
+ if (engine_update) {
+ // Destroy the IExecutionContext objects before destroying an engine object, otherwise it will lead to undefined behavior.
+ trt_state->context->reset();
+ trt_state->engine->reset();
+ auto trt_config = std::unique_ptr(trt_builder->createBuilderConfig());
+ trt_config->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, *(trt_state->max_workspace_size_ptr));
+ for (auto trt_profile : trt_profiles) {
+ trt_config->addOptimizationProfile(trt_profile);
+ }
- size_t input_index = 0;
- const auto iter = input_indexes.find(input_name);
- if (iter != input_indexes.end()) {
- input_index = iter->second;
+ // Set INT8 Per Tensor Dynamic range
+ if (trt_state->int8_enable && trt_builder->platformHasFastInt8() && trt_state->int8_calibration_cache_available) {
+ trt_config->setInt8Calibrator(nullptr);
+ if (!SetDynamicRange(*trt_state->network->get(), trt_state->dynamic_range_map)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to set INT8 dynamic range.");
}
- auto input_tensor = ctx.GetInput(input_index);
- auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
- const auto tensor_shapes = tensor_info.GetShape();
+ }
- // Only use for "shape tensor" input
- std::vector shape_values;
- if (tensor_shape_values.find(input_name) != tensor_shape_values.end()) {
- shape_values = tensor_shape_values[input_name];
+ // Set precision
+ if (trt_state->fp16_enable && trt_state->int8_enable) {
+ trt_config->setFlags(1U << static_cast(nvinfer1::BuilderFlag::kFP16) | 1U << static_cast(nvinfer1::BuilderFlag::kINT8));
+ } else if (trt_state->fp16_enable) {
+ trt_config->setFlag(nvinfer1::BuilderFlag::kFP16);
+ } else if (trt_state->int8_enable) {
+ trt_config->setFlag(nvinfer1::BuilderFlag::kINT8);
+ }
+
+ // Set DLA (DLA can only run with FP16 or INT8)
+ if ((trt_state->fp16_enable || trt_state->int8_enable) && trt_state->dla_enable) {
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] use DLA core " << trt_state->dla_core;
+ trt_config->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
+ trt_config->setDefaultDeviceType(nvinfer1::DeviceType::kDLA);
+ trt_config->setDLACore(trt_state->dla_core);
+ }
+
+ // enable sparse weights
+ if (trt_state->sparsity_enable) {
+ trt_config->setFlag(nvinfer1::BuilderFlag::kSPARSE_WEIGHTS);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Sparse weights are allowed";
+ }
+#if NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR == 5
+ // enable builder heuristics
+ if (trt_state->build_heuristics_enable) {
+ trt_config->setFlag(nvinfer1::BuilderFlag::kENABLE_TACTIC_HEURISTIC);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder heuristics are enabled";
+ }
+#elif NV_TENSORRT_MAJOR == 8 && NV_TENSORRT_MINOR > 5 || NV_TENSORRT_MAJOR > 8
+ // switch optimizaion level
+ if (trt_state->builder_optimization_level != 3) {
+ trt_config->setBuilderOptimizationLevel(trt_state->builder_optimization_level);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Builder optimization level is set to " << builder_optimization_level_;
+ }
+
+ // limit auxiliary streams
+ if (trt_state->auxiliary_streams >= 0) {
+ trt_config->setMaxAuxStreams(trt_state->auxiliary_streams);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Auxiliary streams are se to " << trt_state->auxiliary_streams;
+ }
+#else
+ if (trt_state->builder_optimization_level != 3) {
+ LOGS_DEFAULT(WARNING) << "[TensorRT EP] Builder optimization level can only be used on TRT 8.6 onwards!";
+ }
+ if (trt_state->auxiliary_streams >= 0) {
+ LOGS_DEFAULT(WARNING) << "[TensorRT EP] Auxiliary streams can only be set on TRT 8.6 onwards!";
+ }
+#endif
+ // limit used tactic sources
+ if (trt_state->filter_tactic_sources) {
+ nvinfer1::TacticSources tactics = trt_config->getTacticSources();
+ tactics |= trt_state->tactic_sources;
+ trt_config->setTacticSources(tactics);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Tactic sources are limited using bitmask " << tactics;
+ }
+
+ // Load timing cache from file. Create a fresh cache if the file doesn't exist
+ std::unique_ptr timing_cache = nullptr;
+ if (trt_state->timing_cache_enable) {
+ std::vector loaded_timing_cache = loadTimingCacheFile(timing_cache_path);
+ timing_cache.reset(trt_config->createTimingCache(static_cast(loaded_timing_cache.data()), loaded_timing_cache.size()));
+ if (timing_cache == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP could not create timing cache: " + timing_cache_path);
+ }
+ trt_config->setTimingCache(*timing_cache, force_timing_cache_match_);
+ if (detailed_build_log_) {
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Deserialized timing cache from " + timing_cache_path;
}
+ }
- auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_values, scratch_buffers, alloc, stream);
- if (status != Status::OK()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ // Build engine
+ std::unique_ptr serialized_engine;
+ {
+ auto lock = GetApiLock();
+ std::chrono::steady_clock::time_point engine_build_start;
+ if (detailed_build_log_) {
+ engine_build_start = std::chrono::steady_clock::now();
+ }
+ serialized_engine = std::unique_ptr(
+ trt_builder->buildSerializedNetwork(*trt_state->network->get(), *trt_config));
+ if (!serialized_engine) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create engine from network.");
+ }
+ *(trt_state->engine) = std::unique_ptr(
+ trt_state->runtime->deserializeCudaEngine(serialized_engine->data(), serialized_engine->size()));
+ if (!(*(trt_state->engine))) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to deserialize engine.");
}
+ if (detailed_build_log_) {
+ auto engine_build_stop = std::chrono::steady_clock::now();
+ LOGS_DEFAULT(INFO) << "TensorRT engine build for " << trt_state->trt_node_name_with_precision << " took: " << std::chrono::duration_cast(engine_build_stop - engine_build_start).count() << "ms" << std::endl;
+ }
+ }
+ if (!(*(trt_state->engine))) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP Failed to Build Engine.");
+ }
+ trt_engine = trt_state->engine->get();
+ if (trt_state->engine_cache_enable) {
+ // Serialize engine profile
+ SerializeProfileV2(profile_cache_path, shape_ranges);
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + profile_cache_path;
+
+ // Serialize engine
+ if (trt_state->engine_decryption_enable) {
+ // Encrypt engine. The library is not always deployed with the encrypt function, so check if it is available first.
+ if (trt_state->engine_encryption != nullptr) {
+ if (!trt_state->engine_encryption(encrypted_engine_cache_path.c_str(), reinterpret_cast(serialized_engine->data()), serialized_engine->size())) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP could not call engine encryption function encrypt");
+ }
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized and encrypted engine " + encrypted_engine_cache_path;
+ } else {
+ LOGS_DEFAULT(WARNING) << "[TensorRT EP] Engine cache encryption function is not found. No cache is written to disk";
+ }
+ } else {
+ std::ofstream file(engine_cache_path, std::ios::binary | std::ios::out);
+ file.write(reinterpret_cast(serialized_engine->data()), serialized_engine->size());
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path;
+ }
+ }
+
+ // serialize and save timing cache
+ if (trt_state->timing_cache_enable) {
+ auto timing_cache = trt_config->getTimingCache();
+ std::unique_ptr timingCacheHostData{timing_cache->serialize()};
+ if (timingCacheHostData == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP could not serialize timing cache: " + timing_cache_path);
+ }
+ saveTimingCacheFile(timing_cache_path, timingCacheHostData.get());
+ if (detailed_build_log_) {
+ LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized timing cache " + timing_cache_path;
+ }
+ }
+
+ // dump ep context model
+ if (dump_ep_context_model_ && ep_context_embed_mode_) {
+ UpdateCtxNodeModelEngineContext(model_proto_.get(), reinterpret_cast(serialized_engine->data()), serialized_engine->size());
+ DumpCtxNodeModel(model_proto_.get(), cache_path_prefix);
+ }
+ context_update = true;
+ }
+
+ if (context_update) {
+ if (trt_state->context_memory_sharing_enable) {
+ *(trt_state->context) = std::unique_ptr(
+ trt_state->engine->get()->createExecutionContextWithoutDeviceMemory());
+ } else {
+ *(trt_state->context) = std::unique_ptr(
+ trt_state->engine->get()->createExecutionContext());
+ }
+ if (!(*(trt_state->context))) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "TensorRT EP failed to create context.");
}
+ trt_context = trt_state->context->get();
+ }
- /*
- * Set output shapes and bind output buffers
- */
- std::unordered_map buffers;
- buffers.reserve(num_outputs);
- using OutputOrtValue = Ort::UnownedValue;
- std::unordered_map output_tensors;
- output_tensors.reserve(num_outputs);
- std::unordered_map output_dim_sizes;
- output_dim_sizes.reserve(num_outputs);
- std::unordered_set dds_output_set;
+ // Get input and output binding names
+ int total_bindings = trt_engine->getNbIOTensors();
+ std::vector input_binding_names, output_binding_names;
+ for (int i = 0, end = total_bindings; i < end; ++i) {
+ auto const& name = trt_engine->getIOTensorName(i);
+ auto const& mode = trt_engine->getTensorIOMode(name);
+ if (mode == nvinfer1::TensorIOMode::kINPUT) {
+ input_binding_names.push_back(name);
+ } else {
+ output_binding_names.push_back(name);
+ }
+ }
- for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
- char const* output_name = output_binding_names[i];
+ /*
+ * Set input shapes and bind input buffers
+ */
+ std::vector> scratch_buffers;
+ for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) {
+ char const* input_name = input_binding_names[i];
+ size_t input_index = 0;
+ const auto iter = input_indexes.find(input_name);
+ if (iter != input_indexes.end()) {
+ input_index = iter->second;
+ }
+ auto input_tensor = ctx.GetInput(input_index);
+ auto tensor_info = input_tensor.GetTensorTypeAndShapeInfo();
+ const auto tensor_shapes = tensor_info.GetShape();
+
+ // Only use for "shape tensor" input
+ std::vector shape_values;
+ if (tensor_shape_values.find(input_name) != tensor_shape_values.end()) {
+ shape_values = tensor_shape_values[input_name];
+ }
+
+ auto status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_values, scratch_buffers, alloc, stream);
+ if (status != Status::OK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ }
+ }
+
+ /*
+ * Set output shapes and bind output buffers
+ */
+ std::unordered_map buffers;
+ buffers.reserve(num_outputs);
+ using OutputOrtValue = Ort::UnownedValue;
+ std::unordered_map output_tensors;
+ output_tensors.reserve(num_outputs);
+ std::unordered_map output_dim_sizes;
+ output_dim_sizes.reserve(num_outputs);
+ std::unordered_set dds_output_set;
+
+ for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
+ char const* output_name = output_binding_names[i];
+
+ size_t output_index = 0;
+ const auto& index_iter = output_indexes.find(output_name);
+ if (index_iter != output_indexes.end()) {
+ output_index = index_iter->second;
+ }
+
+ size_t output_type = 0;
+ const auto type_iter = output_types.find(output_name);
+ if (type_iter != output_types.end()) {
+ output_type = type_iter->second;
+ }
+
+ Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes,
+ dds_output_set, dds_output_allocator_map, scratch_buffers, alloc, buffers);
+ if (status != Status::OK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ }
+ }
+
+ // Set execution context memory
+ if (trt_state->context_memory_sharing_enable) {
+ size_t mem_size = trt_engine->getDeviceMemorySize();
+ if (mem_size > *max_context_mem_size_ptr) {
+ *max_context_mem_size_ptr = mem_size;
+ }
+ trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get());
+ }
+
+ // Start CUDA graph capture.
+ // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because
+ // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream.
+ if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) {
+ LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model";
+ cuda_graph_.SetStream(stream);
+ CaptureBegin();
+ }
+
+ // Run TRT inference
+ if (!trt_context->enqueueV3(stream)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
+ }
+
+ if (sync_stream_after_enqueue || dds_output_set.size() > 0) {
+ CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
+ }
+
+ // Assign TRT output back to ORT output
+ // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished)
+ // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output
+ for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
+ char const* output_name = output_binding_names[i];
+
+ size_t output_type = 0;
+ const auto& iter = output_types.find(output_name);
+ if (iter != output_types.end()) {
+ output_type = iter->second;
+ }
+
+ if (dds_output_set.find(output_name) != dds_output_set.end()) {
size_t output_index = 0;
const auto& index_iter = output_indexes.find(output_name);
if (index_iter != output_indexes.end()) {
output_index = index_iter->second;
}
-
- size_t output_type = 0;
- const auto type_iter = output_types.find(output_name);
- if (type_iter != output_types.end()) {
- output_type = type_iter->second;
- }
-
- Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes,
- dds_output_set, dds_output_allocator_map, scratch_buffers, alloc, buffers);
+ auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, scratch_buffers, alloc, stream);
if (status != Status::OK()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage());
+ }
+ } else {
+ auto& output_tensor = output_tensors[i];
+ if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
+ auto output_tensor_ptr = output_tensor.GetTensorMutableData();
+ if (output_tensor_ptr != nullptr) {
+ cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]);
+ }
+ } else if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
+ auto output_tensor_ptr = output_tensor.GetTensorMutableData();
+ if (output_tensor_ptr != nullptr) {
+ cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]);
+ }
}
}
+ }
- // Set execution context memory
- if (trt_state->context_memory_sharing_enable) {
- size_t mem_size = trt_engine->getDeviceMemorySize();
- if (mem_size > *max_context_mem_size_ptr) {
- *max_context_mem_size_ptr = mem_size;
- }
- trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get());
+ // End CUDA graph capture.
+ // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture
+ // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc.
+ // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis.
+ if (cuda_graph_enable_ && !IsGraphCaptured()) {
+ if (IsGraphCaptureAllowed()) {
+ CaptureEnd();
+ // CUDA work issued to a capturing stream doesn’t actually run on the GPU,
+ // so run the captured graph here to actually execute the work.
+ ORT_RETURN_IF_ERROR(ReplayGraph());
+ } else {
+ IncrementRegularRunCountBeforeGraphCapture();
+ }
+ }
+
+ return Status::OK();
+ };
+
+ node_compute_funcs.push_back(compute_info);
+ return Status::OK();
+}
+
+Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(const GraphViewer& graph_body_viewer,
+ const Node& fused_node,
+ std::unordered_map& input_map,
+ std::unordered_map& output_map,
+ std::vector& node_compute_funcs) {
+ std::unique_ptr trt_engine;
+ std::unique_ptr trt_context;
+ std::unordered_map input_indexes; // TRT engine input name -> ORT kernel context input index
+ std::unordered_map output_indexes; // TRT engine output name -> ORT kernel context output index
+ std::unordered_map output_types; // TRT engine output name -> ORT output tensor type
+
+ // Get engine binary data and deserialize it
+ auto trt_cache_model_handler = TensorRTCacheModelHandler(&trt_engine, runtime_.get(), compute_capability_);
+ auto status = trt_cache_model_handler.GetEpContextFromGraph(graph_body_viewer);
+ if (status != Status::OK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ }
+
+ // Build context
+ //
+ // Note: Creating an execution context from an engine is thread safe per TRT doc
+ // https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
+ if (context_memory_sharing_enable_) {
+ size_t mem_size = trt_engine->getDeviceMemorySize();
+ if (mem_size > max_ctx_mem_size_) {
+ max_ctx_mem_size_ = mem_size;
+ }
+ trt_context = std::unique_ptr(trt_engine->createExecutionContextWithoutDeviceMemory());
+ } else {
+ trt_context = std::unique_ptr(trt_engine->createExecutionContext());
+ }
+ if (!trt_context) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
+ "TensorRT EP could not build execution context for fused node: " + fused_node.Name());
+ }
+
+ // Create input/output to index maps
+ for (int32_t i = 0; i < trt_engine->getNbIOTensors(); ++i) {
+ auto const& name = trt_engine->getIOTensorName(i);
+ auto const& mode = trt_engine->getTensorIOMode(name);
+ if (mode == nvinfer1::TensorIOMode::kINPUT) {
+ const auto& iter = input_map.find(name);
+ if (iter != input_map.end()) {
+ input_indexes[name] = iter->second;
+ }
+ } else {
+ const auto& iter = output_map.find(name);
+ if (iter != output_map.end()) {
+ output_indexes[name] = iter->second;
+ }
+ }
+ }
+
+ // Create output to type map
+ for (auto node_arg : graph_body_viewer.GetOutputs()) {
+ auto output_name = node_arg->Name();
+ auto& type = node_arg->TypeAsProto()->tensor_type();
+ output_types[output_name] = type.elem_type();
+ }
+
+ // Save TRT engine, TRT context and input/output info to map
+ engines_.emplace(fused_node.Name(), std::move(trt_engine));
+ contexts_.emplace(fused_node.Name(), std::move(trt_context));
+ input_info_[fused_node.Name()].push_back(input_indexes);
+ output_info_[fused_node.Name()].push_back(output_indexes);
+ output_info_[fused_node.Name()].push_back(output_types);
+
+ // Create function state
+ // TODO: remove default capture
+ NodeComputeInfo compute_info;
+ compute_info.create_state_func = [=](ComputeContext* context, FunctionState* state) {
+ std::unique_ptr p = std::make_unique();
+ *p = {context->allocate_func,
+ context->release_func,
+ context->allocator_handle,
+ context->node_name,
+ &engines_[context->node_name],
+ &contexts_[context->node_name],
+ input_info_[context->node_name],
+ output_info_[context->node_name],
+ sync_stream_after_enqueue_,
+ context_memory_sharing_enable_,
+ &max_ctx_mem_size_,
+ &tensorrt_mu_};
+ *state = p.release();
+ return 0;
+ };
+
+ // Release function state
+ compute_info.release_state_func = [](FunctionState state) {
+ delete static_cast(state);
+ };
+
+ // Create compute function
+ compute_info.compute_func = [this](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
+ Ort::KernelContext ctx(context);
+
+ TensorrtShortFuncState* trt_state = reinterpret_cast(state);
+
+ // The whole compute_function should be considered the critical section.
+ // More details here, https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#threading
+ std::lock_guard lock(*(trt_state->tensorrt_mu_ptr));
+
+ const std::unordered_map& input_indexes = (trt_state->input_info)[0];
+ const std::unordered_map& output_indexes = (trt_state->output_info)[0];
+ const std::unordered_map& output_types = (trt_state->output_info)[1];
+ auto fused_node_name = trt_state->fused_node_name;
+ bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue;
+ auto& dds_output_allocator_map = this->dds_output_allocator_maps_[fused_node_name];
+ auto trt_engine = trt_state->engine->get();
+ auto trt_context = trt_state->context->get();
+ auto max_context_mem_size_ptr = trt_state->max_context_mem_size_ptr;
+ // int num_inputs = static_cast(input_indexes.size());
+ int num_outputs = static_cast(output_indexes.size());
+
+ OrtMemoryInfo mem_info("", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, device_id_), device_id_);
+ if (alloc_ == nullptr) {
+ Ort::ThrowOnError(api->KernelContext_GetAllocator(context, &mem_info, &alloc_));
+ }
+ OrtAllocator* alloc = alloc_;
+
+ void* cuda_stream;
+ Ort::ThrowOnError(api->KernelContext_GetGPUComputeStream(context, &cuda_stream));
+ cudaStream_t stream = static_cast(cuda_stream);
+
+ // Get input and output binding names
+ int total_bindings = trt_engine->getNbIOTensors();
+ std::vector input_binding_names, output_binding_names;
+ for (int i = 0, end = total_bindings; i < end; ++i) {
+ auto const& name = trt_engine->getIOTensorName(i);
+ auto const& mode = trt_engine->getTensorIOMode(name);
+ if (mode == nvinfer1::TensorIOMode::kINPUT) {
+ input_binding_names.push_back(name);
+ } else {
+ output_binding_names.push_back(name);
}
+ }
+
+ /*
+ * Set input shapes and bind input buffers
+ */
+ std::vector> scratch_buffers;
+ for (size_t i = 0, end = input_binding_names.size(); i < end; ++i) {
+ char const* input_name = input_binding_names[i];
- // Start CUDA graph capture.
- // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because
- // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream.
- if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) {
- LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model";
- cuda_graph_.SetStream(stream);
- CaptureBegin();
+ size_t input_index = 0;
+ const auto iter = input_indexes.find(input_name);
+ if (iter != input_indexes.end()) {
+ input_index = iter->second;
}
- // Run TRT inference
- if (!trt_context->enqueueV3(stream)) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
+ // Only use for "shape tensor" input
+ std::vector shape_values;
+
+ Status status = BindContextInput(ctx, trt_engine, trt_context, input_name, input_index, shape_values, scratch_buffers, alloc, stream);
+ if (status != Status::OK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
}
+ }
+
+ /*
+ * Set output shapes and bind output buffers
+ */
+ std::unordered_map buffers;
+ buffers.reserve(num_outputs);
+ using OutputOrtValue = Ort::UnownedValue;
+ std::unordered_map output_tensors;
+ output_tensors.reserve(num_outputs);
+ std::unordered_map output_dim_sizes;
+ output_dim_sizes.reserve(num_outputs);
+ std::unordered_set dds_output_set;
- if (sync_stream_after_enqueue || dds_output_set.size() > 0) {
- CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
+ for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
+ char const* output_name = output_binding_names[i];
+
+ size_t output_index = 0;
+ const auto& index_iter = output_indexes.find(output_name);
+ if (index_iter != output_indexes.end()) {
+ output_index = index_iter->second;
}
- // Assign TRT output back to ORT output
- // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished)
- // (2) Cast TRT INT32 output to ORT INT64 output or TRT float output to double output
- for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
- char const* output_name = output_binding_names[i];
+ size_t output_type = 0;
+ const auto type_iter = output_types.find(output_name);
+ if (type_iter != output_types.end()) {
+ output_type = type_iter->second;
+ }
- size_t output_type = 0;
- const auto& iter = output_types.find(output_name);
- if (iter != output_types.end()) {
- output_type = iter->second;
- }
+ Status status = BindContextOutput(ctx, trt_context, output_name, output_index, output_type, i, output_tensors, output_dim_sizes,
+ dds_output_set, dds_output_allocator_map, scratch_buffers, alloc, buffers);
+ if (status != Status::OK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, status.ErrorMessage());
+ }
+ }
- if (dds_output_set.find(output_name) != dds_output_set.end()) {
- size_t output_index = 0;
- const auto& index_iter = output_indexes.find(output_name);
- if (index_iter != output_indexes.end()) {
- output_index = index_iter->second;
- }
- auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, scratch_buffers, alloc, stream);
- if (status != Status::OK()) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage());
+ // Set execution context memory
+ if (trt_state->context_memory_sharing_enable) {
+ size_t mem_size = trt_engine->getDeviceMemorySize();
+ if (mem_size > *max_context_mem_size_ptr) {
+ *max_context_mem_size_ptr = mem_size;
+ }
+ trt_context->setDeviceMemory(IAllocator::MakeUniquePtrFromOrtAllocator(alloc, *max_context_mem_size_ptr).get());
+ }
+
+ // Start CUDA graph capture.
+ // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because
+ // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream.
+ if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) {
+ LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model";
+ cuda_graph_.SetStream(stream);
+ CaptureBegin();
+ }
+
+ // Run TRT inference
+ if (!trt_context->enqueueV3(stream)) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
+ }
+
+ if (sync_stream_after_enqueue || dds_output_set.size() > 0) {
+ CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
+ }
+
+ // Assign TRT output back to ORT output
+ // (1) Bind TRT DDS output to ORT kernel context output. (It needs to wait until enqueueV3 is finished)
+ // (2) Cast TRT INT32 output to ORT INT64 output or TRT double output to float output
+ for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
+ char const* output_name = output_binding_names[i];
+
+ size_t output_type = 0;
+ const auto& iter = output_types.find(output_name);
+ if (iter != output_types.end()) {
+ output_type = iter->second;
+ }
+
+ if (dds_output_set.find(output_name) != dds_output_set.end()) {
+ size_t output_index = 0;
+ const auto& index_iter = output_indexes.find(output_name);
+ if (index_iter != output_indexes.end()) {
+ output_index = index_iter->second;
+ }
+ auto status = BindKernelOutput(ctx, &mem_info, dds_output_allocator_map, output_name, output_index, output_type, scratch_buffers, alloc, stream);
+ if (status != Status::OK()) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, status.ErrorMessage());
+ }
+ } else {
+ auto& output_tensor = output_tensors[i];
+ if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
+ auto output_tensor_ptr = output_tensor.GetTensorMutableData();
+ if (output_tensor_ptr != nullptr) {
+ cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]);
}
- } else {
- auto& output_tensor = output_tensors[i];
- if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
- auto output_tensor_ptr = output_tensor.GetTensorMutableData();
- if (output_tensor_ptr != nullptr) {
- cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]);
- }
- } else if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
- auto output_tensor_ptr = output_tensor.GetTensorMutableData();
- if (output_tensor_ptr != nullptr) {
- cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]);
- }
+ } else if (output_type == ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE) {
+ auto output_tensor_ptr = output_tensor.GetTensorMutableData();
+ if (output_tensor_ptr != nullptr) {
+ cuda::Impl_Cast(stream, reinterpret_cast(buffers[output_name]), output_tensor_ptr, output_dim_sizes[i]);
}
}
}
+ }
- // End CUDA graph capture.
- // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture
- // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc.
- // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis.
- if (cuda_graph_enable_ && !IsGraphCaptured()) {
- if (IsGraphCaptureAllowed()) {
- CaptureEnd();
- // CUDA work issued to a capturing stream doesn’t actually run on the GPU,
- // so run the captured graph here to actually execute the work.
- ORT_RETURN_IF_ERROR(ReplayGraph());
- } else {
- IncrementRegularRunCountBeforeGraphCapture();
- }
+ // End CUDA graph capture.
+ // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture
+ // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc.
+ // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis.
+ if (cuda_graph_enable_ && !IsGraphCaptured()) {
+ if (IsGraphCaptureAllowed()) {
+ CaptureEnd();
+ // CUDA work issued to a capturing stream doesn’t actually run on the GPU,
+ // so run the captured graph here to actually execute the work.
+ ORT_RETURN_IF_ERROR(ReplayGraph());
+ } else {
+ IncrementRegularRunCountBeforeGraphCapture();
}
+ }
- return Status::OK();
- };
+ return Status::OK();
+ };
- node_compute_funcs.push_back(compute_info);
- }
+ node_compute_funcs.push_back(compute_info);
return Status::OK();
}
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
index bacdf0f3c996c..9b8798e0fc4ca 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h
@@ -46,6 +46,9 @@ static const std::string kProfilesMinShapes = "ORT_TENSORRT_PROFILE_MIN_SHAPES";
static const std::string kProfilesMaxShapes = "ORT_TENSORRT_PROFILE_MAX_SHAPES";
static const std::string kProfilesOptShapes = "ORT_TENSORRT_PROFILE_OPT_SHAPES";
static const std::string kCudaGraphEnable = "ORT_TENSORRT_CUDA_GRAPH_ENABLE";
+static const std::string kDumpEpContextModel = "ORT_DUMP_EP_CONTEXT_MODEL";
+static const std::string kEpContextEmbedMode = "ORT_EP_CONTEXT_EMBED_MODE";
+static const std::string kEpContextComputeCapabilityEnable = "ORT_EP_CONTEXT_COMPUTE_CAPABILITY_ENABLE";
// Old env variable for backward compatibility
static const std::string kEngineCachePath = "ORT_TENSORRT_ENGINE_CACHE_PATH";
} // namespace tensorrt_env_vars
@@ -177,6 +180,22 @@ struct TensorrtFuncState {
bool cuda_graph_enable = 0;
};
+// Minimum information to construct kernel function state for direct engine load code path
+struct TensorrtShortFuncState {
+ AllocateFunc test_allocate_func = nullptr;
+ DestroyFunc test_release_func = nullptr;
+ AllocatorHandle allocator = nullptr;
+ std::string fused_node_name;
+ std::unique_ptr* engine = nullptr;
+ std::unique_ptr* context = nullptr;
+ std::vector> input_info;
+ std::vector> output_info;
+ bool sync_stream_after_enqueue = false;
+ bool context_memory_sharing_enable = false;
+ size_t* max_context_mem_size_ptr = nullptr;
+ OrtMutex* tensorrt_mu_ptr = nullptr;
+};
+
// Holds important information for building valid ORT graph.
struct SubGraphContext {
std::unordered_set output_args;
@@ -276,6 +295,12 @@ class TensorrtExecutionProvider : public IExecutionProvider {
// and should be kept for the lifetime of TRT EP object.
OrtAllocator* alloc_ = nullptr;
+ // For create/dump EP context node model
+ bool dump_ep_context_model_ = false;
+ int ep_context_embed_mode_ = 0;
+ bool ep_context_compute_capability_enable_ = true;
+ std::unique_ptr model_proto_ = ONNX_NAMESPACE::ModelProto::Create();
+
std::unordered_set control_flow_op_set_ = {"If", "Loop", "Scan"};
mutable std::unordered_map> subgraph_context_map_;
@@ -489,6 +514,25 @@ class TensorrtExecutionProvider : public IExecutionProvider {
*/
bool IsLocalValue(const Graph& graph, const std::string& name) const;
+ /**
+ * Create a vector of NodeComputeInfo instances directly from "TRT engine" wrapped onnx model without
+ * going through the time-consuming processes of model parsing and engine building.
+ */
+ Status CreateNodeComputeInfoFromPrecompiledEngine(const GraphViewer& graph_body_viewer,
+ const Node& fused_node,
+ std::unordered_map& input_map,
+ std::unordered_map& output_map,
+ std::vector& node_compute_funcs);
+
+ /**
+ * Create a vector of NodeComputeInfo instances from graph.
+ */
+ Status CreateNodeComputeInfoFromGraph(const GraphViewer& graph_body_viewer,
+ const Node& fused_node,
+ std::unordered_map& input_map,
+ std::unordered_map& output_map,
+ std::vector& node_compute_funcs);
+
bool IsGraphCaptureAllowed() const;
void CaptureBegin();
void CaptureEnd();
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc
index 3ead33f9131d9..f7820ac8a08c3 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.cc
@@ -46,6 +46,9 @@ constexpr const char* kProfilesMinShapes = "trt_profile_min_shapes";
constexpr const char* kProfilesMaxShapes = "trt_profile_max_shapes";
constexpr const char* kProfilesOptShapes = "trt_profile_opt_shapes";
constexpr const char* kCudaGraphEnable = "trt_cuda_graph_enable";
+constexpr const char* kDumpEpContextModel = "trt_dump_ep_context_model";
+constexpr const char* kEpContextEmbedMode = "trt_ep_context_embed_mode";
+constexpr const char* kEpContextComputeCapabilityEnable = "trt_ep_context_compute_capability_enable";
} // namespace provider_option_names
} // namespace tensorrt
@@ -97,6 +100,9 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
.AddAssignmentToReference(tensorrt::provider_option_names::kProfilesMaxShapes, info.profile_max_shapes)
.AddAssignmentToReference(tensorrt::provider_option_names::kProfilesOptShapes, info.profile_opt_shapes)
.AddAssignmentToReference(tensorrt::provider_option_names::kCudaGraphEnable, info.cuda_graph_enable)
+ .AddAssignmentToReference(tensorrt::provider_option_names::kDumpEpContextModel, info.dump_ep_context_model)
+ .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextEmbedMode, info.ep_context_embed_mode)
+ .AddAssignmentToReference(tensorrt::provider_option_names::kEpContextComputeCapabilityEnable, info.ep_context_compute_capability_enable)
.Parse(options)); // add new provider option here.
return info;
@@ -138,6 +144,9 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kProfilesMaxShapes, MakeStringWithClassicLocale(info.profile_max_shapes)},
{tensorrt::provider_option_names::kProfilesOptShapes, MakeStringWithClassicLocale(info.profile_opt_shapes)},
{tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.cuda_graph_enable)},
+ {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.dump_ep_context_model)},
+ {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.ep_context_embed_mode)},
+ {tensorrt::provider_option_names::kEpContextComputeCapabilityEnable, MakeStringWithClassicLocale(info.ep_context_compute_capability_enable)},
};
return options;
}
@@ -188,6 +197,9 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const OrtTensor
{tensorrt::provider_option_names::kProfilesMaxShapes, kProfilesMaxShapes_},
{tensorrt::provider_option_names::kProfilesOptShapes, kProfilesOptShapes_},
{tensorrt::provider_option_names::kCudaGraphEnable, MakeStringWithClassicLocale(info.trt_cuda_graph_enable)},
+ {tensorrt::provider_option_names::kDumpEpContextModel, MakeStringWithClassicLocale(info.trt_dump_ep_context_model)},
+ {tensorrt::provider_option_names::kEpContextEmbedMode, MakeStringWithClassicLocale(info.trt_ep_context_embed_mode)},
+ {tensorrt::provider_option_names::kEpContextComputeCapabilityEnable, MakeStringWithClassicLocale(info.trt_ep_context_compute_capability_enable)},
};
return options;
}
@@ -279,5 +291,8 @@ void TensorrtExecutionProviderInfo::UpdateProviderOptions(void* provider_options
trt_provider_options_v2.trt_profile_opt_shapes = copy_string_if_needed(internal_options.profile_opt_shapes);
trt_provider_options_v2.trt_cuda_graph_enable = internal_options.cuda_graph_enable;
+ trt_provider_options_v2.trt_dump_ep_context_model = internal_options.dump_ep_context_model;
+ trt_provider_options_v2.trt_ep_context_embed_mode = internal_options.ep_context_embed_mode;
+ trt_provider_options_v2.trt_ep_context_compute_capability_enable = internal_options.ep_context_compute_capability_enable;
}
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h
index b16543aa3d7dd..76223b7847359 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_info.h
@@ -51,6 +51,9 @@ struct TensorrtExecutionProviderInfo {
std::string profile_max_shapes{""};
std::string profile_opt_shapes{""};
bool cuda_graph_enable{false};
+ bool dump_ep_context_model{false};
+ int ep_context_embed_mode{0};
+ bool ep_context_compute_capability_enable{1};
static TensorrtExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const TensorrtExecutionProviderInfo& info);
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h
index c69299d0ecdeb..07f6f8eb3476f 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider_utils.h
@@ -5,6 +5,7 @@
#include
#include
#include
+#include
#include
#include "flatbuffers/idl.h"
#include "ort_trt_int8_cal_table.fbs.h"
diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc
index 426584553f349..0e29df72f0322 100644
--- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc
+++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc
@@ -116,6 +116,9 @@ struct Tensorrt_Provider : Provider {
info.profile_max_shapes = options.trt_profile_max_shapes == nullptr ? "" : options.trt_profile_max_shapes;
info.profile_opt_shapes = options.trt_profile_opt_shapes == nullptr ? "" : options.trt_profile_opt_shapes;
info.cuda_graph_enable = options.trt_cuda_graph_enable != 0;
+ info.dump_ep_context_model = options.trt_dump_ep_context_model != 0;
+ info.ep_context_embed_mode = options.trt_ep_context_embed_mode;
+ info.ep_context_compute_capability_enable = options.trt_ep_context_compute_capability_enable != 0;
return std::make_shared(info);
}
diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc
index e3b8dea90a898..e2d46012c097b 100644
--- a/onnxruntime/core/session/provider_bridge_ort.cc
+++ b/onnxruntime/core/session/provider_bridge_ort.cc
@@ -427,6 +427,7 @@ struct ProviderHostImpl : ProviderHost {
int64_t AttributeProto__i(const ONNX_NAMESPACE::AttributeProto* p) override { return p->i(); }
float AttributeProto__f(const ONNX_NAMESPACE::AttributeProto* p) override { return p->f(); }
void AttributeProto__set_s(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { return p->set_s(value); }
+ void AttributeProto__set_i(ONNX_NAMESPACE::AttributeProto* p, int64_t value) override { return p->set_i(value); }
const ::std::string& AttributeProto__s(const ONNX_NAMESPACE::AttributeProto* p) override { return p->s(); }
void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) override { return p->set_name(value); }
void AttributeProto__set_type(ONNX_NAMESPACE::AttributeProto* p, ONNX_NAMESPACE::AttributeProto_AttributeType value) override { return p->set_type(value); }
@@ -447,6 +448,7 @@ struct ProviderHostImpl : ProviderHost {
ONNX_NAMESPACE::ValueInfoProtos* GraphProto__mutable_value_info(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_value_info(); }
ONNX_NAMESPACE::TensorProtos* GraphProto__mutable_initializer(ONNX_NAMESPACE::GraphProto* p) override { return p->mutable_initializer(); }
ONNX_NAMESPACE::NodeProto* GraphProto__add_node(ONNX_NAMESPACE::GraphProto* p) override { return p->add_node(); }
+ ONNX_NAMESPACE::NodeProto* GraphProto__mutable_node(ONNX_NAMESPACE::GraphProto* p, int index) override { return p->mutable_node(index); }
void GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) override { *p = v; }
@@ -470,6 +472,7 @@ struct ProviderHostImpl : ProviderHost {
void NodeProto__operator_assign(ONNX_NAMESPACE::NodeProto* p, const ONNX_NAMESPACE::NodeProto& v) override { *p = v; }
int NodeProto__attribute_size(ONNX_NAMESPACE::NodeProto* p) override { return p->attribute_size(); }
const ONNX_NAMESPACE::AttributeProto& NodeProto__attribute(const ONNX_NAMESPACE::NodeProto* p, int index) const override { return p->attribute(index); }
+ ONNX_NAMESPACE::AttributeProto* NodeProto__mutable_attribute(ONNX_NAMESPACE::NodeProto* p, int index) override { return p->mutable_attribute(index); }
// TensorProto (wrapped)
std::unique_ptr TensorProto__construct() override { return std::make_unique(); }
diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc
index 6f383d733edbd..06eb2afdf80f2 100644
--- a/onnxruntime/python/onnxruntime_pybind_state.cc
+++ b/onnxruntime/python/onnxruntime_pybind_state.cc
@@ -713,6 +713,28 @@ std::unique_ptr CreateExecutionProviderInstance(
} else {
ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_cuda_graph_enable' should be 'True' or 'False'. Default value is 'False'.\n");
}
+ } else if (option.first == "trt_dump_ep_context_model") {
+ if (option.second == "True" || option.second == "true") {
+ params.trt_dump_ep_context_model = true;
+ } else if (option.second == "False" || option.second == "false") {
+ params.trt_dump_ep_context_model = false;
+ } else {
+ ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_dump_ep_context_model' should be 'True' or 'False'. Default value is 'False'.\n");
+ }
+ } else if (option.first == "trt_ep_context_embed_mode") {
+ if (!option.second.empty()) {
+ params.trt_ep_context_embed_mode = std::stoi(option.second);
+ } else {
+ ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_embed_mode' should be a positive integer number i.e. '1'.\n");
+ }
+ } else if (option.first == "trt_ep_context_compute_capability_enable") {
+ if (option.second == "True" || option.second == "true") {
+ params.trt_ep_context_compute_capability_enable = true;
+ } else if (option.second == "False" || option.second == "false") {
+ params.trt_ep_context_compute_capability_enable = false;
+ } else {
+ ORT_THROW("[ERROR] [TensorRT] The value for the key 'trt_ep_context_compute_capability_enable' should be 'True' or 'False'. Default value is 'False'.\n");
+ }
} else {
ORT_THROW("Invalid TensorRT EP option: ", option.first);
}
diff --git a/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py
new file mode 100644
index 0000000000000..717a0816247e7
--- /dev/null
+++ b/onnxruntime/python/tools/tensorrt/gen_trt_engine_wrapper_onnx_model.py
@@ -0,0 +1,174 @@
+#!/usr/bin/env python3
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+from argparse import ArgumentParser
+
+import onnx
+import tensorrt as trt
+from onnx import TensorProto, helper
+
+
+class TensorRTEngineWrapperCreator:
+ def __init__(self, args):
+ ctx_embed_mode = args.embed_mode
+ engine_cache_path = args.trt_engine_cache_path
+ self.model_name = args.model_name
+ self.dynamic_dim_count = 0
+
+ # Get serialized engine from engine cache
+ with open(engine_cache_path, "rb") as file:
+ engine_buffer = file.read()
+
+ if ctx_embed_mode:
+ ep_cache_context_content = engine_buffer
+ else:
+ ep_cache_context_content = engine_cache_path
+
+ # Deserialize an TRT engine
+ logger = trt.Logger(trt.Logger.WARNING)
+ runtime = trt.Runtime(logger)
+ engine = runtime.deserialize_cuda_engine(engine_buffer)
+ num_bindings = engine.num_bindings
+
+ input_tensors = []
+ output_tensors = []
+ input_tensor_shapes = []
+ output_tensor_shapes = []
+ input_tensor_types = []
+ output_tensor_types = []
+
+ # Get type and shape of each input/output
+ for b_index in range(num_bindings):
+ tensor_name = engine.get_tensor_name(b_index)
+ tensor_shape = engine.get_tensor_shape(tensor_name)
+ tensor_type = engine.get_tensor_dtype(tensor_name)
+ if engine.get_tensor_mode(tensor_name) == trt.TensorIOMode.INPUT:
+ input_tensors.append(tensor_name)
+ input_tensor_shapes.append(tensor_shape)
+ input_tensor_types.append(tensor_type)
+ else:
+ output_tensors.append(tensor_name)
+ output_tensor_shapes.append(tensor_shape)
+ output_tensor_types.append(tensor_type)
+
+ # Note:
+ # The TRT engine should be built with min, max and opt profiles so that dynamic shape input can have dimension of "-1"
+ print(input_tensors)
+ print(input_tensor_types)
+ print(input_tensor_shapes)
+ print(output_tensors)
+ print(output_tensor_types)
+ print(output_tensor_shapes)
+
+ nodes = [
+ helper.make_node(
+ "EPContext",
+ input_tensors,
+ output_tensors,
+ "EPContext",
+ domain="com.microsoft",
+ embed_mode=ctx_embed_mode,
+ ep_cache_context=ep_cache_context_content,
+ ),
+ ]
+
+ model_inputs = []
+ for i in range(len(input_tensors)):
+ model_inputs.append(
+ helper.make_tensor_value_info(
+ input_tensors[i],
+ self.trt_data_type_to_onnx_data_type(input_tensor_types[i]),
+ self.trt_shape_to_ort_shape(input_tensor_shapes[i]),
+ )
+ )
+
+ model_outputs = []
+ for i in range(len(output_tensors)):
+ model_outputs.append(
+ helper.make_tensor_value_info(
+ output_tensors[i],
+ self.trt_data_type_to_onnx_data_type(output_tensor_types[i]),
+ self.trt_shape_to_ort_shape(output_tensor_shapes[i]),
+ )
+ )
+
+ self.graph = helper.make_graph(
+ nodes,
+ "trt_engine_wrapper",
+ model_inputs,
+ model_outputs,
+ )
+
+ def trt_data_type_to_onnx_data_type(self, trt_data_type):
+ if trt_data_type == trt.DataType.FLOAT:
+ return TensorProto.FLOAT
+ elif trt_data_type == trt.DataType.HALF:
+ return TensorProto.FLOAT16
+ elif trt_data_type == trt.DataType.INT8:
+ return TensorProto.INT8
+ elif trt_data_type == trt.DataType.INT32:
+ return TensorProto.INT32
+ elif trt_data_type == trt.DataType.BOOL:
+ return TensorProto.BOOL
+ elif trt_data_type == trt.DataType.UINT8:
+ return TensorProto.UINT8
+ else:
+ return TensorProto.UNDEFINED
+
+ # TRT uses "-1" to represent dynamic dimension
+ # ORT uses symbolic name to represent dynamic dimension
+ # Here we only do the conversion when there is any dynamic dimension in the shape
+ def trt_shape_to_ort_shape(self, trt_data_shape):
+ def has_dynamic_dim(trt_data_shape):
+ if any(dim == -1 for dim in trt_data_shape):
+ return True
+ return False
+
+ if not has_dynamic_dim(trt_data_shape):
+ return trt_data_shape
+
+ ort_data_shape = []
+ if has_dynamic_dim(trt_data_shape):
+ for dim in trt_data_shape:
+ if dim == -1:
+ ort_data_shape.append("free_dim_" + str(self.dynamic_dim_count))
+ self.dynamic_dim_count += 1
+ else:
+ ort_data_shape.append(dim)
+ return ort_data_shape
+
+ def create_model(self):
+ model = helper.make_model(self.graph)
+ onnx.save(model, self.model_name)
+ print(self.model_name + " is created.")
+
+
+def main():
+ parser = ArgumentParser("Generate Onnx model which includes the TensorRT engine binary.")
+ parser.add_argument(
+ "-p", "--trt_engine_cache_path", help="Required. Path to TensorRT engine cache.", required=True, type=str
+ )
+ parser.add_argument(
+ "-e",
+ "--embed_mode",
+ help="mode 0 means the engine cache path and mode 1 means engine binary data",
+ required=False,
+ default=0,
+ type=int,
+ )
+ parser.add_argument(
+ "-m",
+ "--model_name",
+ help="Model name to be created",
+ required=False,
+ default="trt_engine_wrapper.onnx",
+ type=str,
+ )
+ args = parser.parse_args()
+ ctor = TensorRTEngineWrapperCreator(args)
+ ctor.create_model()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/onnxruntime/test/python/onnxruntime_test_engine_wrapper.py b/onnxruntime/test/python/onnxruntime_test_engine_wrapper.py
new file mode 100644
index 0000000000000..4123318b9f0af
--- /dev/null
+++ b/onnxruntime/test/python/onnxruntime_test_engine_wrapper.py
@@ -0,0 +1,100 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+
+import os
+import unittest
+
+import numpy as np
+import onnx
+from helper import get_name
+from onnx import TensorProto, helper
+
+import onnxruntime as onnxrt
+
+
+class TestInferenceSessionWithCtxNode(unittest.TestCase):
+ trt_engine_cache_path_ = "./trt_engine_cache"
+ ctx_node_model_name_ = "ctx_node.onnx"
+
+ # This test is only for TRT EP to test EPContext node with TRT engine
+ @unittest.skipIf(
+ "TensorrtExecutionProvider" not in onnxrt.get_available_providers(),
+ reason="Test TRT EP only",
+ )
+ def create_ctx_node(self, ctx_embed_mode=0, cache_path=""):
+ if ctx_embed_mode:
+ # Get engine buffer from engine cache
+ with open(cache_path, "rb") as file:
+ engine_buffer = file.read()
+ ep_cache_context_content = engine_buffer
+ else:
+ ep_cache_context_content = cache_path
+
+ nodes = [
+ helper.make_node(
+ "EPContext",
+ ["X"],
+ ["Y"],
+ "EPContext",
+ domain="com.microsoft",
+ embed_mode=ctx_embed_mode,
+ ep_cache_context=ep_cache_context_content,
+ ),
+ ]
+
+ graph = helper.make_graph(
+ nodes,
+ "trt_engine_wrapper",
+ [ # input
+ helper.make_tensor_value_info("X", TensorProto.FLOAT, ["N", 2]),
+ ],
+ [ # output
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["N", 1]),
+ ],
+ )
+ model = helper.make_model(graph)
+ onnx.save(model, self.ctx_node_model_name_)
+
+ def test_ctx_node(self):
+ x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)
+
+ # First session and run to create engine cache
+ providers = [
+ (
+ "TensorrtExecutionProvider",
+ {"trt_engine_cache_enable": True, "trt_engine_cache_path": self.trt_engine_cache_path_},
+ )
+ ]
+ session = onnxrt.InferenceSession(get_name("matmul_2.onnx"), providers=providers)
+ session.run(
+ ["Y"],
+ {"X": x},
+ )
+
+ # Get engine cache name
+ cache_name = ""
+ for f in os.listdir(self.trt_engine_cache_path_):
+ if f.endswith(".engine"):
+ cache_name = f
+ print(cache_name)
+
+ # Second session and run to test ctx node with engine cache path
+ self.create_ctx_node(cache_path=os.path.join(self.trt_engine_cache_path_, cache_name))
+ providers = [("TensorrtExecutionProvider", {})]
+ session = onnxrt.InferenceSession(get_name(self.ctx_node_model_name_), providers=providers)
+ session.run(
+ ["Y"],
+ {"X": x},
+ )
+
+ # Third session and run to test ctx node with engine binary content
+ self.create_ctx_node(ctx_embed_mode=1, cache_path=os.path.join(self.trt_engine_cache_path_, cache_name))
+ session = onnxrt.InferenceSession(get_name(self.ctx_node_model_name_), providers=providers)
+ session.run(
+ ["Y"],
+ {"X": x},
+ )
+
+
+if __name__ == "__main__":
+ unittest.main()