Skip to content

Commit

Permalink
remove class QnnCacheModelHandler
Browse files Browse the repository at this point in the history
  • Loading branch information
HectorSVC committed Nov 17, 2023
1 parent 0fa6261 commit 9a8092f
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 146 deletions.
110 changes: 62 additions & 48 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
return Status::OK();
}

Status QnnCacheModelHandler::GetEpContextFromModel(const std::string& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger) {
Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger) {
using namespace onnxruntime;
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger));
Expand All @@ -74,10 +74,10 @@ Status QnnCacheModelHandler::GetEpContextFromModel(const std::string& ctx_onnx_m
qnn_model);
}

Status QnnCacheModelHandler::GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model) {
Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model) {
const auto& node = graph_viewer.Nodes().begin();
NodeAttrHelper node_helper(*node);
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
Expand Down Expand Up @@ -112,12 +112,27 @@ Status QnnCacheModelHandler::GetEpContextFromGraph(const onnxruntime::GraphViewe
qnn_model);
}

Status QnnCacheModelHandler::GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path,
std::string& model_name,
std::string& model_description,
std::string& graph_partition_name,
std::string& cache_source,
const logging::Logger& logger) {
Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
bool is_qnn_ctx_model,
bool is_ctx_cache_file_exist,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger) {
if (is_qnn_ctx_model) {
return GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model);
} else if (is_ctx_cache_file_exist) {
return GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger);
}
return Status::OK();
}

Status GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path,
std::string& model_name,
std::string& model_description,
std::string& graph_partition_name,
std::string& cache_source,
const logging::Logger& logger) {
using namespace onnxruntime;
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger));
Expand All @@ -132,33 +147,29 @@ Status QnnCacheModelHandler::GetMetadataFromEpContextModel(const std::string& ct
return Status::OK();
}

bool QnnCacheModelHandler::IsContextCacheFileExists(const std::string& customer_context_cache_path,
const std::string& model_name,
const std::string& model_description,
const onnxruntime::PathString& model_pathstring) {
model_name_ = model_name;
model_description_ = model_description;
bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
std::string& context_cache_path) {
// Use user provided context cache file path if exist, otherwise try model_file.onnx_ctx.onnx by default
if (!customer_context_cache_path.empty()) {
context_cache_path_ = customer_context_cache_path;
context_cache_path = customer_context_cache_path;
} else if (!model_pathstring.empty()) {
context_cache_path_ = PathToUTF8String(model_pathstring) + "_qnn_ctx.onnx";
context_cache_path = PathToUTF8String(model_pathstring) + "_qnn_ctx.onnx";
}

ctx_file_exists_ = std::filesystem::is_regular_file(context_cache_path_) && std::filesystem::exists(context_cache_path_);
return ctx_file_exists_;
return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path);
}

