Skip to content

Commit

Permalink
ExecutionProvider API refactor - make GenerateMetaDefId a standalone …
Browse files Browse the repository at this point in the history
…function, decouple it from EP (microsoft#18977)

### Description
<!-- Describe your changes. -->
Make EP's member function, GenerateMetaDefId, a standalone function
which decouples from EP


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This change is for ExecutionProvider API refactoring, we will make a
clean ExecutionProvider API first for later EPv2 work
  • Loading branch information
jslhcl authored Jan 26, 2024
1 parent fc44f96 commit 7d4dc66
Show file tree
Hide file tree
Showing 32 changed files with 187 additions and 147 deletions.
35 changes: 3 additions & 32 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,11 @@ enum class DataLayout {

class IExecutionProvider {
protected:
IExecutionProvider(const std::string& type, bool use_metadef_id_creator = false)
: IExecutionProvider(type, OrtDevice(), use_metadef_id_creator) {}
IExecutionProvider(const std::string& type)
: IExecutionProvider(type, OrtDevice()) {}

IExecutionProvider(const std::string& type, OrtDevice device, bool use_metadef_id_creator = false)
IExecutionProvider(const std::string& type, OrtDevice device)
: default_device_(device), type_{type} {
if (use_metadef_id_creator) {
metadef_id_generator_ = std::make_unique<ModelMetadefIdGenerator>();
}
}

/*
Expand Down Expand Up @@ -274,19 +271,6 @@ class IExecutionProvider {
return logger_;
}

/** Generate a unique id that can be used in a MetaDef name. Values are unique for a model instance.
The model hash is also returned if you wish to include that in the MetaDef name to ensure uniqueness across models.
@param graph_viewer[in] Graph viewer that GetCapability was called with. Can be for the main graph or nested graph.
@param model_hash[out] Returns the hash for the main (i.e. top level) graph in the model.
This is created using the model path if available,
or the model input names and the output names from all nodes in the main graph.
@remarks e.g. the TensorRT Execution Provider is used in multiple sessions and the underlying infrastructure caches
compiled kernels, so the name must be unique and deterministic across models and sessions.
NOTE: Ideally this would be a protected method, but to work across the EP bridge it has to be public and
virtual, and ModelMetadefIdGenerator but be defined in the header as well.
*/
virtual int GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const;

virtual std::unique_ptr<profiling::EpProfiler> GetProfiler() {
return {};
}
Expand Down Expand Up @@ -340,18 +324,5 @@ class IExecutionProvider {

// It will be set when this object is registered to a session
const logging::Logger* logger_ = nullptr;

// helper to generate ids that are unique to model and deterministic, even if the execution provider is shared across
// multiple sessions.
class ModelMetadefIdGenerator {
public:
int GenerateId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash);

private:
std::unordered_map<HashValue, HashValue> main_graph_hash_; // map graph instance hash to model contents hash
std::unordered_map<HashValue, int> model_metadef_id_; // current unique id for model
};

std::unique_ptr<ModelMetadefIdGenerator> metadef_id_generator_;
};
} // namespace onnxruntime
73 changes: 0 additions & 73 deletions onnxruntime/core/framework/execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,77 +35,4 @@ common::Status IExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
}

#endif

