Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Dec 20, 2024
1 parent 5518e38 commit 70075e5
Showing 1 changed file with 20 additions and 21 deletions.
41 changes: 20 additions & 21 deletions onnxruntime/core/providers/coreml/coreml_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,32 +53,31 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_version_,
coreml_options_.RequireStaticShape(), coreml_options_.CreateMLProgram());
const auto supported_nodes = coreml::GetSupportedNodes(graph_viewer, builder_params, logger);

const Graph* main_graph = &graph_viewer.GetGraph();
while (main_graph->IsSubgraph()) {
main_graph = main_graph->ParentGraph();
}
const auto& metadata = main_graph->GetModel().MetaData();

std::string user_provided_key = metadata.count(kCOREML_CACHE_KEY) > 0
? metadata.at(kCOREML_CACHE_KEY)
: "";
if (user_provided_key.size() > 64 ||
std::any_of(user_provided_key.begin(), user_provided_key.end(),
[](unsigned char c) { return !std::isalnum(c); })) {
LOGS(logger, ERROR) << "[" << kCOREML_CACHE_KEY << ":" << user_provided_key << "] is not a valid cache key."
<< " It should be alphanumeric and less than 64 characters.";
}
const auto gen_metadef_name =
[&]() {
HashValue model_hash;
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);

const Graph* main_graph = &graph_viewer.GetGraph();
while (main_graph->IsSubgraph()) {
main_graph = main_graph->ParentGraph();
}
const auto& metadata = main_graph->GetModel().MetaData();
// use model_hash as the key if user doesn't provide one
// model_hash is a 64-bit hash value of model_path if model_path is not empty,
// otherwise it hashes the graph input names and all the node output names.
// it can't guarantee the uniqueness of the key, so user should manager the key for the best.
std::string user_provided_key = metadata.count(kCOREML_CACHE_KEY) > 0
? metadata.at(kCOREML_CACHE_KEY)
: std::to_string(model_hash);

if (user_provided_key.size() > 64 ||
std::any_of(user_provided_key.begin(), user_provided_key.end(),
[](unsigned char c) { return !std::isalnum(c); })) {
LOGS(logger, ERROR) << "[" << kCOREML_CACHE_KEY << ":" << user_provided_key << "] is not a valid cache key."
<< " It should be alphanumeric and less than 64 characters.";

} else if (user_provided_key.empty()) { // user passed a empty string
if (user_provided_key.empty()) {
// user passed a empty string
// model_hash is a 64-bit hash value of model_path if model_path is not empty,
// otherwise it hashes the graph input names and all the node output names.
// it can't guarantee the uniqueness of the key, so user should manager the key for the best.
user_provided_key = std::to_string(model_hash);
}
// The string format is used by onnxruntime/core/providers/coreml/builders/model_builder.cc::GetModelOutputPath
Expand Down

0 comments on commit 70075e5

Please sign in to comment.