Status QnnCacheModelHandler::ValidateWithContextFile(const std::string& model_name,
const std::string& graph_partition_name,
const logging::Logger& logger) {
ORT_RETURN_IF(!ctx_file_exists_, "Qnn context binary file not exist for some reason!");

Status ValidateWithContextFile(const std::string& context_cache_path,
const std::string& model_name,
const std::string& model_description,
const std::string& graph_partition_name,
const logging::Logger& logger) {
std::string model_name_from_ctx_cache;
std::string model_description_from_ctx_cache;
std::string graph_partition_name_from_ctx_cache;
std::string cache_source;
ORT_RETURN_IF_ERROR(GetMetadataFromEpContextModel(context_cache_path_,
ORT_RETURN_IF_ERROR(GetMetadataFromEpContextModel(context_cache_path,
model_name_from_ctx_cache,
model_description_from_ctx_cache,
graph_partition_name_from_ctx_cache,
Expand All @@ -175,31 +186,34 @@ Status QnnCacheModelHandler::ValidateWithContextFile(const std::string& model_na
" is different with target: " + model_name +
". Please make sure the context binary file matches the model.");

ORT_RETURN_IF(model_description_ != model_description_from_ctx_cache,
ORT_RETURN_IF(model_description != model_description_from_ctx_cache,
"Model description from context cache metadata: " + model_description_from_ctx_cache +
" is different with target: " + model_description_ +
" is different with target: " + model_description +
". Please make sure the context binary file matches the model.");

ORT_RETURN_IF(graph_partition_name != graph_partition_name_from_ctx_cache && get_capability_round_2_,
ORT_RETURN_IF(graph_partition_name != graph_partition_name_from_ctx_cache,
"Graph name from context cache metadata: " + graph_partition_name_from_ctx_cache +
" is different with target: " + graph_partition_name +
". You may need to re-generate the context binary file.");

get_capability_round_2_ = true;
return Status::OK();
}

Status QnnCacheModelHandler::GenerateCtxCacheOnnxModel(unsigned char* buffer,
uint64_t buffer_size,
const std::string& sdk_build_version,
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
const logging::Logger& logger) {
Status GenerateCtxCacheOnnxModel(const std::string model_name,
const std::string model_description,
unsigned char* buffer,
uint64_t buffer_size,
const std::string& sdk_build_version,
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
const std::string& context_cache_path,
bool qnn_context_embed_mode,
const logging::Logger& logger) {
std::unordered_map<std::string, int> domain_to_version = {{kOnnxDomain, 11}, {kMSDomain, 1}};
Model model(model_name_, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
Model model(model_name, false, ModelMetaData(), PathString(), IOnnxRuntimeOpSchemaRegistryList(),
domain_to_version, {}, logger);
auto& graph = model.MainGraph();
graph.SetDescription(model_description_);
graph.SetDescription(model_description);

using namespace ONNX_NAMESPACE;
int index = 0;
Expand Down Expand Up @@ -227,13 +241,13 @@ Status QnnCacheModelHandler::GenerateCtxCacheOnnxModel(unsigned char* buffer,

// Only dump the context buffer once since all QNN graph are in one single context
if (0 == index) {
if (qnn_context_embed_mode_) {
if (qnn_context_embed_mode) {
std::string cache_payload(buffer, buffer + buffer_size);
ep_node.AddAttribute(EP_CACHE_CONTEXT, cache_payload);
} else {
std::string context_cache_path(context_cache_path_ + "_" + graph_name + ".bin");
std::string context_cache_name(std::filesystem::path(context_cache_path).filename().string());
std::ofstream of_stream(context_cache_path.c_str(), std::ofstream::binary);
std::string context_bin_path(context_cache_path + "_" + graph_name + ".bin");
std::string context_cache_name(std::filesystem::path(context_bin_path).filename().string());
std::ofstream of_stream(context_bin_path.c_str(), std::ofstream::binary);
if (!of_stream) {
LOGS(logger, ERROR) << "Failed to open create context file.";
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to open context cache file.");
Expand All @@ -244,15 +258,15 @@ Status QnnCacheModelHandler::GenerateCtxCacheOnnxModel(unsigned char* buffer,
} else {
ep_node.AddAttribute(MAIN_CONTEXT, static_cast<int64_t>(0));
}
int64_t embed_mode = qnn_context_embed_mode_ ? static_cast<int64_t>(1) : static_cast<int64_t>(0);
int64_t embed_mode = qnn_context_embed_mode ? static_cast<int64_t>(1) : static_cast<int64_t>(0);
ep_node.AddAttribute(EMBED_MODE, embed_mode);
ep_node.AddAttribute(EP_SDK_VER, sdk_build_version);
ep_node.AddAttribute(PARTITION_NAME, graph_name);
ep_node.AddAttribute(SOURCE, kQnnExecutionProvider);
++index;
}
ORT_RETURN_IF_ERROR(graph.Resolve());
ORT_RETURN_IF_ERROR(Model::Save(model, context_cache_path_));
ORT_RETURN_IF_ERROR(Model::Save(model, context_cache_path));

return Status::OK();
}
Expand Down
106 changes: 39 additions & 67 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,78 +38,50 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
std::vector<NodeArg*>& node_args,
onnxruntime::Graph& graph);

class QnnCacheModelHandler {
public:
QnnCacheModelHandler(bool qnn_context_embed_mode) : qnn_context_embed_mode_(qnn_context_embed_mode) {
}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnCacheModelHandler);

Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
bool is_qnn_ctx_model,
bool is_ctx_cache_file_exist,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger) {
if (is_qnn_ctx_model) {
return GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model);
} else if (is_ctx_cache_file_exist) {
return GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger);
}
return Status::OK();
}

bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
const std::string& model_name,
const std::string& model_description,
const onnxruntime::PathString& model_pathstring);

bool GetIsContextCacheFileExists() const {
return ctx_file_exists_;
}

Status ValidateWithContextFile(const std::string& model_name,
const std::string& graph_name,
const logging::Logger& logger);
bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
std::string& context_cache_path);

Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger);

Status GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path,
std::string& model_name,
std::string& model_description,
std::string& graph_partition_name,
std::string& cache_source,
const logging::Logger& logger);

Status GenerateCtxCacheOnnxModel(unsigned char* buffer,
uint64_t buffer_size,
const std::string& sdk_build_version,
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
const logging::Logger& logger);

private:
Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model);

Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
bool is_qnn_ctx_model,
bool is_ctx_cache_file_exist,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger);

Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model);

private:
bool is_metadata_ready_ = false;
// model_name_ to cache_source_ -- metadata get from generated Qnn context binary Onnx model
std::string model_name_ = "";
std::string model_description_ = "";
std::string graph_partition_name_ = "";
std::string cache_source_ = "";

std::string context_cache_path_ = "";
bool ctx_file_exists_ = false;
bool get_capability_round_2_ = false;
bool qnn_context_embed_mode_ = true;
}; // QnnCacheModelHandler
Status ValidateWithContextFile(const std::string& context_cache_path,
const std::string& model_name,
const std::string& model_description,
const std::string& graph_partition_name,
const logging::Logger& logger);

Status GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path,
std::string& model_name,
std::string& model_description,
std::string& graph_partition_name,
std::string& cache_source,
const logging::Logger& logger);

Status GenerateCtxCacheOnnxModel(const std::string model_name,
const std::string model_description,
unsigned char* buffer,
uint64_t buffer_size,
const std::string& sdk_build_version,
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
const std::string& context_cache_path,
bool qnn_context_embed_mode,
const logging::Logger& logger);
} // namespace qnn
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ namespace onnxruntime {
namespace qnn {

class QnnModel;
class QnnCacheModelHandler;

class QnnBackendManager {
public:
Expand Down
Loading

0 comments on commit 9a8092f

Please sign in to comment.