Skip to content

Commit

Permalink
better hash
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Dec 16, 2024
1 parent f492fee commit 5f56c3b
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 11 deletions.
17 changes: 7 additions & 10 deletions onnxruntime/core/providers/coreml/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ void CreateEmptyFile(const std::string& filename) {
#endif // defined(COREML_ENABLE_MLPROGRAM)

std::string GetModelOutputPath(const CoreMLOptions& coreml_options,
const std::vector<std::string>& onnx_input_names) {
const std::string& graph_name) {
std::string path;
if (coreml_options.ModelCachePath().empty()) {
// path is used to create the ML Package directory for ML Program, and for the model directly otherwise.
Expand All @@ -400,14 +400,11 @@ std::string GetModelOutputPath(const CoreMLOptions& coreml_options,
path += ".model.mlmodel";
}
} else {
// input names in onnx are unique. so we can use them as the key in the cache.
std::string inputs_collections = std::accumulate(
onnx_input_names.begin(), onnx_input_names.end(), std::string(),
[](const std::string& a, const std::string& b) { return a + "," + b; });
std::hash<std::string> hasher;
// different subgraph has different folders. so we need to hash the inputs.
path = std::string(coreml_options.ModelCachePath()) +
"/" + std::to_string(hasher(inputs_collections));
// graph_name is uniquely generated by
// onnxruntime/core/providers/coreml/coreml_execution_provider.cc::gen_metadef_name
// int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
// MakeString(COREML, "_", model_hash, "_", metadef_id);.
path = std::string(coreml_options.ModelCachePath()) + "/" + graph_name;
if (!coreml_options.CreateMLProgram()) {
ORT_THROW_IF_ERROR(Env::Default().CreateFolder(path));
path += "/mlmodel";
Expand All @@ -427,7 +424,7 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge
coreml_version_(coreml_version),
coreml_options_(coreml_options),
create_ml_program_(coreml_options.CreateMLProgram()),
model_output_path_(GetModelOutputPath(coreml_options, onnx_input_names)),
model_output_path_(GetModelOutputPath(coreml_options, graph_viewer.Name())),
onnx_input_names_(std::move(onnx_input_names)),
onnx_output_names_(std::move(onnx_output_names)),
coreml_model_(std::make_unique<CoreML::Specification::Model>()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "core/providers/coreml/model/host_utils.h"
#include "core/providers/coreml/model/model.h"
#include "core/providers/coreml/shape_utils.h"
#include "core/graph/model.h"

namespace onnxruntime {

Expand Down Expand Up @@ -57,7 +58,11 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
[&]() {
HashValue model_hash;
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
return MakeString(COREML, "_", model_hash, "_", metadef_id);
std::string user_provide_hash;
if (graph_viewer.GetGraph().GetModel().MetaData().count("model_hash") > 0) {
user_provide_hash = graph_viewer.GetGraph().GetModel().MetaData().at("model_hash");
}
return MakeString(user_provide_hash, "_", COREML, "_", model_hash, "_", metadef_id);
};

result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {},
Expand Down

0 comments on commit 5f56c3b

Please sign in to comment.