From f403610d2d8090da50df4640598d354530abe01d Mon Sep 17 00:00:00 2001 From: Yueqing Zhang Date: Mon, 12 Aug 2024 09:08:11 -0500 Subject: [PATCH] optimize model clone for VitisAI --- .../shared_library/provider_interfaces.h | 2 +- .../shared_library/provider_wrappedtypes.h | 2 +- .../core/providers/vitisai/imp/global_api.cc | 14 +--- .../core/providers/vitisai/imp/graph.cc | 66 +++++++++++++++++++ .../core/providers/vitisai/imp/node_arg.cc | 29 +++++++- .../providers/vitisai/include/vaip/graph.h | 2 +- .../providers/vitisai/include/vaip/node_arg.h | 1 + .../vitisai/include/vaip/vaip_ort_api.h | 3 +- .../core/session/provider_bridge_ort.cc | 2 +- 9 files changed, 103 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index a9394838aa784..4527c0a89303c 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -896,7 +896,7 @@ struct ProviderHost { virtual NodeArg& Graph__GetOrCreateNodeArg(Graph* p, const std::string& name, const ONNX_NAMESPACE::TypeProto* p_arg_type) = 0; virtual void Graph__AddOuterScopeNodeArg(Graph* p, const std::string& name) = 0; virtual void Graph__SetInputs(Graph* p, gsl::span inputs) = 0; - + virtual const std::unordered_map& Graph__DomainToVersionMap(const Graph* p) const noexcept = 0; virtual Status Graph__Resolve(Graph* p) = 0; virtual void Graph__AddInitializedTensor(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor) = 0; virtual Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span& input_args, const gsl::span& output_args, const NodeAttributes* attributes, const std::string& domain) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index 242c7126f3274..d98d91759b164 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -943,7 +943,7 @@ struct Graph final { NodeArg& GetOrCreateNodeArg(const std::string& name, const ONNX_NAMESPACE::TypeProto* p_arg_type) { return g_host->Graph__GetOrCreateNodeArg(this, name, p_arg_type); } void AddOuterScopeNodeArg(const std::string& name) { g_host->Graph__AddOuterScopeNodeArg(this, name); } void SetInputs(gsl::span inputs) { g_host->Graph__SetInputs(this, inputs); } - + const std::unordered_map& DomainToVersionMap() const noexcept { return g_host->Graph__DomainToVersionMap(this); } Status Resolve() { return g_host->Graph__Resolve(this); } void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto& tensor) { return g_host->Graph__AddInitializedTensor(this, tensor); } Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, gsl::span input_args, gsl::span output_args, const NodeAttributes* attributes, const std::string& domain) { return g_host->Graph__AddNode(this, name, op_type, description, input_args, output_args, attributes, domain); } diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index df47fa5cee4ab..5e931f9f1947f 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -250,17 +250,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { }; the_global_api.model_delete = [](Model* model) { delete model; }; - the_global_api.model_clone = [](const Model& const_model) -> Model* { - auto& logger = logging::LoggingManager::DefaultLogger(); - auto& model = const_cast(const_model); - auto model_proto = model.ToProto(); - auto file_path = model.MainGraph().ModelPath(); - auto local_registries = IOnnxRuntimeOpSchemaRegistryList{model.MainGraph().GetSchemaRegistry()}; - auto ret = Model::Create(std::move(*model_proto), ToPathString(file_path), &local_registries, logger); - auto status = ret->MainGraph().Resolve(); - vai_assert(status.IsOK(), status.ErrorMessage()); - return ret.release(); - }; + the_global_api.model_clone = vaip::model_clone; the_global_api.model_set_meta_data = [](Model& model, const std::string& key, const std::string& value) { const_cast(model.MetaData())[key] = value; }; @@ -454,7 +444,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.graph_set_inputs = [](Graph& graph, gsl::span inputs) { graph.SetInputs(inputs); }; - + the_global_api.node_arg_external_location = vaip::node_arg_external_location; if (!s_library_vitisaiep.vaip_get_version) { return reinterpret_cast(&(the_global_api.host_)); } else { diff --git a/onnxruntime/core/providers/vitisai/imp/graph.cc b/onnxruntime/core/providers/vitisai/imp/graph.cc index 3f46fbde8c714..38e3580ebab2a 100644 --- a/onnxruntime/core/providers/vitisai/imp/graph.cc +++ b/onnxruntime/core/providers/vitisai/imp/graph.cc @@ -160,4 +160,70 @@ Node& graph_fuse(Graph& graph, const std::string& name, } return fused_node; } +Model* model_clone(const Model& original_model) { + // create an empty mode + auto& original_graph = const_cast(original_model).MainGraph(); + auto& logger = logging::LoggingManager::DefaultLogger(); + auto file_path = original_graph.ModelPath(); + auto local_registries = IOnnxRuntimeOpSchemaRegistryList{original_graph.GetSchemaRegistry()}; + auto model_proto = ONNX_NAMESPACE::ModelProto::Create(); + auto graph_proto = model_proto->mutable_graph(); // create a graph + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + for (const auto& op : original_graph.DomainToVersionMap()) { + auto* opset_import = model_proto->add_opset_import(); + *(opset_import->mutable_domain()) = op.first; + opset_import->set_version(op.second); + } + auto graph_input = graph_proto->mutable_input(); + for (const auto& input : original_graph.GetInputs()) { + auto* input_proto = graph_input->Add(); + *input_proto = input->ToProto(); + } + auto graph_output = graph_proto->mutable_output(); + for (const auto& output : original_graph.GetOutputs()) { + auto* output_proto = graph_output->Add(); + *output_proto = output->ToProto(); + } + for (auto& node : original_graph.Nodes()) { + auto* node_proto = graph_proto->add_node(); + node->ToProto(*node_proto, false); + for (auto output : node->OutputDefs()) { + if (output->Exists()) { + auto* value_info = graph_proto->mutable_value_info()->Add(); + *value_info = output->ToProto(); + } + } + } + auto ptr_to_string = [](const void* g) -> std::string { + return std::to_string((uintptr_t)(g)); + }; + auto graph_ptr = ptr_to_string(&original_graph); + for (auto& it : original_graph.GetAllInitializedTensors()) { + auto cloned_tensor = graph_proto->add_initializer(); + auto original_tensor = it.second; + cloned_tensor->set_name(original_tensor->name()); + cloned_tensor->set_data_type(original_tensor->data_type()); + auto& dims = original_tensor->dims(); + int64_t size = 1; + for (auto i = 0; i < dims.size(); ++i) { + auto dim = dims[i]; + cloned_tensor->add_dims(dim); + size = size * dim; + } + constexpr int64_t THRESHOLD = 512; + if (size >= THRESHOLD) { + cloned_tensor->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL); + auto external_data = cloned_tensor->mutable_external_data(); + auto p = external_data->Add(); + *p->mutable_key() = "location"; + *p->mutable_value() = std::string("<") + graph_ptr; + } else { + *cloned_tensor = *original_tensor; + } + } + auto ret = Model::Create(std::move(*model_proto), file_path, &local_registries, logger); + auto status = ret->MainGraph().Resolve(); + vai_assert(status.IsOK(), status.ErrorMessage()); + return ret.release(); +} } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/imp/node_arg.cc b/onnxruntime/core/providers/vitisai/imp/node_arg.cc index a54cbef91c398..aa635998f9361 100644 --- a/onnxruntime/core/providers/vitisai/imp/node_arg.cc +++ b/onnxruntime/core/providers/vitisai/imp/node_arg.cc @@ -70,7 +70,7 @@ void node_arg_set_element_type(NodeArg& node_arg, int type) { const ONNX_NAMESPACE::TensorProto& node_arg_get_const_data_as_tensor( const Graph& graph, const NodeArg& node_arg) { auto tensor_proto = graph.GetConstantInitializer(node_arg.Name(), true); - assert(tensor_proto != nullptr); + vai_assert(tensor_proto != nullptr, (std::string("tensor_proto is not found: name=") + node_arg.Name())); return *tensor_proto; } int node_arg_get_element_type(const NodeArg& node_arg) { @@ -104,4 +104,31 @@ NodeArg& node_arg_new(Graph& graph, const std::string& name, const std::vector(graph.GetConstantInitializer(node_arg.Name(), true)); + vai_assert(tensor_proto != nullptr, (std::string("tensor_proto is not found: name=") + node_arg.Name())); + auto ret = 0; + offset = 0; + size = 0; + checksum = 0; + if (tensor_proto->data_location() != ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL) { + auto external_data = tensor_proto->mutable_external_data(); + auto external_data_size = external_data->size(); + for (auto i = 0; i < external_data_size; ++i) { + auto& data = external_data->at(i); + char* end = nullptr; + if (*data.mutable_key() == "location") { + file = *data.mutable_value(); + ret = 1; + } else if (*data.mutable_key() == "offset") { + offset = (size_t) std::strtoull(data.mutable_value()->data(), &end, 10); + } else if (*data.mutable_key() == "length") { + size = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10); + } else if (*data.mutable_key() == "checksum") { + checksum = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10); + } + } + } + return ret; +} } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/graph.h b/onnxruntime/core/providers/vitisai/include/vaip/graph.h index 292fb2bb38b2b..ce01d553eb4c5 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/graph.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/graph.h @@ -15,5 +15,5 @@ void graph_save(const Graph& graph, const std::string& filename, const std::stri Node& graph_fuse(Graph& graph, const std::string& name, const std::string& op_type, const std::vector& nodes, const std::vector& inputs, const std::vector& outputs, const std::vector& constant_initializers); - +Model* model_clone(const Model& original_model); } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h b/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h index fca641c5e11c8..7804b48fb71c5 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/node_arg.h @@ -28,5 +28,6 @@ void node_arg_set_element_type(NodeArg& node_arg, int data_type); const ONNX_NAMESPACE::TensorProto& node_arg_get_const_data_as_tensor(const Graph& graph, const NodeArg& node_arg); +int node_arg_external_location(const Graph& graph, const NodeArg& node_arg, std::string& file, size_t& offset, size_t& size, size_t& checksum); } // namespace vaip diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index e6aacfe1f0272..598695f0f6a55 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -13,7 +13,7 @@ struct OrtApi; namespace vaip_core { -#define VAIP_ORT_API_MAJOR (4u) +#define VAIP_ORT_API_MAJOR (5u) #define VAIP_ORT_API_MINOR (0u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { @@ -228,6 +228,7 @@ struct OrtApiForVaip { Model* (*create_empty_model)(const std::filesystem::path& path, const std::vector>& opset); //[91] void (*graph_set_inputs)(Graph& graph, gsl::span inputs); // [92] + int(*node_arg_external_location)(const Graph& graph, const NodeArg& node_arg, std::string& file, size_t& offset, size_t& size, size_t& checksum); // [93] }; #ifndef USE_VITISAI diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 6d6940590602a..ff841950b4384 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1180,7 +1180,7 @@ struct ProviderHostImpl : ProviderHost { std::unique_ptr Graph__CreateGraphViewer(const Graph* p) override { return std::make_unique(*p); } std::unique_ptr Graph__ToGraphProto(const Graph* p) override { return std::make_unique(p->ToGraphProto()); } void Graph__SetInputs(Graph* p, gsl::span inputs) override { p->SetInputs(inputs); } - + const std::unordered_map& Graph__DomainToVersionMap(const Graph* p) const noexcept override { return p->DomainToVersionMap(); }; NodeArg& Graph__GetOrCreateNodeArg(Graph* p, const std::string& name, const ONNX_NAMESPACE::TypeProto* p_arg_type) override { return p->GetOrCreateNodeArg(name, p_arg_type); } void Graph__AddOuterScopeNodeArg(Graph* p, const std::string& name) override { p->AddOuterScopeNodeArg(name); }