Skip to content

Commit

Permalink
use pointer instead of ModelMetadefIdGenerator instance in EP and cre…
Browse files Browse the repository at this point in the history
…ate from ORT
  • Loading branch information
jslhcl committed Jan 24, 2024
1 parent 67dab28 commit 560b42f
Show file tree
Hide file tree
Showing 10 changed files with 17 additions and 13 deletions.
4 changes: 0 additions & 4 deletions onnxruntime/core/framework/model_metadef_id_generator.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// if SHARED_PROVIDER is defined (in provider_api.h), use the definition in provider_wrappedytypes.h to avoid redefinition.
// make sure provider_api.h is included before this header
#ifndef SHARED_PROVIDER
#pragma once
#include <unordered_map>
#include "core/common/basic_types.h"
Expand Down Expand Up @@ -32,4 +29,3 @@ class ModelMetadefIdGenerator {
};

} // namespace onnxruntime
#endif
3 changes: 2 additions & 1 deletion onnxruntime/core/providers/cann/cann_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,7 @@ CANNExecutionProvider::CANNExecutionProvider(const CANNExecutionProviderInfo& in

soc_name_ = aclrtGetSocName();
ORT_ENFORCE(soc_name_ != nullptr, "aclrtGetSocName return nullptr");
metadef_id_generator_ = ModelMetadefIdGenerator::Create();
}

