Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CoreML] support coreml model cache #23065

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@
static const char* const kCoremlProviderOption_ProfileComputePlan = "ProfileComputePlan";
// please refer to https://developer.apple.com/documentation/coreml/mlmodelconfiguration/allowlowprecisionaccumulationongpu
static const char* const kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU = "AllowLowPrecisionAccumulationOnGPU";
// Specify the directory to cache any CoreML models created from the ONNX model in.
// CoreML EP will convert onnx subgraph to CoreML model and save to disk.
// If this path is not specified, the model will be saved to a temp directory and deleted after the session is closed.
// otherwise, the model will be saved to the specified path and User should manage to delete the model.

// we would not detect if the cached model match the onnx subgraph, so User should carefully manage the cache for a new model.
// The cache key is generated by
// 1. user provided key in metadata_props if exist (prefered)

Check warning on line 71 in include/onnxruntime/core/providers/coreml/coreml_provider_factory.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "prefered" is a misspelling of "preferred" Raw Output: ./include/onnxruntime/core/providers/coreml/coreml_provider_factory.h:71:55: "prefered" is a misspelling of "preferred"
// 2. Hash of the model url the inference session was created with
// 3. Hash of the input/output names of the model
wejoncy marked this conversation as resolved.
Show resolved Hide resolved

// EP wounln't track the model change or responsible for the cache management.
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
static const char* const kCoremlProviderOption_ModelCacheDirectory = "ModelCacheDirectory";

// User provided cache-key in metadata_props.
static const char* const kCOREML_CACHE_KEY = "CACHE_KEY";
wejoncy marked this conversation as resolved.
Show resolved Hide resolved