int IExecutionProvider::ModelMetadefIdGenerator::GenerateId(const onnxruntime::GraphViewer& graph_viewer,
HashValue& model_hash) {
model_hash = 0;

// find the top level graph
const Graph* cur_graph = &graph_viewer.GetGraph();
while (cur_graph->IsSubgraph()) {
cur_graph = cur_graph->ParentGraph();
}

uint32_t instance_hash[4] = {0, 0, 0, 0};

const Graph& main_graph = *cur_graph;

// hash the bytes in the Graph instance. we can't just use the address as a new Graph instance may use
// the same memory (unit tests prove this can occur). the raw bytes of the Graph instance should be a unique
// fingerprint for the instance that can use used as the key to the hash of the model path/contents.
MurmurHash3::x86_128(&main_graph, gsl::narrow_cast<int32_t>(sizeof(Graph)), instance_hash[0], &instance_hash);
HashValue graph_instance_hash = instance_hash[0] | (uint64_t(instance_hash[1]) << 32);

// if we've already hashed this main graph instance use the cached value
auto entry = main_graph_hash_.find(graph_instance_hash);
if (entry != main_graph_hash_.cend()) {
model_hash = entry->second;
} else {
uint32_t hash[4] = {0, 0, 0, 0};

// prefer path the model was loaded from
// this may not be available if the model was loaded from a stream or in-memory bytes
const auto& model_path_str = main_graph.ModelPath().ToPathString();
if (!model_path_str.empty()) {
MurmurHash3::x86_128(model_path_str.data(), gsl::narrow_cast<int32_t>(model_path_str.size()), hash[0], &hash);
} else {
auto hash_str = [&hash](const std::string& str) {
MurmurHash3::x86_128(str.data(), gsl::narrow_cast<int32_t>(str.size()), hash[0], &hash);
};

// fingerprint the main graph by hashing graph inputs and the ordered outputs from each node
for (const auto* node_arg : main_graph.GetInputsIncludingInitializers()) {
hash_str(node_arg->Name());
}

// note: process nodes in order defined in model to be deterministic
for (const auto& node : main_graph.Nodes()) {
for (const auto* node_arg : node.OutputDefs()) {
if (node_arg->Exists()) {
hash_str(node_arg->Name());
}
}
}
}

model_hash = hash[0] | (uint64_t(hash[1]) << 32);

main_graph_hash_[graph_instance_hash] = model_hash;
}

// return the current unique id, and increment to update
return model_metadef_id_[model_hash]++;
}

int IExecutionProvider::GenerateMetaDefId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const {
ORT_ENFORCE(metadef_id_generator_,
"IExecutionProvider constructor must be called with true for use_metadef_id_creator");

// if the EP is shared across multiple sessions there's a very small potential for concurrency issues.
// use a lock when generating an id to be paranoid
static OrtMutex mutex;
std::lock_guard<OrtMutex> lock(mutex);
return metadef_id_generator_->GenerateId(graph_viewer, model_hash);
}

} // namespace onnxruntime
75 changes: 75 additions & 0 deletions onnxruntime/core/framework/model_metadef_id_generator.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <unordered_map>
#include "model_metadef_id_generator.h"
#include "core/platform/ort_mutex.h"
#include "core/graph/graph_viewer.h"
#include "core/framework/murmurhash3.h"

namespace onnxruntime {
int ModelMetadefIdGenerator::GenerateId(const onnxruntime::GraphViewer& graph_viewer,
HashValue& model_hash) const {
// if the EP is shared across multiple sessions there's a very small potential for concurrency issues.
// use a lock when generating an id to be paranoid
static OrtMutex mutex;
std::lock_guard<OrtMutex> lock(mutex);
model_hash = 0;

// find the top level graph
const Graph* cur_graph = &graph_viewer.GetGraph();
while (cur_graph->IsSubgraph()) {
cur_graph = cur_graph->ParentGraph();
}

uint32_t instance_hash[4] = {0, 0, 0, 0};

const Graph& main_graph = *cur_graph;

// hash the bytes in the Graph instance. we can't just use the address as a new Graph instance may use
// the same memory (unit tests prove this can occur). the raw bytes of the Graph instance should be a unique
// fingerprint for the instance that can use used as the key to the hash of the model path/contents.
MurmurHash3::x86_128(&main_graph, gsl::narrow_cast<int32_t>(sizeof(Graph)), instance_hash[0], &instance_hash);
HashValue graph_instance_hash = instance_hash[0] | (uint64_t(instance_hash[1]) << 32);

// if we've already hashed this main graph instance use the cached value
auto entry = main_graph_hash_.find(graph_instance_hash);
if (entry != main_graph_hash_.cend()) {
model_hash = entry->second;
} else {
uint32_t hash[4] = {0, 0, 0, 0};

// prefer path the model was loaded from
// this may not be available if the model was loaded from a stream or in-memory bytes
const auto& model_path_str = main_graph.ModelPath().ToPathString();
if (!model_path_str.empty()) {
MurmurHash3::x86_128(model_path_str.data(), gsl::narrow_cast<int32_t>(model_path_str.size()), hash[0], &hash);
} else {
auto hash_str = [&hash](const std::string& str) {
MurmurHash3::x86_128(str.data(), gsl::narrow_cast<int32_t>(str.size()), hash[0], &hash);
};

// fingerprint the main graph by hashing graph inputs and the ordered outputs from each node
for (const auto* node_arg : main_graph.GetInputsIncludingInitializers()) {
hash_str(node_arg->Name());
}

// note: process nodes in order defined in model to be deterministic
for (const auto& node : main_graph.Nodes()) {
for (const auto* node_arg : node.OutputDefs()) {
if (node_arg->Exists()) {
hash_str(node_arg->Name());
}
}
}
}

model_hash = hash[0] | (uint64_t(hash[1]) << 32);

main_graph_hash_[graph_instance_hash] = model_hash;
}

// return the current unique id, and increment to update
return model_metadef_id_[model_hash]++;
}

} // namespace onnxruntime
31 changes: 31 additions & 0 deletions onnxruntime/core/framework/model_metadef_id_generator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include <unordered_map>
#include "core/common/basic_types.h"
namespace onnxruntime {
class GraphViewer;

/// <summary>
/// helper to generate ids that are unique to model and deterministic, even if the execution provider is shared across
/// multiple sessions.
/// </summary>
class ModelMetadefIdGenerator {
public:
/** Generate a unique id that can be used in a MetaDef name. Values are unique for a model instance.
The model hash is also returned if you wish to include that in the MetaDef name to ensure uniqueness across models.
@param graph_viewer[in] Graph viewer that GetCapability was called with. Can be for the main graph or nested graph.
@param model_hash[out] Returns the hash for the main (i.e. top level) graph in the model.
This is created using the model path if available,
or the model input names and the output names from all nodes in the main graph.
*/
int GenerateId(const onnxruntime::GraphViewer& graph_viewer, HashValue& model_hash) const;

private:
// mutable as these are caches so we can minimize the hashing required on each usage of GenerateId
mutable std::unordered_map<HashValue, HashValue> main_graph_hash_; // map graph instance hash to model contents hash
mutable std::unordered_map<HashValue, int> model_metadef_id_; // current unique id for model
};

} // namespace onnxruntime
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/cann/cann_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include <map>
#include <unordered_set>

