Skip to content

Commit

Permalink
[TensorRT EP] Load precompiled TRT engine file directly (#18217)
Browse files Browse the repository at this point in the history
When the TRT engine cache (precompiled engine) is present, it doesn't
make sense to go over the processes of model verification, model
optimization, TRT EP's GetCapability(), TRT EP's model proto
reconstruction, calling TRT parser and engine compilation.
This PR makes TRT EP skip those processes and directly load the engine
to perform inference.

The feature request:
#18072

Features:

- Replace original model with TRT engine wrapped ONNX model. It can save
a lot of time as mentioned above.

- How to get TRT engine wrapped ONNX model?
1. Set `trt_dump_ep_context_model` provider option to "true" and run the
inference. You will find the "xxx_wrapper.onnx" at the engine cache
path. (The same logic of generating engine cache)
    2. Use gen_trt_engine_wrapper_onnx_model.py

- Three provider options are added, 
`trt_dump_ep_context_model`: Enable dump wrapped onnx model by TRT EP
`trt_ep_context_embed_mode`: Add embed_mode as attribute. 0 means engine
cache path, 1 means engine binary data.
`trt_ep_context_compute_capability_enable`: Add hardware_arch as
attribute. When running the model, TRT EP will check consistency between
model's hardware_arch and GPU's compute capability.

- When the engine cache path is given in the wrapped model, TRT EP will
first search for the engine file using the path (relative to model
path), if it can't find it, it will change to use the path as it is
(depends on user, could be relative to working dir or absolute path)

Note: 

1. This PR includes the change of
#17751


Constraints:

1. The whole model should be fully supported by TRT. 
4. Users need to make sure the engine is built with min/max/opt
optimization profiles that large enough to cover the range of all
inputs. TRT EP will simply fail and won't rebuild the engine if the
input shape is out of range during runtime.
  • Loading branch information
chilo-ms authored Jan 12, 2024
1 parent b6d8283 commit 46dd0d3
Show file tree
Hide file tree
Showing 17 changed files with 1,873 additions and 840 deletions.
2 changes: 2 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0.</dd>
<dt><tt>ep_sdk_version</tt> : string</dt>
<dd>(Optional) SDK version used to convert the model.</dd>
<dt><tt>hardware_architecture</tt> : string</dt>
<dd>(Optional) Hardware architecture.</dd>
<dt><tt>main_context</tt> : int</dt>
<dd>Usually each single EPContext associate with a graph partition.But for some case like QNN, it has single EPContext contains all partitions.In that case, the node with ep_cache_context should set main_context=1. Other nodes set main_context=0 and skip ep_cache_context.The path is relative to this Onnx file. Default is 1.</dd>
<dt><tt>notes</tt> : string</dt>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,7 @@ struct OrtTensorRTProviderOptionsV2 {
const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT
int trt_dump_ep_context_model{0}; // Dump EP context node model
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute
};
5 changes: 5 additions & 0 deletions onnxruntime/core/graph/contrib_ops/contrib_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3230,6 +3230,11 @@ void RegisterContribSchemas() {
"(Optional) SDK version used to convert the model.",
AttributeProto::STRING,
OPTIONAL_VALUE)
.Attr(
"hardware_architecture",
"(Optional) Hardware architecture.",
AttributeProto::STRING,
OPTIONAL_VALUE)
.Attr(
"partition_name",
"(Optional) partitioned graph name.",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ struct ProviderHost {
virtual int64_t AttributeProto__i(const ONNX_NAMESPACE::AttributeProto* p) = 0;
virtual float AttributeProto__f(const ONNX_NAMESPACE::AttributeProto* p) = 0;
virtual void AttributeProto__set_s(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0;
virtual void AttributeProto__set_i(ONNX_NAMESPACE::AttributeProto* p, int64_t value) = 0;
virtual const ::std::string& AttributeProto__s(const ONNX_NAMESPACE::AttributeProto* p) = 0;
virtual void AttributeProto__set_name(ONNX_NAMESPACE::AttributeProto* p, const ::std::string& value) = 0;
virtual void AttributeProto__set_type(ONNX_NAMESPACE::AttributeProto* p, ONNX_NAMESPACE::AttributeProto_AttributeType value) = 0;
Expand All @@ -351,6 +352,7 @@ struct ProviderHost {
virtual ONNX_NAMESPACE::ValueInfoProtos* GraphProto__mutable_value_info(ONNX_NAMESPACE::GraphProto* p) = 0;
virtual ONNX_NAMESPACE::TensorProtos* GraphProto__mutable_initializer(ONNX_NAMESPACE::GraphProto* p) = 0;
virtual ONNX_NAMESPACE::NodeProto* GraphProto__add_node(ONNX_NAMESPACE::GraphProto* p) = 0;
virtual ONNX_NAMESPACE::NodeProto* GraphProto__mutable_node(ONNX_NAMESPACE::GraphProto* p, int index) = 0;

// ModelProto
virtual std::unique_ptr<ONNX_NAMESPACE::ModelProto> ModelProto__construct() = 0;
Expand All @@ -372,6 +374,7 @@ struct ProviderHost {
virtual void NodeProto__operator_assign(ONNX_NAMESPACE::NodeProto* p, const ONNX_NAMESPACE::NodeProto& v) = 0;
virtual int NodeProto__attribute_size(ONNX_NAMESPACE::NodeProto* p) = 0;
virtual const ONNX_NAMESPACE::AttributeProto& NodeProto__attribute(const ONNX_NAMESPACE::NodeProto* p, int index) const = 0;
virtual ONNX_NAMESPACE::AttributeProto* NodeProto__mutable_attribute(ONNX_NAMESPACE::NodeProto* p, int index) = 0;

// TensorProto
virtual std::unique_ptr<ONNX_NAMESPACE::TensorProto> TensorProto__construct() = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ struct AttributeProto final {
int64_t i() const { return g_host->AttributeProto__i(this); }
float f() const { return g_host->AttributeProto__f(this); }
void set_s(const ::std::string& value) { return g_host->AttributeProto__set_s(this, value); }
void set_i(int64_t value) { return g_host->AttributeProto__set_i(this, value); }
const ::std::string& s() const { return g_host->AttributeProto__s(this); }
void set_name(const ::std::string& value) { return g_host->AttributeProto__set_name(this, value); }
void set_type(AttributeProto_AttributeType value) { return g_host->AttributeProto__set_type(this, value); }
Expand Down Expand Up @@ -118,6 +119,7 @@ struct GraphProto final {
ValueInfoProtos* mutable_value_info() { return g_host->GraphProto__mutable_value_info(this); }
TensorProtos* mutable_initializer() { return g_host->GraphProto__mutable_initializer(this); }
NodeProto* add_node() { return g_host->GraphProto__add_node(this); }
NodeProto* mutable_node(int index) { return g_host->GraphProto__mutable_node(this, index); }

GraphProto() = delete;
GraphProto(const GraphProto&) = delete;
Expand Down Expand Up @@ -148,6 +150,7 @@ struct NodeProto final {
void operator=(const NodeProto& v) { g_host->NodeProto__operator_assign(this, v); }
int attribute_size() { return g_host->NodeProto__attribute_size(this); }
const AttributeProto& attribute(int index) const { return g_host->NodeProto__attribute(this, index); }
AttributeProto* mutable_attribute(int index) { return g_host->NodeProto__mutable_attribute(this, index); }

NodeProto() = delete;
NodeProto(const NodeProto&) = delete;
Expand Down
229 changes: 229 additions & 0 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <iostream>
#include <fstream>
#include <filesystem>

#include "onnx_ctx_model_helper.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
#include "core/framework/execution_provider.h"

namespace onnxruntime {

/*
* Check whether the graph has the EP context contrib op.
* The op can contain the precompiled engine info for TRT EP to directly load the engine.
*
* Note: Please see more details about "EPContext" contrib op in contrib_defs.cc
*/
bool GraphHasCtxNode(const GraphViewer& graph_viewer) {
for (int i = 0; i < graph_viewer.MaxNodeIndex(); ++i) {
auto node = graph_viewer.GetNode(i);
if (node != nullptr && node->OpType() == EPCONTEXT_OP) {
return true;
}
}
return false;
}

const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer) {
// find the top level graph
const Graph* cur_graph = &graph_viewer.GetGraph();
while (cur_graph->IsSubgraph()) {
cur_graph = cur_graph->ParentGraph();
}

const Graph& main_graph = *cur_graph;
return main_graph.ModelPath();
}

std::filesystem::path LocateEngineRelativeToPath(std::string engine_cache_path, const onnxruntime::Path& path) {
std::filesystem::path base_path(path.ToPathString());
std::filesystem::path parent_path = base_path.parent_path();
std::filesystem::path engine_path = parent_path.append(engine_cache_path);
return engine_path;
}

/*
* Update ep_cache_context attribute of the EP context node with the given engine binary data
*/
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
char* engine_data,
size_t size) {
ONNX_NAMESPACE::GraphProto* graph_proto = model_proto->mutable_graph();
ONNX_NAMESPACE::NodeProto* node_proto = graph_proto->mutable_node(0);

for (int i = 0; i < node_proto->attribute_size(); ++i) {
ONNX_NAMESPACE::AttributeProto* attribute_proto = node_proto->mutable_attribute(i);
if (attribute_proto->name() == EP_CACHE_CONTEXT) {
std::string engine_data_str = "";
if (size > 0) {
engine_data_str.assign(engine_data, size);
}
attribute_proto->set_s(engine_data_str);
}
}
}

/*
* Create "EP context node" model where engine information is embedded
*/
ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer,
const std::string engine_cache_path,
char* engine_data,
size_t size,
const int64_t embed_mode,
bool compute_capability_enable,
std::string compute_capability,
const logging::Logger* logger) {
auto model_build = graph_viewer.CreateModel(*logger);
auto& graph_build = model_build->MainGraph();

// Get graph inputs and outputs
std::vector<onnxruntime::NodeArg*> inputs, outputs;
for (auto input : graph_viewer.GetInputs()) {
auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto());
inputs.push_back(&n_input);
}

for (auto output : graph_viewer.GetOutputs()) {
auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto());
outputs.push_back(&n_output);
}

// Create EP context node attributes
auto attr_0 = ONNX_NAMESPACE::AttributeProto::Create(); // embed_mode
auto attr_1 = ONNX_NAMESPACE::AttributeProto::Create(); // ep_cache_context
auto attr_2 = ONNX_NAMESPACE::AttributeProto::Create(); // hardware_architecture
std::string engine_data_str = "";
attr_0->set_name(EMBED_MODE);
attr_0->set_type(onnx::AttributeProto_AttributeType_INT);
attr_0->set_i(embed_mode);
attr_1->set_name(EP_CACHE_CONTEXT);
attr_1->set_type(onnx::AttributeProto_AttributeType_STRING);
if (embed_mode) {
if (size > 0) {
engine_data_str.assign(engine_data, size);
}
attr_1->set_s(engine_data_str);
} else {
attr_1->set_s(engine_cache_path);
}
auto node_attributes = ONNX_NAMESPACE::NodeAttributes::Create();
int num_attributes = compute_capability_enable ? 3 : 2;
node_attributes->reserve(num_attributes);
node_attributes->emplace(EMBED_MODE, *attr_0);
node_attributes->emplace(EP_CACHE_CONTEXT, *attr_1);

if (compute_capability_enable) {
attr_2->set_name(COMPUTE_CAPABILITY);
attr_2->set_type(onnx::AttributeProto_AttributeType_STRING);
attr_2->set_s(compute_capability);
node_attributes->emplace(COMPUTE_CAPABILITY, *attr_2);
}

// Create EP context node
graph_build.AddNode(EPCONTEXT_OP, EPCONTEXT_OP, "", inputs, outputs, node_attributes.get(), EPCONTEXT_OP_DOMAIN);
ORT_ENFORCE(graph_build.Resolve().IsOK());

// Serialize modelproto to string
auto new_graph_viewer = graph_build.CreateGraphViewer();
auto model = new_graph_viewer->CreateModel(*logger);
auto model_proto = model->ToProto();
new_graph_viewer->ToProto(*model_proto->mutable_graph(), true, true);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);

return model_proto.release();
}

/*
* Dump "EP context node" model
*
*/
void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string engine_cache_path) {
std::fstream dump(engine_cache_path + "_wrapper.onnx", std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Serialized " + engine_cache_path + "_wrapper.onnx";
}

Status TensorRTCacheModelHandler::GetEpContextFromGraph(const GraphViewer& graph_viewer) {
if (!ValidateEPCtxNode(graph_viewer)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL, "It's not a valid EP Context node");
}
auto node = graph_viewer.GetNode(0);
auto& attrs = node->GetAttributes();

const int64_t embed_mode = attrs.at(EMBED_MODE).i();
if (embed_mode) {
// Get engine from byte stream
const std::string& context_binary = attrs.at(EP_CACHE_CONTEXT).s();
*(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(const_cast<char*>(context_binary.c_str()),
static_cast<size_t>(context_binary.length())));
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] Read engine as binary data from \"ep_cache_context\" attribute of ep context node and deserialized it";
if (!(*trt_engine_)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from binary data");
}
} else {
// Get engine from cache file
std::ifstream engine_file(engine_cache_path_.string(), std::ios::binary | std::ios::in);
engine_file.seekg(0, std::ios::end);
size_t engine_size = engine_file.tellg();
engine_file.seekg(0, std::ios::beg);
std::unique_ptr<char[]> engine_buf{new char[engine_size]};
engine_file.read((char*)engine_buf.get(), engine_size);
*(trt_engine_) = std::unique_ptr<nvinfer1::ICudaEngine>(trt_runtime_->deserializeCudaEngine(engine_buf.get(), engine_size));
LOGS_DEFAULT(VERBOSE) << "[TensorRT EP] DeSerialized " + engine_cache_path_.string();
if (!(*trt_engine_)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, EP_FAIL,
"TensorRT EP could not deserialize engine from cache: " + engine_cache_path_.string());
}
}
return Status::OK();
}

/*
* The sanity check for EP context contrib op.
*/
bool TensorRTCacheModelHandler::ValidateEPCtxNode(const GraphViewer& graph_viewer) {
assert(graph_viewer.NumberOfNodes() == 1);
assert(graph_viewer.GetNode(0)->OpType() == EPCONTEXT_OP);
auto node = graph_viewer.GetNode(0);
auto& attrs = node->GetAttributes();

// Check hardware_architecture(compute_capability) if it's present as an attribute
if (attrs.count(COMPUTE_CAPABILITY) > 0) {
std::string model_compute_capability = attrs.at(COMPUTE_CAPABILITY).s();
if (model_compute_capability != compute_capability_) {
LOGS_DEFAULT(ERROR) << "The compute capability of the engine cache doesn't match with the GPU's compute capability";
LOGS_DEFAULT(ERROR) << "The compute capability of the engine cache: " << model_compute_capability;
LOGS_DEFAULT(ERROR) << "The compute capability of the GPU: " << compute_capability_;
return false;
}
}

// "embed_mode" attr and "ep_cache_context" attr should be present
if (attrs.count(EMBED_MODE) > 0 && attrs.count(EP_CACHE_CONTEXT) > 0) {
// ep_cache_context: payload of the execution provider context if embed_mode=1, or path to the context file if embed_mode=0
const int64_t embed_mode = attrs.at(EMBED_MODE).i();

// engine cache path
if (embed_mode == 0) {
// First assume engine cache path is relatvie to model path,
// If not, then assume the engine cache path is an absolute path.
engine_cache_path_ = LocateEngineRelativeToPath(attrs.at(EP_CACHE_CONTEXT).s(), GetModelPath(graph_viewer));
auto default_engine_cache_path_ = engine_cache_path_;
if (!std::filesystem::exists(engine_cache_path_)) {
engine_cache_path_.assign(attrs.at(EP_CACHE_CONTEXT).s());
if (!std::filesystem::exists(engine_cache_path_)) {
LOGS_DEFAULT(ERROR) << "Can't find " << default_engine_cache_path_.string() << " or " << engine_cache_path_.string() << " TensorRT engine";
return false;
}
}
}
}
return true;
}
} // namespace onnxruntime
55 changes: 55 additions & 0 deletions onnxruntime/core/providers/tensorrt/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include <string>
#include <filesystem>