#ifdef __cplusplus
extern "C" {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/platform/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class Env {
#ifdef _WIN32
/// \brief Returns true if the directory exists.
virtual bool FolderExists(const std::wstring& path) const = 0;
virtual bool FileExists(const std::wstring& path) const = 0;
/// \brief Recursively creates the directory, if it doesn't exist.
virtual common::Status CreateFolder(const std::wstring& path) const = 0;
// Mainly for use with protobuf library
Expand All @@ -206,6 +207,7 @@ class Env {
#endif
/// \brief Returns true if the directory exists.
virtual bool FolderExists(const std::string& path) const = 0;
virtual bool FileExists(const std::string& path) const = 0;
/// \brief Recursively creates the directory, if it doesn't exist.
virtual common::Status CreateFolder(const std::string& path) const = 0;
// Recursively deletes the directory and its contents.
Expand Down
8 changes: 8 additions & 0 deletions onnxruntime/core/platform/posix/env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,14 @@ class PosixEnv : public Env {
return S_ISDIR(sb.st_mode);
}

bool FileExists(const std::string& path) const override {
struct stat sb;
if (stat(path.c_str(), &sb)) {
return false;
}
return S_ISREG(sb.st_mode);
}

common::Status CreateFolder(const std::string& path) const override {
size_t pos = 0;
do {
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/platform/windows/env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,16 @@ bool WindowsEnv::FolderExists(const std::string& path) const {
return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_DIRECTORY);
}

bool WindowsEnv::FileExists(const std::wstring& path) const {
DWORD attributes = GetFileAttributesW(path.c_str());
return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_NORMAL);
}

bool WindowsEnv::FileExists(const std::string& path) const {
DWORD attributes = GetFileAttributesA(path.c_str());
return (attributes != INVALID_FILE_ATTRIBUTES) && (attributes & FILE_ATTRIBUTE_NORMAL);
}

common::Status WindowsEnv::CreateFolder(const std::wstring& path) const {
size_t pos = 0;
do {
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/platform/windows/env.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class WindowsEnv : public Env {
MappedMemoryPtr& mapped_memory) const override;
bool FolderExists(const std::wstring& path) const override;
bool FolderExists(const std::string& path) const override;
bool FileExists(const std::wstring& path) const override;
bool FileExists(const std::string& path) const override;
common::Status CreateFolder(const std::wstring& path) const override;
common::Status CreateFolder(const std::string& path) const override;
common::Status DeleteFolder(const PathString& path) const override;
Expand Down
83 changes: 74 additions & 9 deletions onnxruntime/core/providers/coreml/builders/model_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,13 +390,57 @@ void CreateEmptyFile(const std::string& filename) {

#endif // defined(COREML_ENABLE_MLPROGRAM)

std::string GetModelOutputPath(bool create_ml_program) {
// path is used to create the ML Package directory for ML Program, and for the model directly otherwise.
auto path = util::GetTemporaryFilePath();
if (!create_ml_program) {
path += ".model.mlmodel";
}
std::string GetModelOutputPath(const CoreMLOptions& coreml_options,
const GraphViewer& graph_viewer) {
const std::string& subgraph_name = graph_viewer.Name();
std::string path;
if (coreml_options.ModelCacheDirectory().empty()) {
// path is used to create the ML Package directory for ML Program, and for the model directly otherwise.
path = util::GetTemporaryFilePath();
if (!coreml_options.CreateMLProgram()) {
path += ".model.mlmodel";
}
} else {
// subgraph_name is uniquely generated by
// onnxruntime/core/providers/coreml/coreml_execution_provider.cc::gen_metadef_name
// int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
// MakeString(user_provide_key, "_", COREML, "_", model_hash, "_", metadef_id);
std::string_view cache_key = std::string_view(subgraph_name)
.substr(0, subgraph_name.find_first_of("_"));
skottmckay marked this conversation as resolved.
Show resolved Hide resolved
// subgraph_short_name is metadef_id
std::string_view subgraph_short_name = std::string_view(subgraph_name)
.substr(subgraph_name.find_last_of("_") + 1);
path = MakeString(std::string(coreml_options.ModelCacheDirectory()), "/", cache_key);
ORT_THROW_IF_ERROR(Env::Default().CreateFolder(path));
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
// Write the model path to a file in the cache directory.
// This is for developers to know what the cached model is as we used a hash for the directory name.
if (!Env::Default().FileExists(ToPathString(path + "/model.txt"))) {
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
const Graph* main_graph = &graph_viewer.GetGraph();
while (main_graph->IsSubgraph()) {
main_graph = main_graph->ParentGraph();
}
std::ofstream file(path + "/model.txt");
ORT_ENFORCE(file.is_open(), "Failed to open file ", path + "/model.txt");
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
file << main_graph->ModelPath().string();
file.close();
}

path = MakeString(path, "/", subgraph_short_name);
// Set the model cache path with setting of RequireStaticShape and ModelFormat
if (coreml_options.RequireStaticShape()) {
path += "_static";
} else {
path += "_dynamic";
}

if (coreml_options.CreateMLProgram()) {
path += "_mlprogram";
} else {
path += "_nn";
}
ORT_THROW_IF_ERROR(Env::Default().CreateFolder(path));
path += "/model";
}
return path;
}
} // namespace
Expand All @@ -410,10 +454,21 @@ ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logge
coreml_version_(coreml_version),
coreml_options_(coreml_options),
create_ml_program_(coreml_options.CreateMLProgram()),
model_output_path_(GetModelOutputPath(create_ml_program_)),
model_output_path_(GetModelOutputPath(coreml_options, graph_viewer)),
onnx_input_names_(std::move(onnx_input_names)),
onnx_output_names_(std::move(onnx_output_names)),
coreml_model_(std::make_unique<CoreML::Specification::Model>()) {
// GetTemporaryFilePath() always produce a unique path for the model and this is not existed
// Mlprogram will create a folder while NN create a file
if (Env::Default().FolderExists(ToPathString(model_output_path_)) ||
Env::Default().FileExists(ToPathString(model_output_path_))) {
is_model_cached_ = true;
LOGS(logger, INFO) << "Model is already cached in " << model_output_path_
<< " and will be reused. If you want to update the model or hit other issues, "
<< "please consider to clear the cache and retry.";
return;
}

if (create_ml_program_) {
#if defined(COREML_ENABLE_MLPROGRAM)
coreml_model_->set_specificationversion(CoreMLSpecVersion());
Expand Down Expand Up @@ -847,6 +902,10 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i

input_output_info_.emplace(name, OnnxTensorInfo{data_type, shape});

if (IsModelCached()) {
return Status::OK();
}

#if defined(COREML_ENABLE_MLPROGRAM)
if (create_ml_program_) {
if (is_input) {
Expand Down Expand Up @@ -1056,8 +1115,14 @@ Status ModelBuilder::Build(const GraphViewer& graph_viewer, const logging::Logge
ModelBuilder builder(graph_viewer, logger, coreml_version, coreml_options,
std::move(onnx_input_names), std::move(onnx_output_names));

ORT_RETURN_IF_ERROR(builder.CreateModel());
ORT_RETURN_IF_ERROR(builder.SaveModel());
if (!builder.IsModelCached()) {
ORT_RETURN_IF_ERROR(builder.CreateModel());
ORT_RETURN_IF_ERROR(builder.SaveModel());
} else {
// runtime requires the input/output names to be passed
ORT_RETURN_IF_ERROR(builder.RegisterModelInputs());
ORT_RETURN_IF_ERROR(builder.RegisterModelOutputs());
}

return builder.LoadModel(model);
}
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/coreml/builders/model_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ class ModelBuilder {
// We only support CoreML 3 and later so the spec version is always version + 1.
int32_t CoreMLVersion() const { return coreml_version_; }
int32_t CoreMLSpecVersion() const { return coreml_version_ + 1; }
bool IsModelCached() const { return is_model_cached_; }

// Returns true if we are creating an ML Program
bool CreateMLProgram() const {
Expand Down Expand Up @@ -218,8 +219,9 @@ class ModelBuilder {
const logging::Logger& logger_;
const int32_t coreml_version_;
CoreMLOptions coreml_options_;
const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old)
const std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel
const bool create_ml_program_; // ML Program (CoreML5, iOS 15+, macOS 12+) or NeuralNetwork (old)
std::string model_output_path_; // create_ml_program_ ? dir for mlpackage : filename for mlmodel
bool is_model_cached_{false};

std::vector<std::string> onnx_input_names_;
std::vector<std::string> onnx_output_names_;
Expand Down
27 changes: 26 additions & 1 deletion onnxruntime/core/providers/coreml/coreml_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "core/providers/coreml/model/host_utils.h"
#include "core/providers/coreml/model/model.h"
#include "core/providers/coreml/shape_utils.h"
#include "core/graph/model.h"

namespace onnxruntime {

Expand Down Expand Up @@ -57,7 +58,31 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie
[&]() {
HashValue model_hash;
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
return MakeString(COREML, "_", model_hash, "_", metadef_id);
std::string user_provided_key;
const Graph* main_graph = &graph_viewer.GetGraph();
while (main_graph->IsSubgraph()) {
main_graph = main_graph->ParentGraph();
}
if (main_graph->GetModel().MetaData().count(kCOREML_CACHE_KEY) > 0) {
user_provided_key = graph_viewer.GetGraph().GetModel().MetaData().at(kCOREML_CACHE_KEY);
if (user_provided_key.size() > 64 ||
std::any_of(user_provided_key.begin(), user_provided_key.end(),
[](unsigned char c) { return !std::isalnum(c); })) {
user_provided_key = std::to_string(std::hash<std::string>{}(user_provided_key));
wejoncy marked this conversation as resolved.
Show resolved Hide resolved
}
// invalid cache-key
if (user_provided_key.size() == 0) {
user_provided_key = std::to_string(model_hash);
}
} else {
// model_hash is a 64-bit hash value of model_path if model_path is not empty,
// otherwise it hashes the graph input names and all the node output names.
// it can't guarantee the uniqueness of the key, so user should manager the key for the best.
user_provided_key = std::to_string(model_hash);
}
// The string format is used by onnxruntime/core/providers/coreml/builders/model_builder.cc::GetModelOutputPath
// If the format changes, the function should be updated accordingly.
return MakeString(user_provided_key, "_", COREML, "_", model_hash, "_", metadef_id);
};

result = utils::CreateSupportedPartitions(graph_viewer, supported_nodes, {},
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/coreml/coreml_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "core/providers/coreml/coreml_provider_factory.h" // defines flags
#include "core/providers/coreml/model/host_utils.h"
#include "core/providers/coreml/builders/helper.h"
#include "core/platform/env.h"

namespace onnxruntime {

Expand Down Expand Up @@ -71,6 +72,7 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option
kCoremlProviderOption_SpecializationStrategy,
kCoremlProviderOption_ProfileComputePlan,
kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU,
kCoremlProviderOption_ModelCacheDirectory,
};
// Validate the options
for (const auto& option : options) {
Expand Down Expand Up @@ -103,6 +105,8 @@ void CoreMLOptions::ValidateAndParseProviderOption(const ProviderOptions& option
profile_compute_plan_ = option.second == "1";
} else if (kCoremlProviderOption_AllowLowPrecisionAccumulationOnGPU == option.first) {
allow_low_precision_accumulation_on_gpu_ = option.second == "1";
} else if (kCoremlProviderOption_ModelCacheDirectory == option.first) {
model_cache_directory_ = option.second;
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/coreml/coreml_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
std::string strategy_;
bool profile_compute_plan_{false};
bool allow_low_precision_accumulation_on_gpu_{false};
// path to store the converted coreml model
std::string model_cache_directory_;

Check warning on line 21 in onnxruntime/core/providers/coreml/coreml_options.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/coreml/coreml_options.h:21: Add #include <string> for string [build/include_what_you_use] [4]

public:
explicit CoreMLOptions(uint32_t coreml_flags);
Expand All @@ -32,6 +34,8 @@
bool UseStrategy(std::string_view strategy) const { return strategy_ == strategy; }
bool ProfileComputePlan() const { return profile_compute_plan_ && create_mlprogram_; }

std::string_view ModelCacheDirectory() const { return model_cache_directory_; }

private:
void ValidateAndParseProviderOption(const ProviderOptions& options);
};
Expand Down
Loading
Loading