Skip to content

Commit

Permalink
use onnxruntime::PathString instead of string
Browse files Browse the repository at this point in the history
  • Loading branch information
HectorSVC committed Nov 17, 2023
1 parent 8253a2d commit 4a761ec
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 38 deletions.
38 changes: 20 additions & 18 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
return Status::OK();
}

Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger) {
Expand All @@ -75,7 +75,7 @@ Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
}

Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model) {
const auto& node = graph_viewer.Nodes().begin();
Expand All @@ -89,11 +89,13 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
}

std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, "");
std::filesystem::path folder_path = std::filesystem::path(ctx_onnx_model_path).parent_path();
std::filesystem::path context_binary_path = folder_path.append(external_qnn_context_binary_file_name);

std::string context_binary_path(std::filesystem::path(ctx_onnx_model_path).parent_path().string() +
"/" + external_qnn_context_binary_file_name);
//std::string context_binary_path(std::filesystem::path(ctx_onnx_model_path).parent_path().string() +

Check warning on line 95 in onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc#L95

Should have a space between // and comment [whitespace/comments] [4]
Raw output
onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc:95:  Should have a space between // and comment  [whitespace/comments] [4]
// "/" + external_qnn_context_binary_file_name);
size_t buffer_size{0};
std::ifstream cache_file(context_binary_path.c_str(), std::ifstream::binary);
std::ifstream cache_file(context_binary_path.string().c_str(), std::ifstream::binary);
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file.");

cache_file.seekg(0, cache_file.end);
Expand All @@ -113,7 +115,7 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
}

Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
const onnxruntime::PathString& ctx_onnx_model_path,
bool is_qnn_ctx_model,
bool is_ctx_cache_file_exist,
QnnBackendManager* qnn_backend_manager,
Expand All @@ -126,22 +128,22 @@ Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
status = GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger);
}

if (Status::OK() != status) {
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to load from EpContextModel. ", status.ErrorMessage());
}

return Status::OK();
}

Status GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path,
Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path,
std::string& model_name,
std::string& model_description,
std::string& graph_partition_name,
std::string& cache_source,
const logging::Logger& logger) {
using namespace onnxruntime;

Check warning on line 144 in onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc#L144

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc:144:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]
std::shared_ptr<Model> model;
ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger));
ORT_RETURN_IF_ERROR(Model::Load(ctx_onnx_model_path, model, {}, logger));
const auto& graph = GraphViewer(model->MainGraph());
const auto& node = graph.Nodes().begin();
NodeAttrHelper node_helper(*node);
Expand All @@ -155,18 +157,18 @@ Status GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path,

bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
std::string& context_cache_path) {
onnxruntime::PathString& context_cache_path) {
// Use user provided context cache file path if exist, otherwise try model_file.onnx_ctx.onnx by default
if (!customer_context_cache_path.empty()) {
context_cache_path = customer_context_cache_path;
context_cache_path = ToPathString(customer_context_cache_path);
} else if (!model_pathstring.empty()) {
context_cache_path = PathToUTF8String(model_pathstring) + "_qnn_ctx.onnx";
context_cache_path = model_pathstring + ToPathString("_qnn_ctx.onnx");
}

return std::filesystem::is_regular_file(context_cache_path) && std::filesystem::exists(context_cache_path);
}

