diff --git a/cmake/onnxruntime_providers_vitisai.cmake b/cmake/onnxruntime_providers_vitisai.cmake index 7ac4a82c89a76..0951c2d02664d 100644 --- a/cmake/onnxruntime_providers_vitisai.cmake +++ b/cmake/onnxruntime_providers_vitisai.cmake @@ -15,16 +15,10 @@ "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.cc" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/imp/*.h" ) - list(REMOVE_ITEM onnxruntime_providers_vitisai_cc_srcs "${ONNXRUNTIME_ROOT}/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc") source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_vitisai_cc_srcs}) onnxruntime_add_static_library(onnxruntime_providers_vitisai ${onnxruntime_providers_vitisai_cc_srcs}) onnxruntime_add_include_to_target(onnxruntime_providers_vitisai onnxruntime_common onnxruntime_framework onnx onnx_proto) - onnxruntime_add_shared_library(onnxruntime_vitisai_ep ${ONNXRUNTIME_ROOT}/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc) - onnxruntime_add_include_to_target(onnxruntime_vitisai_ep onnxruntime_common) - target_include_directories(onnxruntime_vitisai_ep PRIVATE "${ONNXRUNTIME_ROOT}" "${ONNXRUNTIME_ROOT}/core/providers/vitisai/include") - target_link_libraries(onnxruntime_providers_vitisai PUBLIC onnxruntime_vitisai_ep PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json ) - target_compile_definitions(onnxruntime_vitisai_ep - PRIVATE "-DONNXRUNTIME_VITISAI_EP_STUB=1" "-DONNXRUNTIME_VITISAI_EP_EXPORT_DLL=1") + target_link_libraries(onnxruntime_providers_vitisai PRIVATE onnx protobuf::libprotobuf nlohmann_json::nlohmann_json) if(NOT MSVC) target_compile_options(onnxruntime_providers_vitisai PUBLIC $<$:-U_FORTIFY_SOURCE -D_FORTIFY_SOURCE=0>) endif(NOT MSVC) @@ -49,4 +43,4 @@ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} FRAMEWORK DESTINATION ${CMAKE_INSTALL_BINDIR}) - endif() \ No newline at end of file + endif() diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index 59bdd43ec997e..b629c8eff9097 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -2,6 +2,10 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #include "vaip/global_api.h" + +#include +#include + #include "./vai_assert.h" #include "core/common/exceptions.h" #include "core/common/logging/logging.h" @@ -10,10 +14,10 @@ #include "core/graph/model.h" #include "core/session/ort_env.h" +#include "core/session/onnxruntime_cxx_api.h" -#include +#include -#include "core/session/onnxruntime_cxx_api.h" #include "vaip/dll_safe.h" #include "vaip/vaip_ort_api.h" #include "vaip/graph.h" @@ -24,28 +28,107 @@ #include "./attr_proto.h" #include "./register_xir_ops.h" -#include "onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h" - #include "onnxruntime_config.h" #include "version_info.h" // version_info.hpp.in using namespace onnxruntime; +using json = nlohmann::json; + +// The filename extension for a shared library is different per platform +#ifdef _WIN32 +#define LIBRARY_PREFIX +#define LIBRARY_EXTENSION ORT_TSTR(".dll") +#elif defined(__APPLE__) +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".dylib" +#else +#define LIBRARY_PREFIX "lib" +#define LIBRARY_EXTENSION ".so" +#endif + vaip_core::OrtApiForVaip* create_org_api_hook(); +struct OrtVitisAIEpAPI { + void (*initialize_onnxruntime_vitisai_ep)(vaip_core::OrtApiForVaip* api, std::vector& ret_domain); + std::vector>* (*compile_onnx_model_3)(const std::string& model_path, + const onnxruntime::Graph& graph, + const char* json_config); + std::vector>* (*compile_onnx_model_with_options)( + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); + void Ensure() { + if (handle_) return; + auto full_path = Env::Default().GetRuntimePath() + + PathString(LIBRARY_PREFIX ORT_TSTR("onnxruntime_vitisai_ep") LIBRARY_EXTENSION); + ORT_THROW_IF_ERROR(Env::Default().LoadDynamicLibrary(full_path, true, &handle_)); + ORT_THROW_IF_ERROR(Env::Default().GetSymbolFromLibrary( + handle_, "initialize_onnxruntime_vitisai_ep", reinterpret_cast(&initialize_onnxruntime_vitisai_ep))); + auto status1 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep_with_options", + reinterpret_cast(&compile_onnx_model_with_options)); + auto status2 = Env::Default().GetSymbolFromLibrary(handle_, "compile_onnx_model_vitisai_ep", + reinterpret_cast(&compile_onnx_model_3)); + if (!status1.IsOK() && !status2.IsOK()) { + ::onnxruntime::LogRuntimeError(0, status1, __FILE__, static_cast(__FUNCTION__), __LINE__); + ORT_THROW(status1); + } + } + + private: + void* handle_{}; +}; + +static OrtVitisAIEpAPI s_library_vitisaiep; +static std::string config_to_json_str(const onnxruntime::ProviderOptions& config) { + auto iter = config.find("config_file"); + if (iter == config.end()) { + std::cerr << "Error: Key 'config_file' not found in config" << std::endl; + return ""; + } + const auto& filename = config.at("config_file"); + std::ifstream f(filename); + if (!f.is_open()) { + std::cerr << "Error: Failed to open file: " << filename << std::endl; + return ""; + } + nlohmann::json data; + try { + data = nlohmann::json::parse(f); + } catch (const std::exception& e) { + std::cerr << "Error: Failed to parse JSON from file: " << filename << ", Reason: " << e.what() << std::endl; + return ""; + } + for (const auto& entry : config) { + data[entry.first] = entry.second; + } + try { + return data.dump(); + } catch (const std::exception& e) { + std::cerr << "Error: Failed to convert JSON data to string, Reason: " << e.what() << std::endl; + return ""; + } +} +vaip_core::DllSafe>> compile_onnx_model_with_options( + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options) { + if (s_library_vitisaiep.compile_onnx_model_with_options) { + return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_with_options(model_path, graph, options)); + } else { + auto json_str = config_to_json_str(options); + return vaip_core::DllSafe(s_library_vitisaiep.compile_onnx_model_3(model_path, graph, json_str.c_str())); + } +} std::vector initialize_vitisai_ep() { + s_library_vitisaiep.Ensure(); Status status = Status::OK(); try { - OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, "onnxruntime-vitisai-ep"}; + OrtEnv::LoggingManagerConstructionInfo lm_info{nullptr, nullptr, ORT_LOGGING_LEVEL_WARNING, + "onnxruntime-vitisai-ep"}; std::ignore = OrtEnv::GetInstance(lm_info, status); } catch (onnxruntime::OnnxRuntimeException& /*e*/) { } auto domains = std::vector(); domains.reserve(100); - onnxruntime_vitisai_ep::initialize_onnxruntime_vitisai_ep(create_org_api_hook(), domains); - auto& domainToVersionRangeInstance = - ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); - if (domainToVersionRangeInstance.Map().find("com.xilinx") == - domainToVersionRangeInstance.Map().end()) { + s_library_vitisaiep.initialize_onnxruntime_vitisai_ep(create_org_api_hook(), domains); + auto& domainToVersionRangeInstance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); + if (domainToVersionRangeInstance.Map().find("com.xilinx") == domainToVersionRangeInstance.Map().end()) { vaip::register_xir_ops(domains); } @@ -68,17 +151,14 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.model_delete = [](Model* model) { delete model; }; the_global_api.model_clone = [](const Model& model) -> Model* { auto& logger = logging::LoggingManager::DefaultLogger(); - auto model_proto = - const_cast(model).ToProto(); + auto model_proto = const_cast(model).ToProto(); auto file_path = model.ModelPath().ToPathString(); auto ret = std::make_unique(std::move(model_proto), file_path, nullptr, logger); auto status = ret->MainGraph().Resolve(); vai_assert(status.IsOK(), status.ErrorMessage()); return ret.release(); }; - the_global_api.model_set_meta_data = [](Model& model, const std::string& key, - const std::string& value) - -> void { + the_global_api.model_set_meta_data = [](Model& model, const std::string& key, const std::string& value) -> void { const_cast(model.MetaData())[key] = value; }; the_global_api.model_get_meta_data = [](const Model& model, @@ -97,14 +177,9 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { return m.find(key) != m.end() ? 1 : 0; }; - the_global_api.model_main_graph = [](Model& model) -> Graph& { - return model.MainGraph(); - }; - the_global_api.graph_get_model = [](const Graph& graph) -> const Model& { - return graph.GetModel(); - }; - the_global_api.graph_get_inputs_unsafe = - [](const Graph& graph) -> vaip_core::DllSafe> { + the_global_api.model_main_graph = [](Model& model) -> Graph& { return model.MainGraph(); }; + the_global_api.graph_get_model = [](const Graph& graph) -> const Model& { return graph.GetModel(); }; + the_global_api.graph_get_inputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { auto ret = std::vector(); auto inputs = graph.GetInputs(); for (auto input : inputs) { @@ -113,47 +188,35 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { } return vaip_core::DllSafe(std::move(ret)); }; - the_global_api.graph_get_outputs_unsafe = - [](const Graph& graph) -> vaip_core::DllSafe> { + the_global_api.graph_get_outputs_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { return vaip_core::DllSafe(graph.GetOutputs()); }; - the_global_api.graph_set_outputs = - [](Graph& graph, gsl::span outputs) -> void { + the_global_api.graph_set_outputs = [](Graph& graph, gsl::span outputs) -> void { return graph.SetOutputs(outputs); }; - the_global_api.graph_get_node_arg = - [](const Graph& graph, const std::string& name) -> const NodeArg* { + the_global_api.graph_get_node_arg = [](const Graph& graph, const std::string& name) -> const NodeArg* { return graph.GetNodeArg(name); }; the_global_api.graph_producer_node = [](const Graph& graph, const std::string& name) -> const Node* { return graph.GetProducerNode(name); }; - the_global_api.graph_get_node = [](const Graph& graph, - size_t index) -> const Node* { - return graph.GetNode(index); - }; + the_global_api.graph_get_node = [](const Graph& graph, size_t index) -> const Node* { return graph.GetNode(index); }; the_global_api.graph_save = vaip::graph_save; the_global_api.graph_fuse = vaip::graph_fuse; the_global_api.graph_remove_node = vaip::graph_remove_node; - the_global_api.graph_add_node = - [](Graph& graph, const std::string& name, const std::string& op_type, - const std::string& description, - const std::vector& input_args, - const std::vector& output_args, - vaip_core::NodeAttributes& attributes, - const std::string& domain) -> Node& { - return vaip::graph_add_node( - graph, name, op_type, description, input_args, output_args, - std::move(reinterpret_cast(attributes)), - domain); - }; - - the_global_api.graph_get_all_initialized_tensors = - [](const Graph& graph) -> const InitializedTensorSet& { + the_global_api.graph_add_node = [](Graph& graph, const std::string& name, const std::string& op_type, + const std::string& description, const std::vector& input_args, + const std::vector& output_args, + vaip_core::NodeAttributes& attributes, const std::string& domain) -> Node& { + return vaip::graph_add_node(graph, name, op_type, description, input_args, output_args, + std::move(reinterpret_cast(attributes)), domain); + }; + + the_global_api.graph_get_all_initialized_tensors = [](const Graph& graph) -> const InitializedTensorSet& { return graph.GetAllInitializedTensors(); }; @@ -166,66 +229,46 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { }; the_global_api.graph_get_consumer_nodes_unsafe = - [](const Graph& graph, - const std::string& node_arg_name) -> vaip_core::DllSafe> { + [](const Graph& graph, const std::string& node_arg_name) -> vaip_core::DllSafe> { return vaip_core::DllSafe(graph.GetConsumerNodes(node_arg_name)); }; - the_global_api.graph_nodes_unsafe = - [](const Graph& graph) -> vaip_core::DllSafe> { + the_global_api.graph_nodes_unsafe = [](const Graph& graph) -> vaip_core::DllSafe> { auto& node_refererence = graph.Nodes(); - std::vector nodes((size_t)graph.NumberOfNodes(), nullptr); - std::transform(node_refererence.begin(), node_refererence.end(), - nodes.begin(), [](const Node& n) { return &n; }); + std::vector nodes(static_cast(graph.NumberOfNodes()), nullptr); + std::transform(node_refererence.begin(), node_refererence.end(), nodes.begin(), [](const Node& n) { return &n; }); return vaip_core::DllSafe(std::move(nodes)); }; - the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { - return graph.Name(); + the_global_api.graph_get_name = [](const Graph& graph) -> const std::string& { return graph.Name(); }; + the_global_api.graph_reverse_dfs_from = [](const Graph& graph, gsl::span from, + const std::function& enter, + const std::function& leave, + const std::function& stop) { + graph.ReverseDFSFrom(from, enter, leave, nullptr, stop); }; - the_global_api.graph_reverse_dfs_from = - [](const Graph& graph, gsl::span from, - const std::function& enter, - const std::function& leave, - const std::function& stop) { - graph.ReverseDFSFrom(from, enter, leave, nullptr, stop); - }; // node the_global_api.node_get_inputs_unsafe = vaip::node_get_inputs; the_global_api.node_get_output_node_args_unsafe = vaip::node_get_output_node_args; - the_global_api.node_op_type = [](const Node& node) -> const std::string& { - return node.OpType(); - }; - the_global_api.node_op_domain = [](const Node& node) -> const std::string& { - return node.Domain(); - }; - the_global_api.node_get_index = [](const Node& node) -> size_t { - return (size_t)node.Index(); - }; - the_global_api.node_get_name = [](const Node& node) -> const std::string& { - return node.Name(); - }; - the_global_api.node_description = [](const Node& node) -> const std::string& { - return node.Description(); - }; + the_global_api.node_op_type = [](const Node& node) -> const std::string& { return node.OpType(); }; + the_global_api.node_op_domain = [](const Node& node) -> const std::string& { return node.Domain(); }; + the_global_api.node_get_index = [](const Node& node) -> size_t { return static_cast(node.Index()); }; + the_global_api.node_get_name = [](const Node& node) -> const std::string& { return node.Name(); }; + the_global_api.node_description = [](const Node& node) -> const std::string& { return node.Description(); }; - the_global_api.node_get_attributes = - [](Node& node) -> vaip_core::NodeAttributes& { - return reinterpret_cast( - node.GetMutableAttributes()); + the_global_api.node_get_attributes = [](Node& node) -> vaip_core::NodeAttributes& { + return reinterpret_cast(node.GetMutableAttributes()); }; the_global_api.node_type_is_fused = [](const Node& node) { return node.NodeType() == onnxruntime::Node::Type::Fused; }; - the_global_api.node_get_function_body = - [](const Node& node) -> const onnxruntime::Graph& { + the_global_api.node_get_function_body = [](const Node& node) -> const onnxruntime::Graph& { assert(node.GetFunctionBody() != nullptr); return node.GetFunctionBody()->Body(); }; // node_arg - the_global_api.node_arg_get_name_unsafe = - [](const NodeArg& node_arg) -> const std::string& { + the_global_api.node_arg_get_name_unsafe = [](const NodeArg& node_arg) -> const std::string& { return node_arg.Name(); }; the_global_api.node_arg_clone = vaip::node_arg_clone; @@ -236,8 +279,7 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.node_arg_set_shape_i64 = vaip::node_arg_set_shape_i64; the_global_api.node_arg_get_denotation_unsafe = vaip::node_arg_get_denotation; the_global_api.node_arg_set_denotation = vaip::node_arg_set_denotation; - the_global_api.node_arg_get_const_data_as_tensor = - vaip::node_arg_get_const_data_as_tensor; + the_global_api.node_arg_get_const_data_as_tensor = vaip::node_arg_get_const_data_as_tensor; the_global_api.node_arg_get_element_type = vaip::node_arg_get_element_type; the_global_api.node_arg_set_element_type = [](NodeArg& node_arg, int type) { @@ -299,16 +341,13 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { }; /// attr proto the_global_api.attr_proto_delete = [](onnx::AttributeProto* v) { delete v; }; - the_global_api.attr_proto_clone = - [](const onnx::AttributeProto& v) -> onnx::AttributeProto* { + the_global_api.attr_proto_clone = [](const onnx::AttributeProto& v) -> onnx::AttributeProto* { return new onnx::AttributeProto(v); }; - the_global_api.attr_proto_get_name = - [](const onnx::AttributeProto& attr_proto) -> const std::string& { + the_global_api.attr_proto_get_name = [](const onnx::AttributeProto& attr_proto) -> const std::string& { return attr_proto.name(); }; - the_global_api.attr_proto_set_name = [](onnx::AttributeProto* attr_proto, - const std::string& name) { + the_global_api.attr_proto_set_name = [](onnx::AttributeProto* attr_proto, const std::string& name) { attr_proto->set_name(name); }; the_global_api.attr_proto_new_int = vaip::attr_proto_new_int; @@ -325,17 +364,14 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { the_global_api.attr_proto_get_ints = vaip::attr_proto_get_ints; the_global_api.attr_proto_get_floats = vaip::attr_proto_get_floats; the_global_api.attr_proto_get_strings = vaip::attr_proto_get_strings; - the_global_api.attr_proto_get_type = - [](const onnx::AttributeProto& attr) -> int { return attr.type(); }; + the_global_api.attr_proto_get_type = [](const onnx::AttributeProto& attr) -> int { return attr.type(); }; /// node attributes the_global_api.node_attributes_new = []() { return reinterpret_cast(new NodeAttributes()); }; - the_global_api.node_attributes_add = [](vaip_core::NodeAttributes& p, - onnx::AttributeProto&& attr) { - reinterpret_cast(p).insert_or_assign(attr.name(), - std::move(attr)); + the_global_api.node_attributes_add = [](vaip_core::NodeAttributes& p, onnx::AttributeProto&& attr) { + reinterpret_cast(p).insert_or_assign(attr.name(), std::move(attr)); }; the_global_api.node_attributes_delete = [](vaip_core::NodeAttributes* p) { delete reinterpret_cast(p); @@ -349,7 +385,8 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { } return &it->second; }; - the_global_api.node_attributes_get_keys = [](vaip_core::NodeAttributes& p) -> vaip_core::DllSafe> { + the_global_api.node_attributes_get_keys = + [](vaip_core::NodeAttributes& p) -> vaip_core::DllSafe> { auto ret = std::vector(); auto& attr = reinterpret_cast(p); ret.reserve(attr.size()); @@ -359,34 +396,29 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { return vaip_core::DllSafe(std::move(ret)); }; /// tensor proto - the_global_api.tensor_proto_get_shape_unsafe = [](const onnx::TensorProto& t) -> vaip_core::DllSafe> { + the_global_api.tensor_proto_get_shape_unsafe = + [](const onnx::TensorProto& t) -> vaip_core::DllSafe> { return vaip_core::DllSafe>(vaip::tensor_proto_get_shape(t)); }; - the_global_api.tensor_proto_data_type = - [](const onnx::TensorProto& t) -> int { return t.data_type(); }; + the_global_api.tensor_proto_data_type = [](const onnx::TensorProto& t) -> int { return t.data_type(); }; the_global_api.tensor_proto_delete = [](onnx::TensorProto* tp) { delete tp; }; - the_global_api.tensor_proto_new_floats = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { - return new onnx::TensorProto{ - vaip::tensor_proto_new_floats(name, shape, data)}; + the_global_api.tensor_proto_new_floats = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { + return new onnx::TensorProto{vaip::tensor_proto_new_floats(name, shape, data)}; }; - the_global_api.tensor_proto_new_i32 = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { + the_global_api.tensor_proto_new_i32 = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { return new onnx::TensorProto{vaip::tensor_proto_new_i32(name, shape, data)}; }; - the_global_api.tensor_proto_new_i64 = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { + the_global_api.tensor_proto_new_i64 = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { return new onnx::TensorProto{vaip::tensor_proto_new_i64(name, shape, data)}; }; - the_global_api.tensor_proto_new_i8 = - [](const std::string& name, const std::vector& shape, - const std::vector& data) -> onnx::TensorProto* { + the_global_api.tensor_proto_new_i8 = [](const std::string& name, const std::vector& shape, + const std::vector& data) -> onnx::TensorProto* { return new onnx::TensorProto{vaip::tensor_proto_new_i8(name, shape, data)}; }; the_global_api.tensor_proto_raw_data_size = vaip::tensor_proto_raw_data_size; diff --git a/onnxruntime/core/providers/vitisai/include/onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h b/onnxruntime/core/providers/vitisai/include/onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h deleted file mode 100644 index 82f665429c24c..0000000000000 --- a/onnxruntime/core/providers/vitisai/include/onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. -#pragma once -#include -#include -#if defined(_WIN32) -#if ONNXRUNTIME_VITISAI_EP_EXPORT_DLL == 1 -#define ONNXRUNTIME_VITISAI_EP_DLL_SPEC __declspec(dllexport) -#else -#define ONNXRUNTIME_VITISAI_EP_DLL_SPEC __declspec(dllimport) -#endif -#else -#define ONNXRUNTIME_VITISAI_EP_DLL_SPEC __attribute__((visibility("default"))) -#endif - -#ifndef USE_VITISAI -#define USE_VITISAI /* mimic VITISAI EP in ORT */ -#endif - -namespace vaip_core { -class ExecutionProvider; -struct OrtApiForVaip; -template -class DllSafe; -} // namespace vaip_core -namespace onnxruntime { -class Graph; -} -struct OrtCustomOpDomain; -namespace onnxruntime_vitisai_ep { - -ONNXRUNTIME_VITISAI_EP_DLL_SPEC void -initialize_onnxruntime_vitisai_ep(vaip_core::OrtApiForVaip* api, - std::vector& ret_domain); -ONNXRUNTIME_VITISAI_EP_DLL_SPEC -vaip_core::DllSafe>> -compile_onnx_model_3(const std::string& model_path, - const onnxruntime::Graph& graph, const char* json_config); -ONNXRUNTIME_VITISAI_EP_DLL_SPEC -int optimize_onnx_model(const std::filesystem::path& model_path_in, - const std::filesystem::path& model_path_out, - const char* json_config); -} // namespace onnxruntime_vitisai_ep - -extern "C" ONNXRUNTIME_VITISAI_EP_DLL_SPEC const vaip_core::OrtApiForVaip* -get_the_global_api(); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index 8da3882b5af99..c446ab3aefcc5 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -2,6 +2,16 @@ // Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. // Licensed under the MIT License. #pragma once +#include +#include +#include + #include "core/session/onnxruntime_cxx_api.h" +#include "core/framework/provider_options.h" +#include "vaip/my_ort.h" +#include "vaip/dll_safe.h" +#include "vaip/custom_op.h" std::vector initialize_vitisai_ep(); +vaip_core::DllSafe>> compile_onnx_model_with_options( + const std::string& model_path, const onnxruntime::Graph& graph, const onnxruntime::ProviderOptions& options); diff --git a/onnxruntime/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc b/onnxruntime/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc deleted file mode 100644 index 8244c36f822a4..0000000000000 --- a/onnxruntime/core/providers/vitisai/onnxruntime_vitisai_ep_stub.cc +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved. -// Licensed under the MIT License. -#include "vaip/dll_safe.h" -#include "vaip/vaip_ort_api.h" -#include "vaip/custom_op.h" -#include "onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h" -#include -#include -using namespace std; - -namespace onnxruntime_vitisai_ep { -static void my_abort() { - cerr << "please install VitisAI package." << endl; - abort(); -} -using namespace vaip_core; -void initialize_onnxruntime_vitisai_ep(OrtApiForVaip* /*api*/, std::vector& /*domain*/) { - my_abort(); - return; -} // namespace onnxruntime_vitisai_ep -DllSafe>> -compile_onnx_model_3(const std::string& /*model_path*/, const Graph& /*graph*/, - const char* /*json_config*/) { - if (1) { // suppress dead code warning - my_abort(); - } - return DllSafe>>(); -} - -} // namespace onnxruntime_vitisai_ep diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 32ee6ff652aac..5f20b32cd6dc4 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -15,8 +15,6 @@ #include "core/session/custom_ops.h" #include "core/session/inference_session.h" -#include "onnxruntime_vitisai_ep/onnxruntime_vitisai_ep.h" - using namespace ONNX_NAMESPACE; namespace onnxruntime { @@ -24,8 +22,7 @@ namespace onnxruntime { constexpr const char* VITISAI = "VITISAI"; static vaip_core::DllSafe>> compile_onnx_model( - const onnxruntime::GraphViewer& graph_viewer, - const logging::Logger& logger, const char* json_config) { + const onnxruntime::GraphViewer& graph_viewer, const logging::Logger& logger, const ProviderOptions& options) { #ifndef _WIN32 auto model_path = graph_viewer.ModelPath().ToPathString(); #else @@ -33,12 +30,13 @@ static vaip_core::DllSafe strconverter; auto model_path = strconverter.to_bytes(graph_viewer.ModelPath().ToPathString()); #endif - return onnxruntime_vitisai_ep::compile_onnx_model_3(model_path, graph_viewer.GetGraph(), json_config); + return compile_onnx_model_with_options(model_path, graph_viewer.GetGraph(), options); } + struct MyCustomOpKernel : OpKernel { MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { - op_kernel_ = op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), - reinterpret_cast(&info)); + op_kernel_ = + op_.CreateKernel(&op_, OrtGetApiBase()->GetApi(op_.version), reinterpret_cast(&info)); } ~MyCustomOpKernel() override { op_.KernelDestroy(op_kernel_); } @@ -55,8 +53,7 @@ struct MyCustomOpKernel : OpKernel { void* op_kernel_; }; -VitisAIExecutionProvider::VitisAIExecutionProvider( - const VitisAIExecutionProviderInfo& info) +VitisAIExecutionProvider::VitisAIExecutionProvider(const ProviderOptions& info) : IExecutionProvider{onnxruntime::kVitisAIExecutionProvider}, info_(info) { custom_op_domains_ = initialize_vitisai_ep(); registry_ = std::make_shared(); @@ -77,7 +74,8 @@ void VitisAIExecutionProvider::CreateKernelRegistry() { } } def_builder.Provider(onnxruntime::kVitisAIExecutionProvider); - KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { + KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, + std::unique_ptr& out) -> Status { out = std::make_unique(info, *op); return Status::OK(); }; @@ -89,9 +87,8 @@ void VitisAIExecutionProvider::CreateKernelRegistry() { std::shared_ptr VitisAIExecutionProvider::GetKernelRegistry() const { return registry_; } -std::vector> -VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const { +std::vector> VitisAIExecutionProvider::GetCapability( + const onnxruntime::GraphViewer& graph, const IKernelLookup& /*kernel_lookup*/) const { if (graph.IsSubgraph()) { // VITIS AI EP not support sungraph. Assigned to CPU. return {}; @@ -100,9 +97,7 @@ VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, // Only compiling a model once is currently supported return {}; } - auto opt_str = info_.get_json_config_str(); // String - execution_providers_ = - std::make_unique(compile_onnx_model(graph, *GetLogger(), opt_str)); + execution_providers_ = std::make_unique(compile_onnx_model(graph, *GetLogger(), info_)); auto result = vaip::GetComputeCapabilityOps(graph, execution_providers_.get(), vitisai_optypes_); size_t index = 0u; for (auto& ep : **execution_providers_) { @@ -112,16 +107,14 @@ VitisAIExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph, return result; } -common::Status VitisAIExecutionProvider::Compile( - const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) { +common::Status VitisAIExecutionProvider::Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) { for (const auto& fused_node_graph : fused_nodes_and_graphs) { NodeComputeInfo compute_info; const onnx::AttributeProto* attr = graph_utils::GetNodeAttribute(fused_node_graph.fused_node, "index"); assert(attr != nullptr); size_t index = (size_t)attr->i(); - compute_info.create_state_func = [this, index](ComputeContext* context, - FunctionState* state) { + compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) { auto* p = (**this->execution_providers_)[index]->compile().release(); *state = p; return 0; @@ -129,15 +122,11 @@ common::Status VitisAIExecutionProvider::Compile( compute_info.release_state_func = [](FunctionState state) { if (state) { - delete reinterpret_cast( - state); + delete reinterpret_cast(state); } }; - compute_info.compute_func = [](FunctionState state, const OrtApi* api, - OrtKernelContext* context) { - reinterpret_cast( - state) - ->Compute(api, context); + compute_info.compute_func = [](FunctionState state, const OrtApi* api, OrtKernelContext* context) { + reinterpret_cast(state)->Compute(api, context); return Status::OK(); }; node_compute_funcs.push_back(compute_info); diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h index 5bdfc8c18fb6d..e86b53339d4d2 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.h @@ -4,6 +4,10 @@ #pragma once #include +#include +#include +#include +#include #include "core/framework/execution_provider.h" #include "core/framework/customregistry.h" @@ -18,34 +22,19 @@ class ExecutionProvider; } // namespace vaip_core namespace onnxruntime { -// Information needed to construct execution providers. -struct VitisAIExecutionProviderInfo { - VitisAIExecutionProviderInfo(const ProviderOptions& provider_options); - - const char* get_json_config_str() const { - return json_config_.c_str(); - } - - private: - ProviderOptions provider_options_; - const std::string json_config_; -}; - // Logical device representation. class VitisAIExecutionProvider : public IExecutionProvider { public: - explicit VitisAIExecutionProvider(const VitisAIExecutionProviderInfo& info); + explicit VitisAIExecutionProvider(const ProviderOptions& info); ~VitisAIExecutionProvider() = default; - std::vector> - GetCapability(const onnxruntime::GraphViewer& graph, - const IKernelLookup& /*kernel_lookup*/) const override; + std::vector> GetCapability(const onnxruntime::GraphViewer& graph, + const IKernelLookup& /*kernel_lookup*/) const override; int GetDeviceId() const { return 0; } - common::Status Compile( - const std::vector& fused_nodes_and_graphs, - std::vector& node_compute_funcs) override; + common::Status Compile(const std::vector& fused_nodes_and_graphs, + std::vector& node_compute_funcs) override; std::shared_ptr GetKernelRegistry() const override; private: @@ -54,7 +43,7 @@ class VitisAIExecutionProvider : public IExecutionProvider { using my_ep_uptr_t = std::shared_ptr; // we have to hide the implementation by forward declaration. mutable my_ep_uptr_t execution_providers_; - VitisAIExecutionProviderInfo info_; + ProviderOptions info_; std::vector custom_op_domains_; std::shared_ptr registry_; std::set vitisai_optypes_; diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc index 763a3efd1b35b..4c416124ca8f2 100755 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory.cc @@ -3,56 +3,37 @@ #include "vitisai_provider_factory_creator.h" +#include +#include + #include "vaip/global_api.h" #include "./vitisai_execution_provider.h" #include "core/framework/execution_provider.h" #include "core/session/abi_session_options_impl.h" -#include "nlohmann/json.hpp" -#include -#include -#include +#include "core/providers/shared_library/provider_host_api.h" using namespace onnxruntime; -using json = nlohmann::json; namespace onnxruntime { -static std::string ConfigToJsonStr(const std::unordered_map& config) { - const auto& filename = config.at("config_file"); - std::ifstream f(filename); - json data = json::parse(f); - for (const auto& entry : config) { - data[entry.first] = entry.second; - } - return data.dump(); -} - -VitisAIExecutionProviderInfo::VitisAIExecutionProviderInfo(const ProviderOptions& provider_options) : provider_options_(provider_options), json_config_{ConfigToJsonStr(provider_options)} {} - struct VitisAIProviderFactory : IExecutionProviderFactory { - VitisAIProviderFactory(const VitisAIExecutionProviderInfo& info) : info_(info) {} + VitisAIProviderFactory(const ProviderOptions& info) : info_(info) {} ~VitisAIProviderFactory() = default; std::unique_ptr CreateProvider() override; private: - VitisAIExecutionProviderInfo info_; + ProviderOptions info_; }; std::unique_ptr VitisAIProviderFactory::CreateProvider() { return std::make_unique(info_); } -std::shared_ptr -CreateExecutionProviderFactory_VITISAI(const VitisAIExecutionProviderInfo& info) { - initialize_vitisai_ep(); - return std::make_shared(info); -} - -std::shared_ptr VitisAIProviderFactoryCreator::Create(const ProviderOptions& provider_options) { +std::shared_ptr VitisAIProviderFactoryCreator::Create( + const ProviderOptions& provider_options) { initialize_vitisai_ep(); - auto info = VitisAIExecutionProviderInfo{provider_options}; - return std::make_shared(info); + return std::make_shared(provider_options); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h b/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h index 9e0583275d1b6..9bb7cfa062a0f 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h +++ b/onnxruntime/core/providers/vitisai/vitisai_provider_factory_creator.h @@ -9,9 +9,6 @@ #include "core/framework/provider_options.h" namespace onnxruntime { - -struct VitisAIExecutionProviderInfo; - struct VitisAIProviderFactoryCreator { static std::shared_ptr Create(const ProviderOptions& provider_options); }; diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index a5bcbce89bac6..6827f2c9dfd91 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -85,13 +85,6 @@ struct OrtStatus { #define BACKEND_TVM "" #endif -#if USE_VITISAI -#define BACKEND_VITISAI "-VITISAI" -#include "core/providers/vitisai/vitisai_execution_provider.h" -#else -#define BACKEND_VITISAI "" -#endif - #if USE_OPENBLAS #define BACKEND_OPENBLAS "-OPENBLAS" #else @@ -451,9 +444,6 @@ std::shared_ptr CreateExecutionProviderFactory_Dnnl(c std::shared_ptr CreateExecutionProviderFactory_Tvm(const tvm::TvmEPOptions& info); std::shared_ptr CreateExecutionProviderFactory_Tvm(const char* params); #endif -std::shared_ptr CreateExecutionProviderFactory_VITISAI(const char* backend_type, int device_id, - const char* export_runtime_module, - const char* load_runtime_module); std::shared_ptr CreateExecutionProviderFactory_ACL(int use_arena); std::shared_ptr CreateExecutionProviderFactory_ArmNN(int use_arena); std::shared_ptr CreateExecutionProviderFactory_DML(int device_id);