Skip to content

Commit

Permalink
optimize model clone for VitisAI
Browse files Browse the repository at this point in the history
  • Loading branch information
Yueqing Zhang committed Aug 12, 2024
1 parent 154084e commit f403610
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ struct ProviderHost {
virtual NodeArg& Graph__GetOrCreateNodeArg(Graph* p, const std::string& name, const ONNX_NAMESPACE::TypeProto* p_arg_type) = 0;
virtual void Graph__AddOuterScopeNodeArg(Graph* p, const std::string& name) = 0;
virtual void Graph__SetInputs(Graph* p, gsl::span<const NodeArg* const> inputs) = 0;

virtual const std::unordered_map<std::string, int>& Graph__DomainToVersionMap(const Graph* p) const noexcept = 0;
virtual Status Graph__Resolve(Graph* p) = 0;
virtual void Graph__AddInitializedTensor(Graph* p, const ONNX_NAMESPACE::TensorProto& tensor) = 0;
virtual Node& Graph__AddNode(Graph* p, const std::string& name, const std::string& op_type, const std::string& description, const gsl::span<NodeArg* const>& input_args, const gsl::span<NodeArg* const>& output_args, const NodeAttributes* attributes, const std::string& domain) = 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -943,7 +943,7 @@ struct Graph final {
NodeArg& GetOrCreateNodeArg(const std::string& name, const ONNX_NAMESPACE::TypeProto* p_arg_type) { return g_host->Graph__GetOrCreateNodeArg(this, name, p_arg_type); }
void AddOuterScopeNodeArg(const std::string& name) { g_host->Graph__AddOuterScopeNodeArg(this, name); }
void SetInputs(gsl::span<const NodeArg* const> inputs) { g_host->Graph__SetInputs(this, inputs); }

const std::unordered_map<std::string, int>& DomainToVersionMap() const noexcept { return g_host->Graph__DomainToVersionMap(this); }
Status Resolve() { return g_host->Graph__Resolve(this); }
void AddInitializedTensor(const ONNX_NAMESPACE::TensorProto& tensor) { return g_host->Graph__AddInitializedTensor(this, tensor); }
Node& AddNode(const std::string& name, const std::string& op_type, const std::string& description, gsl::span<NodeArg* const> input_args, gsl::span<NodeArg* const> output_args, const NodeAttributes* attributes, const std::string& domain) { return g_host->Graph__AddNode(this, name, op_type, description, input_args, output_args, attributes, domain); }
Expand Down
14 changes: 2 additions & 12 deletions onnxruntime/core/providers/vitisai/imp/global_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,17 +250,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
};
the_global_api.model_delete = [](Model* model) { delete model; };

the_global_api.model_clone = [](const Model& const_model) -> Model* {
auto& logger = logging::LoggingManager::DefaultLogger();
auto& model = const_cast<onnxruntime::Model&>(const_model);
auto model_proto = model.ToProto();
auto file_path = model.MainGraph().ModelPath();
auto local_registries = IOnnxRuntimeOpSchemaRegistryList{model.MainGraph().GetSchemaRegistry()};
auto ret = Model::Create(std::move(*model_proto), ToPathString(file_path), &local_registries, logger);
auto status = ret->MainGraph().Resolve();
vai_assert(status.IsOK(), status.ErrorMessage());
return ret.release();
};
the_global_api.model_clone = vaip::model_clone;
the_global_api.model_set_meta_data = [](Model& model, const std::string& key, const std::string& value) {
const_cast<ModelMetaData&>(model.MetaData())[key] = value;
};
Expand Down Expand Up @@ -454,7 +444,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
the_global_api.graph_set_inputs = [](Graph& graph, gsl::span<const NodeArg* const> inputs) {
graph.SetInputs(inputs);
};

the_global_api.node_arg_external_location = vaip::node_arg_external_location;
if (!s_library_vitisaiep.vaip_get_version) {
return reinterpret_cast<vaip_core::OrtApiForVaip*>(&(the_global_api.host_));
} else {
Expand Down
66 changes: 66 additions & 0 deletions onnxruntime/core/providers/vitisai/imp/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,4 +160,70 @@ Node& graph_fuse(Graph& graph, const std::string& name,
}
return fused_node;
}
Model* model_clone(const Model& original_model) {
// create an empty mode
auto& original_graph = const_cast<Model&>(original_model).MainGraph();
auto& logger = logging::LoggingManager::DefaultLogger();
auto file_path = original_graph.ModelPath();
auto local_registries = IOnnxRuntimeOpSchemaRegistryList{original_graph.GetSchemaRegistry()};
auto model_proto = ONNX_NAMESPACE::ModelProto::Create();
auto graph_proto = model_proto->mutable_graph(); // create a graph
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
for (const auto& op : original_graph.DomainToVersionMap()) {
auto* opset_import = model_proto->add_opset_import();
*(opset_import->mutable_domain()) = op.first;
opset_import->set_version(op.second);
}
auto graph_input = graph_proto->mutable_input();
for (const auto& input : original_graph.GetInputs()) {
auto* input_proto = graph_input->Add();
*input_proto = input->ToProto();
}
auto graph_output = graph_proto->mutable_output();
for (const auto& output : original_graph.GetOutputs()) {
auto* output_proto = graph_output->Add();
*output_proto = output->ToProto();
}
for (auto& node : original_graph.Nodes()) {
auto* node_proto = graph_proto->add_node();
node->ToProto(*node_proto, false);
for (auto output : node->OutputDefs()) {
if (output->Exists()) {
auto* value_info = graph_proto->mutable_value_info()->Add();
*value_info = output->ToProto();
}
}
}
auto ptr_to_string = [](const void* g) -> std::string {
return std::to_string((uintptr_t)(g));
};
auto graph_ptr = ptr_to_string(&original_graph);
for (auto& it : original_graph.GetAllInitializedTensors()) {
auto cloned_tensor = graph_proto->add_initializer();
auto original_tensor = it.second;
cloned_tensor->set_name(original_tensor->name());
cloned_tensor->set_data_type(original_tensor->data_type());
auto& dims = original_tensor->dims();
int64_t size = 1;
for (auto i = 0; i < dims.size(); ++i) {
auto dim = dims[i];
cloned_tensor->add_dims(dim);
size = size * dim;
}
constexpr int64_t THRESHOLD = 512;
if (size >= THRESHOLD) {
cloned_tensor->set_data_location(ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL);
auto external_data = cloned_tensor->mutable_external_data();
auto p = external_data->Add();
*p->mutable_key() = "location";
*p->mutable_value() = std::string("<") + graph_ptr;
} else {
*cloned_tensor = *original_tensor;
}
}
auto ret = Model::Create(std::move(*model_proto), file_path, &local_registries, logger);
auto status = ret->MainGraph().Resolve();
vai_assert(status.IsOK(), status.ErrorMessage());
return ret.release();
}
} // namespace vaip
29 changes: 28 additions & 1 deletion onnxruntime/core/providers/vitisai/imp/node_arg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void node_arg_set_element_type(NodeArg& node_arg, int type) {
const ONNX_NAMESPACE::TensorProto& node_arg_get_const_data_as_tensor(
const Graph& graph, const NodeArg& node_arg) {
auto tensor_proto = graph.GetConstantInitializer(node_arg.Name(), true);
assert(tensor_proto != nullptr);
vai_assert(tensor_proto != nullptr, (std::string("tensor_proto is not found: name=") + node_arg.Name()));
return *tensor_proto;
}
int node_arg_get_element_type(const NodeArg& node_arg) {
Expand Down Expand Up @@ -104,4 +104,31 @@ NodeArg& node_arg_new(Graph& graph, const std::string& name, const std::vector<i
}
return graph.GetOrCreateNodeArg(name, type_proto.release());
}
int node_arg_external_location(const Graph& graph, const NodeArg& node_arg, std::string& file, size_t& offset, size_t& size, size_t&checksum) {
auto tensor_proto = const_cast<ONNX_NAMESPACE::TensorProto*>(graph.GetConstantInitializer(node_arg.Name(), true));
vai_assert(tensor_proto != nullptr, (std::string("tensor_proto is not found: name=") + node_arg.Name()));
auto ret = 0;
offset = 0;
size = 0;
checksum = 0;
if (tensor_proto->data_location() != ONNX_NAMESPACE::TensorProto_DataLocation::TensorProto_DataLocation_EXTERNAL) {
auto external_data = tensor_proto->mutable_external_data();
auto external_data_size = external_data->size();
for (auto i = 0; i < external_data_size; ++i) {
auto& data = external_data->at(i);
char* end = nullptr;
if (*data.mutable_key() == "location") {
file = *data.mutable_value();
ret = 1;
} else if (*data.mutable_key() == "offset") {
offset = (size_t) std::strtoull(data.mutable_value()->data(), &end, 10);
} else if (*data.mutable_key() == "length") {
size = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10);
} else if (*data.mutable_key() == "checksum") {
checksum = (size_t)std::strtoull(data.mutable_value()->data(), &end, 10);
}
}
}
return ret;
}
} // namespace vaip
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/vitisai/include/vaip/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@ void graph_save(const Graph& graph, const std::string& filename, const std::stri
Node& graph_fuse(Graph& graph, const std::string& name, const std::string& op_type, const std::vector<size_t>& nodes,
const std::vector<std::string>& inputs, const std::vector<std::string>& outputs,
const std::vector<std::string>& constant_initializers);

Model* model_clone(const Model& original_model);
} // namespace vaip
1 change: 1 addition & 0 deletions onnxruntime/core/providers/vitisai/include/vaip/node_arg.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,6 @@ void node_arg_set_element_type(NodeArg& node_arg,
int data_type);
const ONNX_NAMESPACE::TensorProto& node_arg_get_const_data_as_tensor(const Graph& graph,
const NodeArg& node_arg);
int node_arg_external_location(const Graph& graph, const NodeArg& node_arg, std::string& file, size_t& offset, size_t& size, size_t& checksum);

} // namespace vaip
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ struct OrtApi;