Status ValidateWithContextFile(const std::string& context_cache_path,
Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path,
const std::string& model_name,
const std::string& model_description,
const std::string& graph_partition_name,
Expand All @@ -181,7 +183,7 @@ Status ValidateWithContextFile(const std::string& context_cache_path,
graph_partition_name_from_ctx_cache,
cache_source,
logger);
if (Status::OK() != status) {
if (!status.IsOK()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_GRAPH, "Failed to get metadata from EpContextModel.");
}

Expand All @@ -194,9 +196,9 @@ Status ValidateWithContextFile(const std::string& context_cache_path,
if (model_name != model_name_from_ctx_cache ||
model_description != model_description_from_ctx_cache ||
graph_partition_name != graph_partition_name_from_ctx_cache) {
std::string message = onnxruntime::MakeString("Metadata from Onnx file: ",
std::string message = onnxruntime::MakeString("Metadata mismatch. onnx: ",
model_name, " ", model_description, " ", graph_partition_name,
" vs metadata from context cache Onnx file",
" vs epcontext: ",
model_name_from_ctx_cache, " ",
model_description_from_ctx_cache, " ",
graph_partition_name_from_ctx_cache);
Expand All @@ -213,7 +215,7 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name,
const std::string& sdk_build_version,
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,

Check warning on line 217 in onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc#L217

Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc:217:  Add #include <memory> for unique_ptr<>  [build/include_what_you_use] [4]
const std::string& context_cache_path,
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
const logging::Logger& logger) {
std::unordered_map<std::string, int> domain_to_version = {{kOnnxDomain, 11}, {kMSDomain, 1}};
Expand Down Expand Up @@ -252,7 +254,7 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name,
std::string cache_payload(buffer, buffer + buffer_size);
ep_node.AddAttribute(EP_CACHE_CONTEXT, cache_payload);
} else {
std::string context_bin_path(context_cache_path + "_" + graph_name + ".bin");
onnxruntime::PathString context_bin_path = context_cache_path + ToPathString("_" + graph_name + ".bin");
std::string context_cache_name(std::filesystem::path(context_bin_path).filename().string());
std::ofstream of_stream(context_bin_path.c_str(), std::ofstream::binary);
if (!of_stream) {
Expand Down
14 changes: 7 additions & 7 deletions onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,33 +40,33 @@ Status CreateNodeArgs(const std::vector<std::string>& names,

bool IsContextCacheFileExists(const std::string& customer_context_cache_path,
const onnxruntime::PathString& model_pathstring,
std::string& context_cache_path);
onnxruntime::PathString& context_cache_path);

Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
Status GetEpContextFromModel(const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger);

Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
const onnxruntime::PathString& ctx_onnx_model_path,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model);

Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
const std::string& ctx_onnx_model_path,
const onnxruntime::PathString& ctx_onnx_model_path,
bool is_qnn_ctx_model,
bool is_ctx_cache_file_exist,
QnnBackendManager* qnn_backend_manager,
QnnModel& qnn_model,
const logging::Logger& logger);

Status ValidateWithContextFile(const std::string& context_cache_path,
Status ValidateWithContextFile(const onnxruntime::PathString& context_cache_path,
const std::string& model_name,
const std::string& model_description,
const std::string& graph_partition_name,
const logging::Logger& logger);

Status GetMetadataFromEpContextModel(const std::string& ctx_onnx_model_path,
Status GetMetadataFromEpContextModel(const onnxruntime::PathString& ctx_onnx_model_path,
std::string& model_name,
std::string& model_description,
std::string& graph_partition_name,
Expand All @@ -80,7 +80,7 @@ Status GenerateCtxCacheOnnxModel(const std::string model_name,
const std::string& sdk_build_version,
const std::vector<IExecutionProvider::FusedNodeAndGraph>& fused_nodes_and_graphs,
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,

Check warning on line 82 in onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h#L82

Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h:82:  Add #include <memory> for unique_ptr<>  [build/include_what_you_use] [4]

Check warning on line 82 in onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h#L82

Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h:82:  Add #include <unordered_map> for unordered_map<>  [build/include_what_you_use] [4]
const std::string& context_cache_path,
const onnxruntime::PathString& context_cache_path,
bool qnn_context_embed_mode,
const logging::Logger& logger);
} // namespace qnn
Expand Down
15 changes: 3 additions & 12 deletions onnxruntime/core/providers/qnn/qnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,6 @@ namespace onnxruntime {

constexpr const char* QNN = "QNN";

std::string GetFileNameFromModelPath(onnxruntime::Path model_path) {
auto model_path_components = model_path.GetComponents();
// There's no model path if model loaded from buffer stead of file
if (model_path_components.empty()) {
return "";
}
return PathToUTF8String(model_path_components.back());
}

void QNNExecutionProvider::ParseProfilingLevel(std::string profiling_level_string) {
std::transform(profiling_level_string.begin(),
profiling_level_string.end(),
Expand Down Expand Up @@ -342,7 +333,7 @@ QNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer

// This is for case: QDQ model + Onnx Qnn context cache model
if (context_cache_enabled_ && !is_qnn_ctx_model) {
std::string context_cache_path;
onnxruntime::PathString context_cache_path;
load_from_cached_context = qnn::IsContextCacheFileExists(context_cache_path_cfg_,
graph_viewer.ModelPath().ToPathString(),
context_cache_path);
Expand Down Expand Up @@ -536,7 +527,7 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
bool is_qnn_ctx_model = false;
ORT_RETURN_IF_ERROR(qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs, is_qnn_ctx_model));

std::string context_cache_path;
onnxruntime::PathString context_cache_path;
bool is_ctx_file_exist = qnn::IsContextCacheFileExists(context_cache_path_cfg_,
graph_viewer.ModelPath().ToPathString(),
context_cache_path);
Expand All @@ -556,7 +547,7 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
std::unique_ptr<qnn::QnnModel> qnn_model = std::make_unique<qnn::QnnModel>(logger, qnn_backend_manager_.get());
// Load and execute from cached context if exist
ORT_RETURN_IF_ERROR(qnn::LoadQnnCtxFromOnnxModel(graph_viewer,
context_cache_path_cfg_,
context_cache_path,
is_qnn_ctx_model,
is_ctx_file_exist,
qnn_backend_manager_.get(),
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/providers/qnn/simple_op_htp_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ TEST_F(QnnHTPBackendTests, ContextBinaryCacheNonEmbedModeTest) {
// Check the Onnx skeleton file is generated
EXPECT_TRUE(std::filesystem::exists(context_binary_file.c_str()));
// Check the Qnn context cache binary file is generated
EXPECT_TRUE(std::filesystem::exists("qnn_context_cache_non_embed.onnx_QNN_8283143575221199085_1.bin"));
EXPECT_TRUE(std::filesystem::exists("qnn_context_cache_non_embed.onnx_QNNExecutionProvider_QNN_8283143575221199085_1_0.bin"));

Check warning on line 789 in onnxruntime/test/providers/qnn/simple_op_htp_test.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/test/providers/qnn/simple_op_htp_test.cc#L789

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/test/providers/qnn/simple_op_htp_test.cc:789:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

// 2nd run loads and run from QDQ model + Onnx skeleton file + Qnn context cache binary file
TestQDQModelAccuracy(BuildOpTestCase<float>(op_type, {input_def}, {}, {}),
Expand Down

0 comments on commit 4a761ec

Please sign in to comment.