#include "NvInfer.h"
#include "core/providers/shared_library/provider_api.h"

namespace onnxruntime {

static const std::string EPCONTEXT_OP = "EPContext";
static const std::string EMBED_MODE = "embed_mode";
static const std::string EP_CACHE_CONTEXT = "ep_cache_context";
static const std::string COMPUTE_CAPABILITY = "hardware_architecture";
static const std::string EPCONTEXT_OP_DOMAIN = "com.microsoft";

bool GraphHasCtxNode(const GraphViewer& graph_viewer);
const onnxruntime::Path& GetModelPath(const GraphViewer& graph_viewer);
std::filesystem::path LocateEngineRelativeToPath(std::string engine_cache_path, const onnxruntime::Path& path);
ONNX_NAMESPACE::ModelProto* CreateCtxNodeModel(const GraphViewer& graph_viewer,
const std::string engine_cache_path,
char* engine_data,
size_t size,
const int64_t embed_mode,
bool compute_capability_enable,
std::string compute_capability,
const logging::Logger* logger);
void DumpCtxNodeModel(ONNX_NAMESPACE::ModelProto* model_proto,
const std::string engine_cache_path);
void UpdateCtxNodeModelEngineContext(ONNX_NAMESPACE::ModelProto* model_proto,
char* engine_data,
size_t size);

class TensorRTCacheModelHandler {
public:
TensorRTCacheModelHandler(std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine,
nvinfer1::IRuntime* trt_runtime,
std::string compute_capability) : trt_engine_(trt_engine), trt_runtime_(trt_runtime), compute_capability_(compute_capability) {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(TensorRTCacheModelHandler);

bool ValidateEPCtxNode(const GraphViewer& graph_viewer);

Status GetEpContextFromGraph(const GraphViewer& graph_viewer);

private:
std::unique_ptr<nvinfer1::ICudaEngine>* trt_engine_;
nvinfer1::IRuntime* trt_runtime_;
std::filesystem::path engine_cache_path_;
std::string compute_capability_;
}; // TRTCacheModelHandler
} // namespace onnxruntime
Loading

0 comments on commit 46dd0d3

Please sign in to comment.