namespace vaip_core {

#define VAIP_ORT_API_MAJOR (4u)
#define VAIP_ORT_API_MAJOR (5u)
#define VAIP_ORT_API_MINOR (0u)
#define VAIP_ORT_API_PATCH (0u)
struct OrtApiForVaip {
Expand Down Expand Up @@ -228,6 +228,7 @@ struct OrtApiForVaip {
Model* (*create_empty_model)(const std::filesystem::path& path, const std::vector<std::pair<std::string, int64_t>>& opset); //[91]
void (*graph_set_inputs)(Graph& graph,
gsl::span<const NodeArg* const> inputs); // [92]
int(*node_arg_external_location)(const Graph& graph, const NodeArg& node_arg, std::string& file, size_t& offset, size_t& size, size_t& checksum); // [93]
};

#ifndef USE_VITISAI
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1180,7 +1180,7 @@ struct ProviderHostImpl : ProviderHost {
std::unique_ptr<GraphViewer> Graph__CreateGraphViewer(const Graph* p) override { return std::make_unique<GraphViewer>(*p); }
std::unique_ptr<ONNX_NAMESPACE::GraphProto> Graph__ToGraphProto(const Graph* p) override { return std::make_unique<ONNX_NAMESPACE::GraphProto>(p->ToGraphProto()); }
void Graph__SetInputs(Graph* p, gsl::span<const NodeArg* const> inputs) override { p->SetInputs(inputs); }

const std::unordered_map<std::string, int>& Graph__DomainToVersionMap(const Graph* p) const noexcept override { return p->DomainToVersionMap(); };
NodeArg& Graph__GetOrCreateNodeArg(Graph* p, const std::string& name, const ONNX_NAMESPACE::TypeProto* p_arg_type) override { return p->GetOrCreateNodeArg(name, p_arg_type); }
void Graph__AddOuterScopeNodeArg(Graph* p, const std::string& name) override { p->AddOuterScopeNodeArg(name); }

Expand Down

0 comments on commit f403610

Please sign in to comment.