#include "core/providers/shared_library/provider_api.h"
#define ORT_API_MANUAL_INIT
#include "core/session/onnxruntime_cxx_api.h"
#include "core/providers/cann/cann_execution_provider.h"
Expand Down Expand Up @@ -1029,13 +1028,14 @@ Status RegisterCANNKernels(KernelRegistry& kernel_registry) {
} // namespace cann

CANNExecutionProvider::CANNExecutionProvider(const CANNExecutionProviderInfo& info)
: IExecutionProvider{onnxruntime::kCannExecutionProvider, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, info.device_id), true}, info_{info} {
: IExecutionProvider{onnxruntime::kCannExecutionProvider, OrtDevice(OrtDevice::NPU, OrtDevice::MemType::DEFAULT, info.device_id)}, info_{info} {
InitProviderOrtApi();

CANN_CALL_THROW(aclrtSetDevice(info_.device_id));

soc_name_ = aclrtGetSocName();
ORT_ENFORCE(soc_name_ != nullptr, "aclrtGetSocName return nullptr");
metadef_id_generator_ = ModelMetadefIdGenerator::Create();
}

CANNExecutionProvider::~CANNExecutionProvider() {
Expand Down Expand Up @@ -1197,7 +1197,7 @@ std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(

// Generate unique kernel name for CANN subgraph
HashValue model_hash = 0;
int id = GenerateMetaDefId(graph_viewer, model_hash);
int id = metadef_id_generator_->GenerateId(graph_viewer, model_hash);
auto meta_def = IndexedSubGraph_MetaDef::Create();
meta_def->name() = graph_viewer.Name() + "_" + std::to_string(model_hash) + "_" + std::to_string(id);

Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/cann/cann_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class CANNExecutionProvider : public IExecutionProvider {
std::unordered_map<std::string, uint32_t> modelIDs_;
std::unordered_map<std::string, std::string> models_;
std::unordered_map<std::string, std::unordered_map<std::size_t, std::string>> names_;
std::unique_ptr<ModelMetadefIdGenerator> metadef_id_generator_;
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace onnxruntime {
constexpr const char* COREML = "CoreML";

CoreMLExecutionProvider::CoreMLExecutionProvider(uint32_t coreml_flags)
: IExecutionProvider{onnxruntime::kCoreMLExecutionProvider, true},
: IExecutionProvider{onnxruntime::kCoreMLExecutionProvider},
coreml_flags_(coreml_flags) {
}

Expand Down Expand Up @@ -54,7 +54,7 @@ CoreMLExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_vie

const auto gen_metadef_name = [&]() {
HashValue model_hash;
int metadef_id = GenerateMetaDefId(graph_viewer, model_hash);
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
return MakeString(COREML, "_", model_hash, "_", metadef_id);
};

Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/coreml/coreml_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include "core/framework/execution_provider.h"
#include "core/framework/model_metadef_id_generator.h"
#include "core/providers/coreml/coreml_provider_factory.h"

namespace onnxruntime {
Expand Down Expand Up @@ -34,5 +35,6 @@ class CoreMLExecutionProvider : public IExecutionProvider {
#ifdef __APPLE__
std::unordered_map<std::string, std::unique_ptr<onnxruntime::coreml::Model>> coreml_models_;
#endif
ModelMetadefIdGenerator metadef_id_generator_;
};
} // namespace onnxruntime
13 changes: 6 additions & 7 deletions onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
#pragma warning(disable : 4996)
#endif

#include "core/providers/dnnl/dnnl_execution_provider.h"

#include <fstream>
#include <iomanip>
#include <unordered_set>
Expand All @@ -16,6 +14,7 @@

#include "core/platform/ort_mutex.h"
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/dnnl/dnnl_execution_provider.h"

#include "core/providers/dnnl/dnnl_fwd.h"
#include "core/providers/dnnl/dnnl_node_capability.h"
Expand All @@ -30,7 +29,7 @@ constexpr const char* DNNL = "Dnnl";
constexpr const char* DNNL_CPU = "DnnlCpu";

DnnlExecutionProvider::DnnlExecutionProvider(const DnnlExecutionProviderInfo& info)
: IExecutionProvider{onnxruntime::kDnnlExecutionProvider, true},
: IExecutionProvider{onnxruntime::kDnnlExecutionProvider},
info_(info) {
InitProviderOrtApi();

Expand Down Expand Up @@ -77,8 +76,8 @@ DnnlExecutionProvider::DnnlExecutionProvider(const DnnlExecutionProviderInfo& in
// Log the number of threads used
LOGS_DEFAULT(INFO) << "Allocated " << omp_get_max_threads() << " OpenMP threads for oneDNN ep\n";
#endif // defined(DNNL_OPENMP)

} // namespace onnxruntime
metadef_id_generator_ = ModelMetadefIdGenerator::Create();
}

DnnlExecutionProvider::~DnnlExecutionProvider() {
}
Expand Down Expand Up @@ -229,7 +228,7 @@ std::vector<std::unique_ptr<ComputeCapability>> DnnlExecutionProvider::GetCapabi

// Assign inputs and outputs to subgraph's meta_def
HashValue model_hash;
int metadef_id = GenerateMetaDefId(graph_viewer, model_hash);
int metadef_id = metadef_id_generator_->GenerateId(graph_viewer, model_hash);
auto meta_def = ::onnxruntime::IndexedSubGraph_MetaDef::Create();
meta_def->name() = "DNNL_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id);
meta_def->domain() = kMSDomain;
Expand Down Expand Up @@ -264,7 +263,7 @@ std::vector<std::unique_ptr<ComputeCapability>> DnnlExecutionProvider::GetCapabi
graph_viewer.ToProto(*model_proto->mutable_graph(), false, true);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
HashValue model_hash;
int metadef_id = GenerateMetaDefId(graph_viewer, model_hash);
int metadef_id = metadef_id_generator_->GenerateId(graph_viewer, model_hash);
std::fstream dump("DNNL_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id) + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
}
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/providers/dnnl/dnnl_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class DnnlExecutionProvider : public IExecutionProvider {
bool debug_log_ = false;
// enable fusion by default
bool enable_fusion_ = true;
std::unique_ptr<ModelMetadefIdGenerator> metadef_id_generator_;
};

} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
using namespace js;

JsExecutionProvider::JsExecutionProvider(const JsExecutionProviderInfo& info)
: IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), true},
: IExecutionProvider{kJsExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)},
preferred_data_layout_{info.data_layout} {
}

Expand Down
Loading

0 comments on commit 7d4dc66

Please sign in to comment.