CANNExecutionProvider::~CANNExecutionProvider() {
Expand Down Expand Up @@ -1196,7 +1197,7 @@ std::unique_ptr<IndexedSubGraph> CANNExecutionProvider::GetSubGraph(

// Generate unique kernel name for CANN subgraph
HashValue model_hash = 0;
int id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
int id = metadef_id_generator_->GenerateId(graph_viewer, model_hash);
auto meta_def = IndexedSubGraph_MetaDef::Create();
meta_def->name() = graph_viewer.Name() + "_" + std::to_string(model_hash) + "_" + std::to_string(id);

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cann/cann_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class CANNExecutionProvider : public IExecutionProvider {
std::unordered_map<std::string, uint32_t> modelIDs_;
std::unordered_map<std::string, std::string> models_;
std::unordered_map<std::string, std::unordered_map<std::size_t, std::string>> names_;
ModelMetadefIdGenerator metadef_id_generator_;
std::unique_ptr<ModelMetadefIdGenerator> metadef_id_generator_;
};

} // namespace onnxruntime
5 changes: 3 additions & 2 deletions onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ DnnlExecutionProvider::DnnlExecutionProvider(const DnnlExecutionProviderInfo& in
// Log the number of threads used
LOGS_DEFAULT(INFO) << "Allocated " << omp_get_max_threads() << " OpenMP threads for oneDNN ep\n";
#endif // defined(DNNL_OPENMP)
metadef_id_generator_ = ModelMetadefIdGenerator::Create();
}

DnnlExecutionProvider::~DnnlExecutionProvider() {
Expand Down Expand Up @@ -227,7 +228,7 @@ std::vector<std::unique_ptr<ComputeCapability>> DnnlExecutionProvider::GetCapabi

// Assign inputs and outputs to subgraph's meta_def
HashValue model_hash;
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
int metadef_id = metadef_id_generator_->GenerateId(graph_viewer, model_hash);
auto meta_def = ::onnxruntime::IndexedSubGraph_MetaDef::Create();
meta_def->name() = "DNNL_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id);
meta_def->domain() = kMSDomain;
Expand Down Expand Up @@ -262,7 +263,7 @@ std::vector<std::unique_ptr<ComputeCapability>> DnnlExecutionProvider::GetCapabi
graph_viewer.ToProto(*model_proto->mutable_graph(), false, true);
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
HashValue model_hash;
int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash);
int metadef_id = metadef_id_generator_->GenerateId(graph_viewer, model_hash);
std::fstream dump("DNNL_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id) + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary);
model_proto->SerializeToOstream(dump);
}
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/providers/dnnl/dnnl_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "core/providers/dnnl/dnnl_op_manager.h"
#include "core/providers/dnnl/subgraph/dnnl_subgraph.h"
#include "core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h"
#include "core/framework/model_metadef_id_generator.h"

namespace onnxruntime {

Expand Down Expand Up @@ -42,7 +41,7 @@ class DnnlExecutionProvider : public IExecutionProvider {
bool debug_log_ = false;
// enable fusion by default
bool enable_fusion_ = true;
ModelMetadefIdGenerator metadef_id_generator_;
std::unique_ptr<ModelMetadefIdGenerator> metadef_id_generator_;
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
MIOPEN_CALL_THROW(miopenCreate(&external_miopen_handle_));
MIOPEN_CALL_THROW(miopenSetStream(external_miopen_handle_, stream_));

metadef_id_generator_ = ModelMetadefIdGenerator::Create();

LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: "
<< "device_id: " << device_id_
<< ", migraphx_fp16_enable: " << fp16_enable_
Expand Down Expand Up @@ -757,7 +759,7 @@ std::unique_ptr<IndexedSubGraph> MIGraphXExecutionProvider::GetSubGraph(const st

// Generate unique kernel name for MIGraphX subgraph
uint64_t model_hash = 0;
int id = metadef_id_generator_.GenerateId(graph, model_hash);
int id = metadef_id_generator_->GenerateId(graph, model_hash);
std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(id);
auto meta_def = IndexedSubGraph_MetaDef::Create();
const std::string graph_type = graph.IsSubgraph() ? "subgraph" : "graph";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include "core/framework/arena_extend_strategy.h"
#include "core/framework/execution_provider.h"
#include "core/framework/model_metadef_id_generator.h"
#include "core/platform/ort_mutex.h"
#include "core/providers/migraphx/migraphx_execution_provider_info.h"
#include "core/providers/migraphx/migraphx_inc.h"
Expand Down Expand Up @@ -99,7 +98,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
AllocatorPtr allocator_;
miopenHandle_t external_miopen_handle_ = nullptr;
rocblas_handle external_rocblas_handle_ = nullptr;
ModelMetadefIdGenerator metadef_id_generator_;
std::unique_ptr<ModelMetadefIdGenerator> metadef_id_generator_;

Check warning on line 101 in onnxruntime/core/providers/migraphx/migraphx_execution_provider.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/migraphx/migraphx_execution_provider.h#L101

Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/migraphx/migraphx_execution_provider.h:101:  Add #include <memory> for unique_ptr<>  [build/include_what_you_use] [4]
};

} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -972,6 +972,8 @@ struct ProviderHost {
#endif

// ModelMetadefIdGenerator
virtual std::unique_ptr<ModelMetadefIdGenerator> ModelMetadefIdGenerator__construct() = 0;

Check warning on line 975 in onnxruntime/core/providers/shared_library/provider_interfaces.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/shared_library/provider_interfaces.h#L975

Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/shared_library/provider_interfaces.h:975:  Add #include <memory> for unique_ptr<>  [build/include_what_you_use] [4]
virtual void ModelMetadefIdGenerator__operator_delete(ModelMetadefIdGenerator* p) = 0;
virtual int ModelMetadefIdGenerator__GenerateId(const ModelMetadefIdGenerator* p, const GraphViewer& graph_viewer, HashValue& model_hash) = 0;

Check warning on line 977 in onnxruntime/core/providers/shared_library/provider_interfaces.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/shared_library/provider_interfaces.h#L977

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/shared_library/provider_interfaces.h:977:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,8 @@ class TensorSeq final {

class ModelMetadefIdGenerator {
public:
static std::unique_ptr<ModelMetadefIdGenerator> Create() { return g_host->ModelMetadefIdGenerator__construct(); }

Check warning on line 1158 in onnxruntime/core/providers/shared_library/provider_wrappedtypes.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/shared_library/provider_wrappedtypes.h#L1158

Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
Raw output
onnxruntime/core/providers/shared_library/provider_wrappedtypes.h:1158:  Add #include <memory> for unique_ptr<>  [build/include_what_you_use] [4]
static void operator delete(void* p) { g_host->ModelMetadefIdGenerator__operator_delete(reinterpret_cast<ModelMetadefIdGenerator*>(p)); }

Check warning on line 1159 in onnxruntime/core/providers/shared_library/provider_wrappedtypes.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/shared_library/provider_wrappedtypes.h#L1159

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/shared_library/provider_wrappedtypes.h:1159:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
int GenerateId(const GraphViewer& graph_viewer, HashValue& model_hash) const { return g_host->ModelMetadefIdGenerator__GenerateId(this, graph_viewer, model_hash); }

Check warning on line 1160 in onnxruntime/core/providers/shared_library/provider_wrappedtypes.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/shared_library/provider_wrappedtypes.h#L1160

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/shared_library/provider_wrappedtypes.h:1160:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
};

Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1081,6 +1081,8 @@ struct ProviderHostImpl : ProviderHost {
void TensorSeq__Reserve(TensorSeq* p, size_t capacity) override { p->Reserve(capacity); }

// ModelMetadefIdGenerator(wrapped)
std::unique_ptr<ModelMetadefIdGenerator> ModelMetadefIdGenerator__construct() override { return std::make_unique<ModelMetadefIdGenerator>(); }

Check warning on line 1084 in onnxruntime/core/session/provider_bridge_ort.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/session/provider_bridge_ort.cc#L1084

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/session/provider_bridge_ort.cc:1084:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
void ModelMetadefIdGenerator__operator_delete(ModelMetadefIdGenerator* p) override { delete p; }
int ModelMetadefIdGenerator__GenerateId(const ModelMetadefIdGenerator* p, const GraphViewer& graph_viewer, HashValue& model_hash) override { return p->GenerateId(graph_viewer, model_hash); }

Check warning on line 1086 in onnxruntime/core/session/provider_bridge_ort.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/session/provider_bridge_ort.cc#L1086

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/session/provider_bridge_ort.cc:1086:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

#if defined(ENABLE_TRAINING) && defined(ORT_USE_NCCL)
Expand Down

0 comments on commit 560b42f

Please sign in to comment.