From 560b42fb5f0bed1f0cef76f69cb52616a8cef1f0 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Wed, 24 Jan 2024 15:52:49 -0800 Subject: [PATCH] use pointer instead of ModelMetadefIdGenerator instance in EP and create from ORT --- onnxruntime/core/framework/model_metadef_id_generator.h | 4 ---- onnxruntime/core/providers/cann/cann_execution_provider.cc | 3 ++- onnxruntime/core/providers/cann/cann_execution_provider.h | 2 +- onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc | 5 +++-- onnxruntime/core/providers/dnnl/dnnl_execution_provider.h | 3 +-- .../core/providers/migraphx/migraphx_execution_provider.cc | 4 +++- .../core/providers/migraphx/migraphx_execution_provider.h | 3 +-- .../core/providers/shared_library/provider_interfaces.h | 2 ++ .../core/providers/shared_library/provider_wrappedtypes.h | 2 ++ onnxruntime/core/session/provider_bridge_ort.cc | 2 ++ 10 files changed, 17 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/framework/model_metadef_id_generator.h b/onnxruntime/core/framework/model_metadef_id_generator.h index a9b8cd48d0f6c..82f68c42b5c35 100644 --- a/onnxruntime/core/framework/model_metadef_id_generator.h +++ b/onnxruntime/core/framework/model_metadef_id_generator.h @@ -1,9 +1,6 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// if SHARED_PROVIDER is defined (in provider_api.h), use the definition in provider_wrappedytypes.h to avoid redefinition. -// make sure provider_api.h is included before this header -#ifndef SHARED_PROVIDER #pragma once #include #include "core/common/basic_types.h" @@ -32,4 +29,3 @@ class ModelMetadefIdGenerator { }; } // namespace onnxruntime -#endif diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.cc b/onnxruntime/core/providers/cann/cann_execution_provider.cc index 74629c95abf2c..752b742805a7c 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.cc +++ b/onnxruntime/core/providers/cann/cann_execution_provider.cc @@ -1035,6 +1035,7 @@ CANNExecutionProvider::CANNExecutionProvider(const CANNExecutionProviderInfo& in soc_name_ = aclrtGetSocName(); ORT_ENFORCE(soc_name_ != nullptr, "aclrtGetSocName return nullptr"); + metadef_id_generator_ = ModelMetadefIdGenerator::Create(); } CANNExecutionProvider::~CANNExecutionProvider() { @@ -1196,7 +1197,7 @@ std::unique_ptr CANNExecutionProvider::GetSubGraph( // Generate unique kernel name for CANN subgraph HashValue model_hash = 0; - int id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); + int id = metadef_id_generator_->GenerateId(graph_viewer, model_hash); auto meta_def = IndexedSubGraph_MetaDef::Create(); meta_def->name() = graph_viewer.Name() + "_" + std::to_string(model_hash) + "_" + std::to_string(id); diff --git a/onnxruntime/core/providers/cann/cann_execution_provider.h b/onnxruntime/core/providers/cann/cann_execution_provider.h index 129c44a5a41de..63ae980869c65 100644 --- a/onnxruntime/core/providers/cann/cann_execution_provider.h +++ b/onnxruntime/core/providers/cann/cann_execution_provider.h @@ -81,7 +81,7 @@ class CANNExecutionProvider : public IExecutionProvider { std::unordered_map modelIDs_; std::unordered_map models_; std::unordered_map> names_; - ModelMetadefIdGenerator metadef_id_generator_; + std::unique_ptr metadef_id_generator_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc index 3d5b18540cafc..3271dab13f675 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.cc @@ -76,6 +76,7 @@ DnnlExecutionProvider::DnnlExecutionProvider(const DnnlExecutionProviderInfo& in // Log the number of threads used LOGS_DEFAULT(INFO) << "Allocated " << omp_get_max_threads() << " OpenMP threads for oneDNN ep\n"; #endif // defined(DNNL_OPENMP) + metadef_id_generator_ = ModelMetadefIdGenerator::Create(); } DnnlExecutionProvider::~DnnlExecutionProvider() { @@ -227,7 +228,7 @@ std::vector> DnnlExecutionProvider::GetCapabi // Assign inputs and outputs to subgraph's meta_def HashValue model_hash; - int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); + int metadef_id = metadef_id_generator_->GenerateId(graph_viewer, model_hash); auto meta_def = ::onnxruntime::IndexedSubGraph_MetaDef::Create(); meta_def->name() = "DNNL_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id); meta_def->domain() = kMSDomain; @@ -262,7 +263,7 @@ std::vector> DnnlExecutionProvider::GetCapabi graph_viewer.ToProto(*model_proto->mutable_graph(), false, true); model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); HashValue model_hash; - int metadef_id = metadef_id_generator_.GenerateId(graph_viewer, model_hash); + int metadef_id = metadef_id_generator_->GenerateId(graph_viewer, model_hash); std::fstream dump("DNNL_" + std::to_string(model_hash) + "_" + std::to_string(metadef_id) + ".onnx", std::ios::out | std::ios::trunc | std::ios::binary); model_proto->SerializeToOstream(dump); } diff --git a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h index f7e4a5e380cd7..b7fcbb7765180 100644 --- a/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h +++ b/onnxruntime/core/providers/dnnl/dnnl_execution_provider.h @@ -13,7 +13,6 @@ #include "core/providers/dnnl/dnnl_op_manager.h" #include "core/providers/dnnl/subgraph/dnnl_subgraph.h" #include "core/providers/dnnl/subgraph/dnnl_subgraph_primitive.h" -#include "core/framework/model_metadef_id_generator.h" namespace onnxruntime { @@ -42,7 +41,7 @@ class DnnlExecutionProvider : public IExecutionProvider { bool debug_log_ = false; // enable fusion by default bool enable_fusion_ = true; - ModelMetadefIdGenerator metadef_id_generator_; + std::unique_ptr metadef_id_generator_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index 163dbd928728e..40e76a0a67782 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -165,6 +165,8 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv MIOPEN_CALL_THROW(miopenCreate(&external_miopen_handle_)); MIOPEN_CALL_THROW(miopenSetStream(external_miopen_handle_, stream_)); + metadef_id_generator_ = ModelMetadefIdGenerator::Create(); + LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " << "device_id: " << device_id_ << ", migraphx_fp16_enable: " << fp16_enable_ @@ -757,7 +759,7 @@ std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const st // Generate unique kernel name for MIGraphX subgraph uint64_t model_hash = 0; - int id = metadef_id_generator_.GenerateId(graph, model_hash); + int id = metadef_id_generator_->GenerateId(graph, model_hash); std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(id); auto meta_def = IndexedSubGraph_MetaDef::Create(); const std::string graph_type = graph.IsSubgraph() ? "subgraph" : "graph"; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index f7482b071cee8..d582338c7e067 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -8,7 +8,6 @@ #include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" -#include "core/framework/model_metadef_id_generator.h" #include "core/platform/ort_mutex.h" #include "core/providers/migraphx/migraphx_execution_provider_info.h" #include "core/providers/migraphx/migraphx_inc.h" @@ -99,7 +98,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { AllocatorPtr allocator_; miopenHandle_t external_miopen_handle_ = nullptr; rocblas_handle external_rocblas_handle_ = nullptr; - ModelMetadefIdGenerator metadef_id_generator_; + std::unique_ptr metadef_id_generator_; }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index c001d4078389f..a216b2bfc6d04 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -972,6 +972,8 @@ struct ProviderHost { #endif // ModelMetadefIdGenerator + virtual std::unique_ptr ModelMetadefIdGenerator__construct() = 0; + virtual void ModelMetadefIdGenerator__operator_delete(ModelMetadefIdGenerator* p) = 0; virtual int ModelMetadefIdGenerator__GenerateId(const ModelMetadefIdGenerator* p, const GraphViewer& graph_viewer, HashValue& model_hash) = 0; }; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 386ac45d8b7f7..f46c76fd3421b 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -1155,6 +1155,8 @@ class TensorSeq final { class ModelMetadefIdGenerator { public: + static std::unique_ptr Create() { return g_host->ModelMetadefIdGenerator__construct(); } + static void operator delete(void* p) { g_host->ModelMetadefIdGenerator__operator_delete(reinterpret_cast(p)); } int GenerateId(const GraphViewer& graph_viewer, HashValue& model_hash) const { return g_host->ModelMetadefIdGenerator__GenerateId(this, graph_viewer, model_hash); } }; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 824b5804688ce..f8bd7f4aa208c 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1081,6 +1081,8 @@ struct ProviderHostImpl : ProviderHost { void TensorSeq__Reserve(TensorSeq* p, size_t capacity) override { p->Reserve(capacity); } // ModelMetadefIdGenerator(wrapped) + std::unique_ptr ModelMetadefIdGenerator__construct() override { return std::make_unique(); } + void ModelMetadefIdGenerator__operator_delete(ModelMetadefIdGenerator* p) override { delete p; } int ModelMetadefIdGenerator__GenerateId(const ModelMetadefIdGenerator* p, const GraphViewer& graph_viewer, HashValue& model_hash) override { return p->GenerateId(graph_viewer, model_hash); } #if defined(ENABLE_TRAINING) && defined(ORT_USE_NCCL)