Skip to content

Commit

Permalink
Fix issue that the generated context cache model inputs/outputs order…
Browse files Browse the repository at this point in the history
… is not guaranteed (#19195)

Fix issue that the generated context cache model inputs/outputs order is not guaranteed

### Description
Currently, QNN EP generate the context cache model in Compile() method which only get access to the partitioned graph. And the inputs/outputs order for the partitioned graph is not guaranteed. And EP doesn't have the view of the input user model. Have to move the context cache model generation to a higher level in GraphPartitioner which has the view of the partitioned model.
This is also a break down of PR for multi-partition support.
#18865
  • Loading branch information
HectorSVC authored and rachguo committed Jan 23, 2024
1 parent 5eebd09 commit 3c2065c
Show file tree
Hide file tree
Showing 13 changed files with 210 additions and 26 deletions.
9 changes: 9 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,15 @@ class IExecutionProvider {
*/
virtual std::vector<AllocatorPtr> CreatePreferredAllocators() { return std::vector<AllocatorPtr>(); };

/**
* Get the array of pointers for EPContext nodes
* EP needs to implement this if has the requirement to generate the context cache model. Otherwise leave it.
* Default return an empty vector if not provided by the Execution Provider
*/
virtual const InlinedVector<const Node*> GetEpContextNodes() const {
return InlinedVector<const Node*>();
}

private:
const std::string type_;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFil
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes =
"session.optimized_model_external_initializers_min_size_in_bytes";

// Enable EP context feature to dump the partitioned graph which include the EP context into Onnx file.
// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file.
// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead.
// "0": disable. (default)
// "1": enable.
Expand Down
105 changes: 105 additions & 0 deletions onnxruntime/core/framework/graph_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "core/graph/function_utils.h"
#include "core/graph/graph_viewer.h"
#include "core/graph/model.h"
#include "core/session/onnxruntime_session_options_config_keys.h"

// uncomment this line to count non-CUDA ops in ONNX domain
// #define COUNT_NON_CUDA_OPS
Expand Down Expand Up @@ -634,6 +635,100 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide
return Status::OK();
}

static Status CreateEpContextModel(const ExecutionProviders& execution_providers,
const Graph& graph,
const std::string& ep_context_path,
const logging::Logger& logger) {
InlinedVector<const Node*> all_ep_context_nodes;
for (const auto& ep : execution_providers) {
const InlinedVector<const Node*> ep_context_nodes = ep->GetEpContextNodes();
all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end());
}

auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair<bool, const Node*> {
for (auto& node : all_ep_context_nodes) {
if (node_name == node->Name()) {
return std::make_pair(true, node);
}
}
return std::make_pair(false, static_cast<const Node*>(nullptr));
};

onnxruntime::PathString context_cache_path;
PathString model_pathstring = graph.ModelPath().ToPathString();
if (all_ep_context_nodes.size() > 0) {
if (!ep_context_path.empty()) {
context_cache_path = ToPathString(ep_context_path);
} else if (!model_pathstring.empty()) {
context_cache_path = model_pathstring + ToPathString("_ctx.onnx");
}

{
#ifdef _WIN32
std::wifstream fs(context_cache_path);
#else
std::ifstream fs(context_cache_path);
#endif
ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already.");
}

Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
graph.DomainToVersionMap(), {}, logger);
auto& ep_graph = ep_context_model.MainGraph();
ep_graph.SetDescription(graph.Description());

// Set inputs outputs explicitly to make sure the order is same as the user model.
auto inputs = graph.GetInputs();
auto outputs = graph.GetOutputs();

InlinedVector<const NodeArg*> ep_graph_inputs;
ep_graph_inputs.reserve(inputs.size());
for (auto& input : inputs) {
auto input_arg = graph.GetNodeArg(input->Name());
auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto());
ep_graph_inputs.push_back(&ep_graph_input_arg);
}

InlinedVector<const NodeArg*> ep_graph_outputs;
ep_graph_outputs.reserve(outputs.size());
for (auto& output : outputs) {
auto output_arg = graph.GetNodeArg(output->Name());
auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto());
ep_graph_outputs.push_back(&ep_graph_output_arg);
}

ep_graph.SetInputs(ep_graph_inputs);
ep_graph.SetOutputs(ep_graph_outputs);

