diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index ea4f52f99649d..1de0217c7e1fa 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -326,6 +326,15 @@ class IExecutionProvider { */ virtual std::vector CreatePreferredAllocators() { return std::vector(); }; + /** + * Get the array of pointers for EPContext nodes + * EP needs to implement this if has the requirement to generate the context cache model. Otherwise leave it. + * Default return an empty vector if not provided by the Execution Provider + */ + virtual const InlinedVector GetEpContextNodes() const { + return InlinedVector(); + } + private: const std::string type_; diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index df79cb6e5b21b..8fd51962bf087 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -236,7 +236,7 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFil static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes = "session.optimized_model_external_initializers_min_size_in_bytes"; -// Enable EP context feature to dump the partitioned graph which include the EP context into Onnx file. +// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file. // The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead. // "0": disable. (default) // "1": enable. diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index e4fe0c7564548..07b465c80745a 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -16,6 +16,7 @@ #include "core/graph/function_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" +#include "core/session/onnxruntime_session_options_config_keys.h" // uncomment this line to count non-CUDA ops in ONNX domain // #define COUNT_NON_CUDA_OPS @@ -634,6 +635,100 @@ static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_provide return Status::OK(); } +static Status CreateEpContextModel(const ExecutionProviders& execution_providers, + const Graph& graph, + const std::string& ep_context_path, + const logging::Logger& logger) { + InlinedVector all_ep_context_nodes; + for (const auto& ep : execution_providers) { + const InlinedVector ep_context_nodes = ep->GetEpContextNodes(); + all_ep_context_nodes.insert(all_ep_context_nodes.begin(), ep_context_nodes.begin(), ep_context_nodes.end()); + } + + auto get_ep_context_node = [&all_ep_context_nodes](const std::string& node_name) -> std::pair { + for (auto& node : all_ep_context_nodes) { + if (node_name == node->Name()) { + return std::make_pair(true, node); + } + } + return std::make_pair(false, static_cast(nullptr)); + }; + + onnxruntime::PathString context_cache_path; + PathString model_pathstring = graph.ModelPath().ToPathString(); + if (all_ep_context_nodes.size() > 0) { + if (!ep_context_path.empty()) { + context_cache_path = ToPathString(ep_context_path); + } else if (!model_pathstring.empty()) { + context_cache_path = model_pathstring + ToPathString("_ctx.onnx"); + } + + { +#ifdef _WIN32 + std::wifstream fs(context_cache_path); +#else + std::ifstream fs(context_cache_path); +#endif + ORT_RETURN_IF(fs.good(), "Failed to generate EP context model since the file exist already."); + } + + Model ep_context_model(graph.Name(), false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), + graph.DomainToVersionMap(), {}, logger); + auto& ep_graph = ep_context_model.MainGraph(); + ep_graph.SetDescription(graph.Description()); + + // Set inputs outputs explicitly to make sure the order is same as the user model. + auto inputs = graph.GetInputs(); + auto outputs = graph.GetOutputs(); + + InlinedVector ep_graph_inputs; + ep_graph_inputs.reserve(inputs.size()); + for (auto& input : inputs) { + auto input_arg = graph.GetNodeArg(input->Name()); + auto& ep_graph_input_arg = ep_graph.GetOrCreateNodeArg(input_arg->Name(), input_arg->TypeAsProto()); + ep_graph_inputs.push_back(&ep_graph_input_arg); + } + + InlinedVector ep_graph_outputs; + ep_graph_outputs.reserve(outputs.size()); + for (auto& output : outputs) { + auto output_arg = graph.GetNodeArg(output->Name()); + auto& ep_graph_output_arg = ep_graph.GetOrCreateNodeArg(output_arg->Name(), output_arg->TypeAsProto()); + ep_graph_outputs.push_back(&ep_graph_output_arg); + } + + ep_graph.SetInputs(ep_graph_inputs); + ep_graph.SetOutputs(ep_graph_outputs); + + for (const auto& node : graph.Nodes()) { + // the fused node and EPContext node has same node name + auto ep_context_node = get_ep_context_node(node.Name()); + // Use EpContext node created by the EPs if name matched, otherwise use node from original model + if (ep_context_node.first) { + ep_graph.AddNode(*ep_context_node.second); + } else { + ep_graph.AddNode(node); + } + } + + // handle initializers + for (const auto& input : graph.GetInputsIncludingInitializers()) { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + if (graph.GetInitializedTensor(input->Name(), initializer)) { + // There initializer could have duplicates so make sure we only add once + const ONNX_NAMESPACE::TensorProto* subgraph_initializer = nullptr; + if (!ep_graph.GetInitializedTensor(input->Name(), subgraph_initializer)) { + ep_graph.AddInitializedTensor(*initializer); + } + } + } + + ORT_RETURN_IF_ERROR(Model::Save(ep_context_model, context_cache_path)); + } + + return Status::OK(); +} + static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager) { @@ -840,6 +935,8 @@ Status GraphPartitioner::InlineFunctionsAOT(Model& model, Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, + const ConfigOptions& config_options, + const logging::Logger& logger, Mode mode, const layout_transformation::DebugGraphFn& debug_graph_fn) const { // It is a greedy partitioning algorithm per provider preferences user provided when calling ONNX RUNTIME right now. @@ -886,7 +983,15 @@ Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, #if !defined(ORT_MINIMAL_BUILD) ORT_RETURN_IF_ERROR(PartitionOnnxFormatModel(partition_params, mode, providers_, kernel_registry_mgr_)); + + bool ep_context_enabled = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextEnable, "0") == "1"; + std::string ep_context_path = config_options.GetConfigOrDefault(kOrtSessionOptionEpContextFilePath, ""); + if (ep_context_enabled) { + ORT_RETURN_IF_ERROR(CreateEpContextModel(providers_, graph, ep_context_path, logger)); + } #else + ORT_UNUSED_PARAMETER(config_options); + ORT_UNUSED_PARAMETER(logger); return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "ONNX models are not supported in this build."); #endif //! defined(ORT_MINIMAL_BUILD) } else { diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 4fc85c2588260..d1ef193cf1520 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -13,6 +13,7 @@ namespace onnxruntime { class ExecutionProviders; class KernelRegistryManager; class Model; +struct ConfigOptions; class GraphPartitioner { public: @@ -31,6 +32,8 @@ class GraphPartitioner { // Run partitioning. Status Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, + const ConfigOptions& config_options, + const logging::Logger& logger, Mode mode = Mode::kNormal, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc index fd9bf200c45ef..5d3f406f50612 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc @@ -230,8 +230,7 @@ Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path return Status::OK(); } -Status GenerateCtxCacheOnnxModel(const std::string model_name, - const std::string model_description, +Status GenerateCtxCacheOnnxModel(Model* model, unsigned char* buffer, uint64_t buffer_size, const std::string& sdk_build_version, @@ -240,11 +239,7 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name, const onnxruntime::PathString& context_cache_path, bool qnn_context_embed_mode, const logging::Logger& logger) { - std::unordered_map domain_to_version = {{kOnnxDomain, 11}, {kMSDomain, 1}}; - Model model(model_name, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(), - domain_to_version, {}, logger); - auto& graph = model.MainGraph(); - graph.SetDescription(model_description); + auto& graph = model->MainGraph(); using namespace ONNX_NAMESPACE; int index = 0; @@ -270,7 +265,7 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name, nullptr, kMSDomain); - // Only dump the context buffer once since all QNN graph are in one single context + // Only dump the context buffer once since all QNN graphs are in one single context if (0 == index) { if (qnn_context_embed_mode) { std::string cache_payload(buffer, buffer + buffer_size); @@ -296,8 +291,6 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name, ep_node.AddAttribute(SOURCE, kQnnExecutionProvider); ++index; } - ORT_RETURN_IF_ERROR(graph.Resolve()); - ORT_RETURN_IF_ERROR(Model::Save(model, context_cache_path)); return Status::OK(); } diff --git a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h index 0011d0f43f5bc..ba6fe23ecd56e 100644 --- a/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h +++ b/onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h @@ -73,8 +73,7 @@ Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_mod std::string& cache_source, const logging::Logger& logger); -Status GenerateCtxCacheOnnxModel(const std::string model_name, - const std::string model_description, +Status GenerateCtxCacheOnnxModel(Model* model, unsigned char* buffer, uint64_t buffer_size, const std::string& sdk_build_version, diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 04bd58c237141..56eb1f4f59f33 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -613,8 +613,8 @@ Status QNNExecutionProvider::Compile(const std::vector& fused ORT_RETURN_IF(fused_nodes_and_graphs.size() != 1, "Only support single partition for context cache feature."); uint64_t buffer_size(0); auto context_buffer = qnn_backend_manager_->GetContextBinaryBuffer(buffer_size); - ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(model_name, - model_description, + qnn_ep_context_model_ = std::make_unique("qnn_ep_context_model", false, logger); + ORT_RETURN_IF_ERROR(qnn::GenerateCtxCacheOnnxModel(qnn_ep_context_model_.get(), context_buffer.get(), buffer_size, qnn_backend_manager_->GetSdkVersion(), @@ -626,4 +626,16 @@ Status QNNExecutionProvider::Compile(const std::vector& fused } return Status::OK(); } + +const InlinedVector QNNExecutionProvider::GetEpContextNodes() const { + InlinedVector ep_context_nodes; + if (qnn_ep_context_model_) { + const auto& graph = qnn_ep_context_model_->MainGraph(); + for (const auto& node : graph.Nodes()) { + ep_context_nodes.push_back(graph.GetNode(node.Index())); + } + } + + return ep_context_nodes; +} } // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index 8b5d0929209ee..d4927f3fa505e 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -9,6 +9,7 @@ #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/providers/qnn/builder/qnn_model.h" #include "core/providers/qnn/builder/qnn_graph_configs_helper.h" +#include "core/graph/model.h" namespace onnxruntime { @@ -35,6 +36,8 @@ class QNNExecutionProvider : public IExecutionProvider { DataLayout GetPreferredLayout() const override; + const InlinedVector GetEpContextNodes() const override; + private: bool IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, std::unordered_map& node_unit_supported_result, @@ -66,6 +69,7 @@ class QNNExecutionProvider : public IExecutionProvider { bool disable_cpu_ep_fallback_ = false; // True if CPU EP fallback has been disabled for this session. bool qnn_context_embed_mode_ = true; int32_t vtcm_size_in_mb_ = 0; + std::unique_ptr qnn_ep_context_model_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 93877c8dd66bd..e8853c8824738 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1164,6 +1164,7 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool // Do partitioning based on execution providers' capabilities. ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state_->GetMutableFuncMgr(), transform_layout_fn, + session_options_.config_options, *session_logger_, mode, debug_graph_fn)); // apply Level2 and higher transformers. @@ -1458,7 +1459,9 @@ namespace { Status PartitionOrtFormatModel(onnxruntime::Graph& graph, const ExecutionProviders& providers, KernelRegistryManager& kernel_registry_manager, - SessionState& session_state) { + SessionState& session_state, + const ConfigOptions& config_options, + const logging::Logger& logger) { layout_transformation::TransformLayoutFunction transform_layout_fn = nullptr; #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1479,6 +1482,8 @@ Status PartitionOrtFormatModel(onnxruntime::Graph& graph, ORT_RETURN_IF_ERROR(partitioner.Partition(graph, session_state.GetMutableFuncMgr(), transform_layout_fn, + config_options, + logger, GraphPartitioner::Mode::kOrtFormatLoad)); return Status::OK(); @@ -1833,7 +1838,7 @@ common::Status InferenceSession::Initialize() { #endif // !defined(ORT_MINIMAL_BUILD) } else { ORT_RETURN_IF_ERROR_SESSIONID_(PartitionOrtFormatModel(graph, execution_providers_, kernel_registry_manager_, - *session_state_)); + *session_state_, session_options_.config_options, *session_logger_)); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) const auto& cpu_ep = *execution_providers_.Get(onnxruntime::kCpuExecutionProvider); diff --git a/onnxruntime/test/framework/session_state_test.cc b/onnxruntime/test/framework/session_state_test.cc index 8990c23e4af39..0c2d8bcb2eb93 100644 --- a/onnxruntime/test/framework/session_state_test.cc +++ b/onnxruntime/test/framework/session_state_test.cc @@ -171,13 +171,16 @@ TEST_P(SessionStateTestP, TestInitializerProcessing) { GraphPartitioner partitioner(krm, execution_providers); ASSERT_STATUS_OK( - partitioner.Partition(graph, session_state.GetMutableFuncMgr(), - [](Graph& graph, bool& modified, const IExecutionProvider& execution_provider, - const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { - AllocatorPtr cpu_allocator = std::make_shared(); - return layout_transformation::TransformLayoutForEP( - graph, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn); - })); + partitioner.Partition( + graph, session_state.GetMutableFuncMgr(), + [](Graph& graph, bool& modified, const IExecutionProvider& execution_provider, + const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { + AllocatorPtr cpu_allocator = std::make_shared(); + return layout_transformation::TransformLayoutForEP( + graph, modified, execution_provider, std::move(cpu_allocator), debug_graph_fn); + }, + sess_options.config_options, + DefaultLoggingManager().DefaultLogger())); ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm)); @@ -257,7 +260,9 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { return layout_transformation::TransformLayoutForEP(graph, modified, execution_provider, cpu_allocator, debug_graph_fn); - })); + }, + sess_options.config_options, + DefaultLoggingManager().DefaultLogger())); ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm)); @@ -314,7 +319,9 @@ TEST(SessionStateTest, TestInitializerMemoryAllocatedUsingNonArenaMemory) { const layout_transformation::DebugGraphFn& debug_graph_fn) -> Status { return layout_transformation::TransformLayoutForEP( graph, modified, execution_provider, cpu_allocator, debug_graph_fn); - })); + }, + sess_options.config_options, + DefaultLoggingManager().DefaultLogger())); // Finalize the session state ASSERT_STATUS_OK(session_state.FinalizeSessionState(oss.str(), krm)); diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index f9064cad3fe12..bc40682cf87b7 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -600,6 +600,51 @@ TEST_F(QnnHTPBackendTests, QnnContextBinaryGeneration2InputTypes) { // Make sure the Qnn context cache binary file is generated EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); +} + +// Generate context cache model from the ONNX models with 2 inputs. +// The generated model should have same input order. +// The input ONNX model is created in the way that the model inputs order +// is different with the order in the graph (topological order). +// It cause issue if the generated model doesn't set the inputs/outputs explicitly. +TEST_F(QnnHTPBackendTests, QnnContextGeneration2InputsOrderIssue) { + ProviderOptions provider_options; +#if defined(_WIN32) + provider_options["backend_path"] = "QnnHtp.dll"; +#else + provider_options["backend_path"] = "libQnnHtp.so"; +#endif + + // Add kMSDomain to cover contrib op like Gelu + const std::unordered_map domain_to_version = {{"", 13}, {kMSDomain, 1}}; + + auto& logging_manager = DefaultLoggingManager(); + logging_manager.SetDefaultLoggerSeverity(logging::Severity::kERROR); + + const std::string context_binary_file = "./qnn_ctx_2_inputs_order_test_gen.onnx"; + Ort::SessionOptions so; + so.AddConfigEntry(kOrtSessionOptionEpContextEnable, "1"); + so.AddConfigEntry(kOrtSessionOptionEpContextFilePath, context_binary_file.c_str()); + + so.AppendExecutionProvider("QNN", provider_options); + + Ort::Session session(*ort_env, ORT_TSTR("testdata/qnn_ctx_2_inputs_order_test.onnx"), so); + + // Make sure the Qnn context cache binary file is generated + EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str())); + + std::shared_ptr model; + ASSERT_STATUS_OK(Model::Load(ToPathString(context_binary_file), model, nullptr, DefaultLoggingManager().DefaultLogger())); + auto inputs = model->MainGraph().GetInputs(); + EXPECT_TRUE(inputs.size() == 2); + EXPECT_TRUE(inputs[0]->Name() == "attention_mask"); + EXPECT_TRUE(inputs[1]->Name() == "Add_input_0"); + + // clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } // A repro of QC case 06838696, accuracy issue for Cast + Op (quantized) diff --git a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc index 4ac1f5ddca643..1e938ae9e334b 100644 --- a/onnxruntime/test/providers/qnn/simple_op_htp_test.cc +++ b/onnxruntime/test/providers/qnn/simple_op_htp_test.cc @@ -778,6 +778,8 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheEmbedModeTest) { QDQTolerance(), logging::Severity::kERROR, context_binary_file); + // Clean up + ASSERT_EQ(std::remove(context_binary_file.c_str()), 0); } // Run QDQ model on HTP 3 times diff --git a/onnxruntime/test/testdata/qnn_ctx_2_inputs_order_test.onnx b/onnxruntime/test/testdata/qnn_ctx_2_inputs_order_test.onnx new file mode 100644 index 0000000000000..46b212dc1fc0e Binary files /dev/null and b/onnxruntime/test/testdata/qnn_ctx_2_inputs_order_test.onnx differ