diff --git a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc index 72b5d9cdcaeb0..8a038d668f5e7 100644 --- a/onnxruntime/core/providers/coreml/coreml_execution_provider.cc +++ b/onnxruntime/core/providers/coreml/coreml_execution_provider.cc @@ -53,32 +53,31 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie const auto builder_params = coreml::MakeOpBuilderParams(graph_viewer, coreml_version_, coreml_options_.RequireStaticShape(), coreml_options_.CreateMLProgram()); const auto supported_nodes = coreml::GetSupportedNodes(graph_viewer, builder_params, logger); - + const Graph* main_graph = &graph_viewer.GetGraph(); + while (main_graph->IsSubgraph()) { + main_graph = main_graph->ParentGraph(); + } + const auto& metadata = main_graph->GetModel().MetaData(); + + std::string user_provided_key = metadata.count(kCOREML_CACHE_KEY) > 0 + ? metadata.at(kCOREML_CACHE_KEY) + : ""; + if (user_provided_key.size() > 64 || + std::any_of(user_provided_key.begin(), user_provided_key.end(), + [](unsigned char c) { return !std::isalnum(c); })) { + LOGS(logger, ERROR) << "[" << kCOREML_CACHE_KEY << ":" << user_provided_key << "] is not a valid cache key." + << " It should be alphanumeric and less than 64 characters."; + } const auto gen_metadef_name = [&]() { HashValue model_hash; int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); - - const Graph* main_graph = &graph_viewer.GetGraph(); - while (main_graph->IsSubgraph()) { - main_graph = main_graph->ParentGraph(); - } - const auto& metadata = main_graph->GetModel().MetaData(); // use model_hash as the key if user doesn't provide one - // model_hash is a 64-bit hash value of model_path if model_path is not empty, - // otherwise it hashes the graph input names and all the node output names. - // it can't guarantee the uniqueness of the key, so user should manager the key for the best. - std::string user_provided_key = metadata.count(kCOREML_CACHE_KEY) > 0 - ? metadata.at(kCOREML_CACHE_KEY) - : std::to_string(model_hash); - - if (user_provided_key.size() > 64 || - std::any_of(user_provided_key.begin(), user_provided_key.end(), - [](unsigned char c) { return !std::isalnum(c); })) { - LOGS(logger, ERROR) << "[" << kCOREML_CACHE_KEY << ":" << user_provided_key << "] is not a valid cache key." - << " It should be alphanumeric and less than 64 characters."; - - } else if (user_provided_key.empty()) { // user passed a empty string + if (user_provided_key.empty()) { + // user passed a empty string + // model_hash is a 64-bit hash value of model_path if model_path is not empty, + // otherwise it hashes the graph input names and all the node output names. + // it can't guarantee the uniqueness of the key, so user should manager the key for the best. user_provided_key = std::to_string(model_hash); } // The string format is used by onnxruntime/core/providers/coreml/builders/model_builder.cc::GetModelOutputPath