for (const auto& node : graph.Nodes()) {
// the fused node and EPContext node has same node name
auto ep_context_node = get_ep_context_node(node.Name());
// Use EpContext node created by the EPs if name matched, otherwise use node from original model
if (ep_context_node.first) {
ep_graph.AddNode(*ep_context_node.second);
} else {
ep_graph.AddNode(node);
}
}

// handle initializers
for (const auto& input : graph.GetInputsIncludingInitializers()) {
const ONNX_NAMESPACE::TensorProto* initializer = nullptr;
if (graph.GetInitializedTensor(input->Name(), initializer)) {
// There initializer could have duplicates so make sure we only add once
const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr;
if (!ep_graph.GetInitializedTensor(input->Name(), subgraph_initializer)) {
ep_graph.AddInitializedTensor(*initializer);
}
}
}

ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path));
}

return Status::OK();
}

static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode,
const ExecutionProviders& execution_providers,
KernelRegistryManager& kernel_registry_manager) {
Expand Down Expand Up @@ -840,6 +935,8 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model,

Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
const layout_transformation::TransformLayoutFunction& transform_layout_function,
const ConfigOptions& config_options,
const logging::Logger& logger,
Mode mode,
const layout_transformation::DebugGraphFn& debug_graph_fn) const {
// It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now.
Expand Down Expand Up @@ -886,7 +983,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr,
#if !defined(ORT_MINIMAL_BUILD)
ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode,
providers_, kernel_registry_mgr_));

bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1";
std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, "");
if (ep_context_enabled) {
ORT_RETURN_IF_ERROR(CreateEpContextModel(providers_, graph, ep_context_path, logger));
}
#else
ORT_UNUSED_PARAMETER(config_options);
ORT_UNUSED_PARAMETER(logger);
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build.");
#endif //! defined(ORT_MINIMAL_BUILD)
} else {
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/framework/graph_partitioner.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ namespace onnxruntime {
class ExecutionProviders;
class KernelRegistryManager;
class Model;
struct ConfigOptions;

class GraphPartitioner {
public:
Expand All @@ -31,6 +32,8 @@ class GraphPartitioner {
// Run partitioning.
Status Partition(Graph& graph, FuncManager& func_mgr,
const layout_transformation::TransformLayoutFunction& transform_layout_function,
const ConfigOptions& config_options,
const logging::Logger& logger,
Mode mode = Mode::kNormal,
const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const;

Expand Down
13 changes: 3 additions & 10 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,7 @@ Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path
return Status::OK();
}

Status GenerateCtxCacheOnnxModel(const std::string model_name,
const std::string model_description,
Status GenerateCtxCacheOnnxModel(Model* model,
unsigned char* buffer,
uint64_t buffer_size,
const std::string& sdk_build_version,
Expand All @@ -240,11 +239,7 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name,
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
const logging::Logger& logger) {
std::unordered_map<std::string, int> domain_to_version = {{kOnnxDomain, 11}, {kMSDomain, 1}};
Model model(model_name, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, logger);
auto& graph = model.MainGraph();
graph.SetDescription(model_description);
auto& graph = model->MainGraph();

using namespace ONNX_NAMESPACE;
int index = 0;
Expand All @@ -270,7 +265,7 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name,
nullptr,
kMSDomain);

// Only dump the context buffer once since all QNN graph are in one single context
// Only dump the context buffer once since all QNN graphs are in one single context
if (0 == index) {
if (qnn_context_embed_mode) {
std::string cache_payload(buffer, buffer + buffer_size);
Expand All @@ -296,8 +291,6 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name,
ep_node.AddAttribute(SOURCE, kQnnExecutionProvider);
++index;
}
ORT_RETURN_IF_ERROR(graph.Resolve());
ORT_RETURN_IF_ERROR(Model::Save(model, context_cache_path));

return Status::OK();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_mod
std::string& cache_source,
const logging::Logger& logger);

Status GenerateCtxCacheOnnxModel(const std::string model_name,
const std::string model_description,
Status GenerateCtxCacheOnnxModel(Model* model,
unsigned char* buffer,
uint64_t buffer_size,
const std::string& sdk_build_version,
Expand Down
16 changes: 14 additions & 2 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -613,8 +613,8 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature.");
uint64_t buffer_size(0);
auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size);
ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(model_name,
model_description,
qnn_ep_context_model_ = std::make_unique<Model>("qnn_ep_context_model", false, logger);
ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(qnn_ep_context_model_.get(),
context_buffer.get(),
buffer_size,
qnn_backend_manager_->GetSdkVersion(),
Expand All @@ -626,4 +626,16 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
}
return Status::OK();
}

const InlinedVector<const Node*> QNNExecutionProvider::GetEpContextNodes() const {
InlinedVector<const Node*> ep_context_nodes;
if (qnn_ep_context_model_) {
const auto& graph = qnn_ep_context_model_->MainGraph();
for (const auto& node : graph.Nodes()) {
ep_context_nodes.push_back(graph.GetNode(node.Index()));
}
}

return ep_context_nodes;
}
} // namespace onnxruntime
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "core/providers/qnn/builder/qnn_backend_manager.h"
#include "core/providers/qnn/builder/qnn_model.h"
#include "core/providers/qnn/builder/qnn_graph_configs_helper.h"
#include "core/graph/model.h"

namespace onnxruntime {

Expand All @@ -35,6 +36,8 @@ class QNNExecutionProvider : public IExecutionProvider {

DataLayout GetPreferredLayout() const override;

const InlinedVector<const Node*> GetEpContextNodes() const override;

private:
bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit,
std::unordered_map<const NodeUnit*, bool>& node_unit_supported_result,
Expand Down Expand Up @@ -66,6 +69,7 @@ class QNNExecutionProvider : public IExecutionProvider {
bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session.
bool qnn_context_embed_mode_ = true;
int32_t vtcm_size_in_mb_ = 0;
std::unique_ptr<onnxruntime::Model> qnn_ep_context_model_;
};

} // namespace onnxruntime
9 changes: 7 additions & 2 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1164,6 +1164,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool

// Do partitioning based on execution providers' capabilities.
ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state_->GetMutableFuncMgr(), transform_layout_fn,
session_options_.config_options, *session_logger_,
mode, debug_graph_fn));

// apply Level2 and higher transformers.
Expand Down Expand Up @@ -1458,7 +1459,9 @@ namespace {
Status PartitionOrtFormatModel(onnxruntime::Graph& graph,
const ExecutionProviders& providers,
KernelRegistryManager& kernel_registry_manager,
SessionState& session_state) {
SessionState& session_state,
const ConfigOptions& config_options,
const logging::Logger& logger) {
layout_transformation::TransformLayoutFunction transform_layout_fn = nullptr;

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Expand All @@ -1479,6 +1482,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph,
ORT_RETURN_IF_ERROR(partitioner.Partition(graph,
session_state.GetMutableFuncMgr(),
transform_layout_fn,
config_options,
logger,
GraphPartitioner::Mode::kOrtFormatLoad));

return Status::OK();
Expand Down Expand Up @@ -1833,7 +1838,7 @@ common::Status InferenceSession::Initialize() {
#endif // !defined(ORT_MINIMAL_BUILD)
} else {
ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_,
*session_state_));
*session_state_, session_options_.config_options, *session_logger_));

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider);
Expand Down
25 changes: 16 additions & 9 deletions onnxruntime/test/framework/session_state_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,16 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) {

GraphPartitioner partitioner(krm, execution_providers);
ASSERT_STATUS_OK(
partitioner.Partition(graph, session_state.GetMutableFuncMgr(),
[](Graph& graph, bool& modified, const IExecutionProvider& execution_provider,
const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status {
AllocatorPtr cpu_allocator = std::make_shared<CPUAllocator>();
return layout_transformation::TransformLayoutForEP(
graph, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn);
}));
partitioner.Partition(
graph, session_state.GetMutableFuncMgr(),
[](Graph& graph, bool& modified, const IExecutionProvider& execution_provider,
const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status {
AllocatorPtr cpu_allocator = std::make_shared<CPUAllocator>();
return layout_transformation::TransformLayoutForEP(
graph, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn);
},
sess_options.config_options,
DefaultLoggingManager().DefaultLogger()));

ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm));

Expand Down Expand Up @@ -257,7 +260,9 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) {
const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status {
return layout_transformation::TransformLayoutForEP(graph, modified, execution_provider,
cpu_allocator, debug_graph_fn);
}));
},
sess_options.config_options,
DefaultLoggingManager().DefaultLogger()));

ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm));

Expand Down Expand Up @@ -314,7 +319,9 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) {
const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status {
return layout_transformation::TransformLayoutForEP(
graph, modified, execution_provider, cpu_allocator, debug_graph_fn);
}));
},
sess_options.config_options,
DefaultLoggingManager().DefaultLogger()));

// Finalize the session state
ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm));
Expand Down
Loading

0 comments on commit 3c2065c

Please sign in to comment.