From b57aa2853cf77d67ae5776611c992bc04e1c8507 Mon Sep 17 00:00:00 2001 From: wejoncy Date: Mon, 16 Dec 2024 17:26:10 +0800 Subject: [PATCH] address comments --- .../providers/coreml/builders/model_builder.cc | 4 +++- .../providers/coreml/coreml_execution_provider.cc | 14 +++++++++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/model_builder.cc b/onnxruntime/core/providers/coreml/builders/model_builder.cc index 6cc242c017d62..db2af3d4dddc1 100644 --- a/onnxruntime/core/providers/coreml/builders/model_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/model_builder.cc @@ -410,6 +410,8 @@ std::string GetModelOutputPath(const CoreMLOptions& coreml_options, std::string_view cache_key = std::string_view(subgraph_name).substr(0, subgraph_name.find_first_of("_")); path = MakeString(std::string(coreml_options.ModelCachePath()), "/", cache_key); ORT_THROW_IF_ERROR(Env::Default().CreateFolder(path)); + // Write the model path to a file in the cache directory. + // This is for developers to know what the cached model is as we used a hash for the directory name. if (!Env::Default().FileExists(ToPathString(path + "/model.txt"))) { const Graph* main_graph = &graph_viewer.GetGraph(); while (main_graph->IsSubgraph()) { @@ -422,7 +424,7 @@ std::string GetModelOutputPath(const CoreMLOptions& coreml_options, } path = MakeString(path, "/", subgraph_name); - // Set the model cache path with equireStaticShape and ModelFormat + // Set the model cache path with setting of RequireStaticShape and ModelFormat if (coreml_options.RequireStaticShape()) { path += "/static_shape"; } else { diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index 9ea111a105cdb..fa725d5b2bd46 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -58,18 +58,22 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie [&]() { HashValue model_hash; int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); - std::string user_provide_key; + std::string user_provided_key; const Graph* main_graph = &graph_viewer.GetGraph(); while (main_graph->IsSubgraph()) { main_graph = main_graph->ParentGraph(); } if (main_graph->GetModel().MetaData().count("CACHE_KEY") > 0) { - user_provide_key = graph_viewer.GetGraph().GetModel().MetaData().at("CACHE_KEY"); + user_provided_key = graph_viewer.GetGraph().GetModel().MetaData().at("CACHE_KEY"); } else { - // model_hash is a 64-bit hash value of model_path - user_provide_key = std::to_string(model_hash); + // model_hash is a 64-bit hash value of model_path if model_path is not empty, + // otherwise it hashes the graph input names and all the node output names. + // it can't guarantee the uniqueness of the key, so user should manager the key by themselves for the best. + user_provided_key = std::to_string(model_hash); } - return MakeString(user_provide_key, "_", COREML, "_", model_hash, "_", metadef_id); + // The string format is used by onnxruntime/core/providers/coreml/builders/model_builder.cc::GetModelOutputPath + // If the format changes, the function should be updated accordingly. + return MakeString(user_provided_key, "_", COREML, "_", model_hash, "_", metadef_id); }; result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {},