From 5d0be84312f8df82199941a771127590bef8d982 Mon Sep 17 00:00:00 2001 From: Yueqing Zhang Date: Thu, 23 Nov 2023 02:20:37 +0000 Subject: [PATCH] lint --- .../core/providers/vitisai/imp/global_api.cc | 199 +++++++----------- .../vitisai/vitisai_execution_provider.cc | 39 ++-- 2 files changed, 92 insertions(+), 146 deletions(-) diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 444204f6214cc..aff00a0db3764 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -47,16 +47,22 @@ using json = nlohmann::json; vaip_core::OrtApiForVaip* create_org_api_hook(); struct OrtVitisAIEpAPI { void (*initialize_onnxruntime_vitisai_ep)(vaip_core::OrtApiForVaip* api, std::vector& ret_domain); - std::vector>* (*compile_onnx_model_3)(const std::string& model_path, const onnxruntime::Graph& graph, const char* json_config); - std::vector>* (*compile_onnx_model_with_options)(const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); + std::vector>* (*compile_onnx_model_3)(const std::string& model_path, + const onnxruntime::Graph& graph, + const char* json_config); + std::vector>* (*compile_onnx_model_with_options)( + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); void Ensure() { - if (handle_) - return; - auto full_path = Env::Default().GetRuntimePath() + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); + if (handle_) return; + auto full_path = Env::Default().GetRuntimePath() + + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, true, &handle_)); - ORT_THROW_IF_ERROR(Env::Default().GetSymbolFromLibrary(handle_, "initialize_onnxruntime_vitisai_ep", (void**)&initialize_onnxruntime_vitisai_ep)); - auto status1 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", (void**)&compile_onnx_model_with_options); - auto status2 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep", (void**)&compile_onnx_model_3); + ORT_THROW_IF_ERROR(Env::Default().GetSymbolFromLibrary(handle_, "initialize_onnxruntime_vitisai_ep", + (void**)&initialize_onnxruntime_vitisai_ep)); + auto status1 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", + (void**)&compile_onnx_model_with_options); + auto status2 = + Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep", (void**)&compile_onnx_model_3); if (!status1.IsOK() && !status2.IsOK()) { ::onnxruntime::LogRuntimeError(0, status1, __FILE__, static_cast(__FUNCTION__), __LINE__); ORT_THROW(status1); @@ -97,7 +103,8 @@ static std::string config_to_json_str(const onnxruntime::ProviderOptions& config return ""; } } -vaip_core::DllSafe>> compile_onnx_model_with_options(const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options) { +vaip_core::DllSafe>> compile_onnx_model_with_options( + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options) { if (s_library_vitisaiep.compile_onnx_model_with_options) { return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph, options)); } else { @@ -110,17 +117,16 @@ std::vector initialize_vitisai_ep() { s_library_vitisaiep.Ensure(); Status status = Status::OK(); try { - OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, "onnxruntime-vitisai-ep"}; + OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, + "onnxruntime-vitisai-ep"}; std::ignore = OrtEnv::GetInstance(lm_info, status); } catch (onnxruntime::OnnxRuntimeException& /*e*/) { } auto domains = std::vector(); domains.reserve(100); s_library_vitisaiep.initialize_onnxruntime_vitisai_ep(create_org_api_hook(), domains); - auto& domainToVersionRangeInstance = - ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); - if (domainToVersionRangeInstance.Map().find("com.xilinx") == - domainToVersionRangeInstance.Map().end()) { + auto& domainToVersionRangeInstance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); + if (domainToVersionRangeInstance.Map().find("com.xilinx") == domainToVersionRangeInstance.Map().end()) { vaip::register_xir_ops(domains); } @@ -143,17 +149,14 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.model_delete = [](Model* model) { delete model; }; the_global_api.model_clone = [](const Model& model) -> Model* { auto& logger = logging::LoggingManager::DefaultLogger(); - auto model_proto = - const_cast(model).ToProto(); + auto model_proto = const_cast(model).ToProto(); auto file_path = model.ModelPath().ToPathString(); auto ret = std::make_unique(std::move(model_proto), file_path, nullptr, logger); auto status = ret->MainGraph().Resolve(); vai_assert(status.IsOK(), status.ErrorMessage()); return ret.release(); }; - the_global_api.model_set_meta_data = [](Model& model, const std::string& key, - const std::string& value) - -> void { + the_global_api.model_set_meta_data = [](Model& model, const std::string& key, const std::string& value) -> void { const_cast(model.MetaData())[key] = value; }; the_global_api.model_get_meta_data = [](const Model& model, @@ -172,14 +175,9 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { return m.find(key) != m.end() ? 1 : 0; }; - the_global_api.model_main_graph = [](Model& model) -> Graph& { - return model.MainGraph(); - }; - the_global_api.graph_get_model = [](const Graph& graph) -> const Model& { - return graph.GetModel(); - }; - the_global_api.graph_get_inputs_unsafe = - [](const Graph& graph) -> vaip_core::DllSafe> { + the_global_api.model_main_graph = [](Model& model) -> Graph& { return model.MainGraph(); }; + the_global_api.graph_get_model = [](const Graph& graph) -> const Model& { return graph.GetModel(); }; + the_global_api.graph_get_inputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { auto ret = std::vector(); auto inputs = graph.GetInputs(); for (auto input : inputs) { @@ -188,47 +186,35 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { } return vaip_core::DllSafe(std::move(ret)); }; - the_global_api.graph_get_outputs_unsafe = - [](const Graph& graph) -> vaip_core::DllSafe> { + the_global_api.graph_get_outputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { return vaip_core::DllSafe(graph.GetOutputs()); }; - the_global_api.graph_set_outputs = - [](Graph& graph, gsl::span outputs) -> void { + the_global_api.graph_set_outputs = [](Graph& graph, gsl::span outputs) -> void { return graph.SetOutputs(outputs); }; - the_global_api.graph_get_node_arg = - [](const Graph& graph, const std::string& name) -> const NodeArg* { + the_global_api.graph_get_node_arg = [](const Graph& graph, const std::string& name) -> const NodeArg* { return graph.GetNodeArg(name); }; the_global_api.graph_producer_node = [](const Graph& graph, const std::string& name) -> const Node* { return graph.GetProducerNode(name); }; - the_global_api.graph_get_node = [](const Graph& graph, - size_t index) -> const Node* { - return graph.GetNode(index); - }; + the_global_api.graph_get_node = [](const Graph& graph, size_t index) -> const Node* { return graph.GetNode(index); }; the_global_api.graph_save = vaip::graph_save; the_global_api.graph_fuse = vaip::graph_fuse; the_global_api.graph_remove_node = vaip::graph_remove_node; - the_global_api.graph_add_node = - [](Graph& graph, const std::string& name, const std::string& op_type, - const std::string& description, - const std::vector& input_args, - const std::vector& output_args, - vaip_core::NodeAttributes& attributes, - const std::string& domain) -> Node& { - return vaip::graph_add_node( - graph, name, op_type, description, input_args, output_args, - std::move(reinterpret_cast(attributes)), - domain); - }; - - the_global_api.graph_get_all_initialized_tensors = - [](const Graph& graph) -> const InitializedTensorSet& { + the_global_api.graph_add_node = [](Graph& graph, const std::string& name, const std::string& op_type, + const std::string& description, const std::vector& input_args, + const std::vector& output_args, + vaip_core::NodeAttributes& attributes, const std::string& domain) -> Node& { + return vaip::graph_add_node(graph, name, op_type, description, input_args, output_args, + std::move(reinterpret_cast(attributes)), domain); + }; + + the_global_api.graph_get_all_initialized_tensors = [](const Graph& graph) -> const InitializedTensorSet& { return graph.GetAllInitializedTensors(); }; @@ -241,66 +227,46 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { }; the_global_api.graph_get_consumer_nodes_unsafe = - [](const Graph& graph, - const std::string& node_arg_name) -> vaip_core::DllSafe> { + [](const Graph& graph, const std::string& node_arg_name) -> vaip_core::DllSafe> { return vaip_core::DllSafe(graph.GetConsumerNodes(node_arg_name)); }; - the_global_api.graph_nodes_unsafe = - [](const Graph& graph) -> vaip_core::DllSafe> { + the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { auto& node_refererence = graph.Nodes(); std::vector nodes((size_t)graph.NumberOfNodes(), nullptr); - std::transform(node_refererence.begin(), node_refererence.end(), - nodes.begin(), [](const Node& n) { return &n; }); + std::transform(node_refererence.begin(), node_refererence.end(), nodes.begin(), [](const Node& n) { return &n; }); return vaip_core::DllSafe(std::move(nodes)); }; - the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { - return graph.Name(); + the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { return graph.Name(); }; + the_global_api.graph_reverse_dfs_from = [](const Graph& graph, gsl::span from, + const std::function& enter, + const std::function& leave, + const std::function& stop) { + graph.ReverseDFSFrom(from, enter, leave, nullptr, stop); }; - the_global_api.graph_reverse_dfs_from = - [](const Graph& graph, gsl::span from, - const std::function& enter, - const std::function& leave, - const std::function& stop) { - graph.ReverseDFSFrom(from, enter, leave, nullptr, stop); - }; // node the_global_api.node_get_inputs_unsafe = vaip::node_get_inputs; the_global_api.node_get_output_node_args_unsafe = vaip::node_get_output_node_args; - the_global_api.node_op_type = [](const Node& node) -> const std::string& { - return node.OpType(); - }; - the_global_api.node_op_domain = [](const Node& node) -> const std::string& { - return node.Domain(); - }; - the_global_api.node_get_index = [](const Node& node) -> size_t { - return (size_t)node.Index(); - }; - the_global_api.node_get_name = [](const Node& node) -> const std::string& { - return node.Name(); - }; - the_global_api.node_description = [](const Node& node) -> const std::string& { - return node.Description(); - }; + the_global_api.node_op_type = [](const Node& node) -> const std::string& { return node.OpType(); }; + the_global_api.node_op_domain = [](const Node& node) -> const std::string& { return node.Domain(); }; + the_global_api.node_get_index = [](const Node& node) -> size_t { return (size_t)node.Index(); }; + the_global_api.node_get_name = [](const Node& node) -> const std::string& { return node.Name(); }; + the_global_api.node_description = [](const Node& node) -> const std::string& { return node.Description(); }; - the_global_api.node_get_attributes = - [](Node& node) -> vaip_core::NodeAttributes& { - return reinterpret_cast( - node.GetMutableAttributes()); + the_global_api.node_get_attributes = [](Node& node) -> vaip_core::NodeAttributes& { + return reinterpret_cast(node.GetMutableAttributes()); }; the_global_api.node_type_is_fused = [](const Node& node) { return node.NodeType() == onnxruntime::Node::Type::Fused; }; - the_global_api.node_get_function_body = - [](const Node& node) -> const onnxruntime::Graph& { + the_global_api.node_get_function_body = [](const Node& node) -> const onnxruntime::Graph& { assert(node.GetFunctionBody() != nullptr); return node.GetFunctionBody()->Body(); }; // node_arg - the_global_api.node_arg_get_name_unsafe = - [](const NodeArg& node_arg) -> const std::string& { + the_global_api.node_arg_get_name_unsafe = [](const NodeArg& node_arg) -> const std::string& { return node_arg.Name(); }; the_global_api.node_arg_clone = vaip::node_arg_clone; @@ -311,8 +277,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.node_arg_set_shape_i64 = vaip::node_arg_set_shape_i64; the_global_api.node_arg_get_denotation_unsafe = vaip::node_arg_get_denotation; the_global_api.node_arg_set_denotation = vaip::node_arg_set_denotation; - the_global_api.node_arg_get_const_data_as_tensor = - vaip::node_arg_get_const_data_as_tensor; + the_global_api.node_arg_get_const_data_as_tensor = vaip::node_arg_get_const_data_as_tensor; the_global_api.node_arg_get_element_type = vaip::node_arg_get_element_type; the_global_api.node_arg_set_element_type = [](NodeArg& node_arg, int type) { @@ -374,16 +339,13 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { }; /// attr proto the_global_api.attr_proto_delete = [](onnx::AttributeProto* v) { delete v; }; - the_global_api.attr_proto_clone = - [](const onnx::AttributeProto& v) -> onnx::AttributeProto* { + the_global_api.attr_proto_clone = [](const onnx::AttributeProto& v) -> onnx::AttributeProto* { return new onnx::AttributeProto(v); }; - the_global_api.attr_proto_get_name = - [](const onnx::AttributeProto& attr_proto) -> const std::string& { + the_global_api.attr_proto_get_name = [](const onnx::AttributeProto& attr_proto) -> const std::string& { return attr_proto.name(); }; - the_global_api.attr_proto_set_name = [](onnx::AttributeProto* attr_proto, - const std::string& name) { + the_global_api.attr_proto_set_name = [](onnx::AttributeProto* attr_proto, const std::string& name) { attr_proto->set_name(name); }; the_global_api.attr_proto_new_int = vaip::attr_proto_new_int; @@ -400,17 +362,14 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.attr_proto_get_ints = vaip::attr_proto_get_ints; the_global_api.attr_proto_get_floats = vaip::attr_proto_get_floats; the_global_api.attr_proto_get_strings = vaip::attr_proto_get_strings; - the_global_api.attr_proto_get_type = - [](const onnx::AttributeProto& attr) -> int { return attr.type(); }; + the_global_api.attr_proto_get_type = [](const onnx::AttributeProto& attr) -> int { return attr.type(); }; /// node attributes the_global_api.node_attributes_new = []() { return reinterpret_cast(new NodeAttributes()); }; - the_global_api.node_attributes_add = [](vaip_core::NodeAttributes& p, - onnx::AttributeProto&& attr) { - reinterpret_cast(p).insert_or_assign(attr.name(), - std::move(attr)); + the_global_api.node_attributes_add = [](vaip_core::NodeAttributes& p, onnx::AttributeProto&& attr) { + reinterpret_cast(p).insert_or_assign(attr.name(), std::move(attr)); }; the_global_api.node_attributes_delete = [](vaip_core::NodeAttributes* p) { delete reinterpret_cast(p); @@ -424,7 +383,8 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { } return &it->second; }; - the_global_api.node_attributes_get_keys = [](vaip_core::NodeAttributes& p) -> vaip_core::DllSafe> { + the_global_api.node_attributes_get_keys = + [](vaip_core::NodeAttributes& p) -> vaip_core::DllSafe> { auto ret = std::vector(); auto& attr = reinterpret_cast(p); ret.reserve(attr.size()); @@ -434,34 +394,29 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { return vaip_core::DllSafe(std::move(ret)); }; /// tensor proto - the_global_api.tensor_proto_get_shape_unsafe = [](const onnx::TensorProto& t) -> vaip_core::DllSafe> { + the_global_api.tensor_proto_get_shape_unsafe = + [](const onnx::TensorProto& t) -> vaip_core::DllSafe> { return vaip_core::DllSafe>(vaip::tensor_proto_get_shape(t)); }; - the_global_api.tensor_proto_data_type = - [](const onnx::TensorProto& t) -> int { return t.data_type(); }; + the_global_api.tensor_proto_data_type = [](const onnx::TensorProto& t) -> int { return t.data_type(); }; the_global_api.tensor_proto_delete = [](onnx::TensorProto* tp) { delete tp; }; - the_global_api.tensor_proto_new_floats = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{ - vaip::tensor_proto_new_floats(name, shape, data)}; + the_global_api.tensor_proto_new_floats = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { + return new onnx::TensorProto{vaip::tensor_proto_new_floats(name, shape, data)}; }; - the_global_api.tensor_proto_new_i32 = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { + the_global_api.tensor_proto_new_i32 = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { return new onnx::TensorProto{vaip::tensor_proto_new_i32(name, shape, data)}; }; - the_global_api.tensor_proto_new_i64 = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { + the_global_api.tensor_proto_new_i64 = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { return new onnx::TensorProto{vaip::tensor_proto_new_i64(name, shape, data)}; }; - the_global_api.tensor_proto_new_i8 = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { + the_global_api.tensor_proto_new_i8 = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { return new onnx::TensorProto{vaip::tensor_proto_new_i8(name, shape, data)}; }; the_global_api.tensor_proto_raw_data_size = vaip::tensor_proto_raw_data_size; diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 85cbbfc9c2b41..5f20b32cd6dc4 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -22,8 +22,7 @@ namespace onnxruntime { constexpr const char* VITISAI = "VITISAI"; static vaip_core::DllSafe>> compile_onnx_model( - const onnxruntime::GraphViewer& graph_viewer, - const logging::Logger& logger, const ProviderOptions& options) { + const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { #ifndef _WIN32 auto model_path = graph_viewer.ModelPath().ToPathString(); #else @@ -36,8 +35,8 @@ static vaip_core::DllSafeGetApi(op_.version), - reinterpret_cast(&info)); + op_kernel_ = + op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), reinterpret_cast(&info)); } ~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); } @@ -54,8 +53,7 @@ struct MyCustomOpKernel : OpKernel { void* op_kernel_; }; -VitisAIExecutionProvider::VitisAIExecutionProvider( - const ProviderOptions& info) +VitisAIExecutionProvider::VitisAIExecutionProvider(const ProviderOptions& info) : IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, info_(info) { custom_op_domains_ = initialize_vitisai_ep(); registry_ = std::make_shared(); @@ -76,7 +74,8 @@ void VitisAIExecutionProvider::CreateKernelRegistry() { } } def_builder.Provider(onnxruntime::kVitisAIExecutionProvider); - KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, + std::unique_ptr& out) -> Status { out = std::make_unique(info, *op); return Status::OK(); }; @@ -88,9 +87,8 @@ void VitisAIExecutionProvider::CreateKernelRegistry() { std::shared_ptr VitisAIExecutionProvider::GetKernelRegistry() const { return registry_; } -std::vector> -VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const { +std::vector> VitisAIExecutionProvider::GetCapability( + const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { if (graph.IsSubgraph()) { // VITIS AI EP not support sungraph. Assigned to CPU. return {}; @@ -99,8 +97,7 @@ VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // Only compiling a model once is currently supported return {}; } - execution_providers_ = - std::make_unique(compile_onnx_model(graph, *GetLogger(), info_)); + execution_providers_ = std::make_unique(compile_onnx_model(graph, *GetLogger(), info_)); auto result = vaip::GetComputeCapabilityOps(graph, execution_providers_.get(), vitisai_optypes_); size_t index = 0u; for (auto& ep : **execution_providers_) { @@ -110,16 +107,14 @@ VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, return result; } -common::Status VitisAIExecutionProvider::Compile( - const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) { +common::Status VitisAIExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) { for (const auto& fused_node_graph : fused_nodes_and_graphs) { NodeComputeInfo compute_info; const onnx::AttributeProto* attr = graph_utils::GetNodeAttribute(fused_node_graph.fused_node, "index"); assert(attr != nullptr); size_t index = (size_t)attr->i(); - compute_info.create_state_func = [this, index](ComputeContext* context, - FunctionState* state) { + compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) { auto* p = (**this->execution_providers_)[index]->compile().release(); *state = p; return 0; @@ -127,15 +122,11 @@ common::Status VitisAIExecutionProvider::Compile( compute_info.release_state_func = [](FunctionState state) { if (state) { - delete reinterpret_cast( - state); + delete reinterpret_cast(state); } }; - compute_info.compute_func = [](FunctionState state, const OrtApi* api, - OrtKernelContext* context) { - reinterpret_cast( - state) - ->Compute(api, context); + compute_info.compute_func = [](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + reinterpret_cast(state)->Compute(api, context); return Status::OK(); }; node_compute_funcs.push_back(compute_info);