Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Dec 16, 2024
1 parent 4a5772f commit b57aa28
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 6 deletions.
4 changes: 3 additions & 1 deletion onnxruntime/core/providers/coreml/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,8 @@ std::string GetModelOutputPath(const CoreMLOptions& coreml_options,
std::string_view cache_key = std::string_view(subgraph_name).substr(0, subgraph_name.find_first_of("_"));
path = MakeString(std::string(coreml_options.ModelCachePath()), "/", cache_key);
ORT_THROW_IF_ERROR(Env::Default().CreateFolder(path));
// Write the model path to a file in the cache directory.
// This is for developers to know what the cached model is as we used a hash for the directory name.
if (!Env::Default().FileExists(ToPathString(path + "/model.txt"))) {
const Graph* main_graph = &graph_viewer.GetGraph();
while (main_graph->IsSubgraph()) {
Expand All @@ -422,7 +424,7 @@ std::string GetModelOutputPath(const CoreMLOptions& coreml_options,
}

path = MakeString(path, "/", subgraph_name);
// Set the model cache path with equireStaticShape and ModelFormat
// Set the model cache path with setting of RequireStaticShape and ModelFormat
if (coreml_options.RequireStaticShape()) {
path += "/static_shape";
} else {
Expand Down
14 changes: 9 additions & 5 deletions onnxruntime/core/providers/coreml/coreml_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,22 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
[&]() {
HashValue model_hash;
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
std::string user_provide_key;
std::string user_provided_key;
const Graph* main_graph = &graph_viewer.GetGraph();
while (main_graph->IsSubgraph()) {
main_graph = main_graph->ParentGraph();
}
if (main_graph->GetModel().MetaData().count("CACHE_KEY") > 0) {
user_provide_key = graph_viewer.GetGraph().GetModel().MetaData().at("CACHE_KEY");
user_provided_key = graph_viewer.GetGraph().GetModel().MetaData().at("CACHE_KEY");
} else {
// model_hash is a 64-bit hash value of model_path
user_provide_key = std::to_string(model_hash);
// model_hash is a 64-bit hash value of model_path if model_path is not empty,
// otherwise it hashes the graph input names and all the node output names.
// it can't guarantee the uniqueness of the key, so user should manager the key by themselves for the best.
user_provided_key = std::to_string(model_hash);
}
return MakeString(user_provide_key, "_", COREML, "_", model_hash, "_", metadef_id);
// The string format is used by onnxruntime/core/providers/coreml/builders/model_builder.cc::GetModelOutputPath
// If the format changes, the function should be updated accordingly.
return MakeString(user_provided_key, "_", COREML, "_", model_hash, "_", metadef_id);
};

result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {},
Expand Down

0 comments on commit b57aa28

Please sign in to comment.