From d01fc75ef161a624c4275f89cb068cc1c79d9392 Mon Sep 17 00:00:00 2001 From: Yueqing Zhang Date: Fri, 26 Jul 2024 22:15:57 -0700 Subject: [PATCH] [VitisAI] support vaip create ep context nodes & bug fix (#21506) ### Description 1. We decided to move the context node creation back to our own repo because it is more flexible to modify. 2. We found a bug related the context node. It would change the inference order. So, we fixed in this PR as well. ### Motivation and Context This is crucial for Microsoft Release next month. --------- Co-authored-by: Yueqing Zhang --- .../shared_library/provider_interfaces.h | 1 + .../shared_library/provider_wrappedtypes.h | 1 + .../core/providers/vitisai/imp/global_api.cc | 50 +++++++++++++++++++ .../vitisai/include/vaip/custom_op.h | 11 ++++ .../vitisai/include/vaip/global_api.h | 6 ++- .../vitisai/include/vaip/vaip_ort_api.h | 11 ++-- .../vitisai/vitisai_execution_provider.cc | 14 ++++-- .../core/session/provider_bridge_ort.cc | 1 + 8 files changed, 88 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 382b3ac932520..a9394838aa784 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -388,6 +388,7 @@ struct ProviderHost { virtual ONNX_NAMESPACE::TensorProto* AttributeProto__add_tensors(ONNX_NAMESPACE::AttributeProto* p) = 0; // GraphProto + virtual std::unique_ptr GraphProto__construct() = 0; virtual void GraphProto__operator_delete(ONNX_NAMESPACE::GraphProto* p) = 0; virtual void GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index de6c1da1d6430..242c7126f3274 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -146,6 +146,7 @@ struct AttributeProto final { }; struct GraphProto final { + static std::unique_ptr Create() { return g_host->GraphProto__construct(); } static void operator delete(void* p) { g_host->GraphProto__operator_delete(reinterpret_cast(p)); } void operator=(const GraphProto& v) { return g_host->GraphProto__operator_assign(this, v); } diff --git a/onnxruntime/core/providers/vitisai/imp/global_api.cc b/onnxruntime/core/providers/vitisai/imp/global_api.cc index a86a4fb61d54d..df47fa5cee4ab 100644 --- a/onnxruntime/core/providers/vitisai/imp/global_api.cc +++ b/onnxruntime/core/providers/vitisai/imp/global_api.cc @@ -55,10 +55,15 @@ struct OrtVitisAIEpAPI { uint32_t (*vaip_get_version)(); void (*get_backend_compilation_cache)(const std::string& model_path, const onnxruntime::Graph& graph, const char* json_config, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data); void (*restore_backend_compilation_cache)(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path); + void (*create_ep_context_nodes)( + onnxruntime::Graph& ep_context_graph, + const std::vector>& eps, + vaip_core::DllSafe>* ret_value) = nullptr; void Ensure() { if (handle_) return; auto& env = Provider_GetHost()->Env__Default(); + auto& logger = *Provider_GetHost()->LoggingManager_GetDefaultLogger(); #ifdef _WIN32 // this dll is already linked to the executable, normally a test program handle_ = reinterpret_cast(GetModuleHandle(TEXT("onnxruntime_vitisai_ep.dll"))); @@ -81,6 +86,10 @@ struct OrtVitisAIEpAPI { (void**)&vaip_get_version); ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "get_compilation_cache", (void**)&get_backend_compilation_cache)); ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "restore_compilation_cache", (void**)&restore_backend_compilation_cache)); + status1 = (env.GetSymbolFromLibrary(handle_, "create_ep_context_nodes", (void**)&create_ep_context_nodes)); + if (!status1.IsOK()) { + LOGS(logger, WARNING) << "create_ep_context_nodes is not defined, please upgrade onnxruntime_vitisai_ep.dll. However, it still works."; + } } private: @@ -146,6 +155,24 @@ void restore_backend_compilation_cache(const std::string& cache_dir, const std:: s_library_vitisaiep.restore_backend_compilation_cache(cache_dir, cache_key, cache_data, model_path); } +bool has_create_ep_context_nodes() { + return s_library_vitisaiep.create_ep_context_nodes != nullptr; +} + +std::optional> create_ep_context_nodes( + onnxruntime::Graph& ep_context_graph, + const std::vector>& eps) { + if (s_library_vitisaiep.create_ep_context_nodes) { + vaip_core::DllSafe> nodes; + s_library_vitisaiep.create_ep_context_nodes(ep_context_graph, eps, &nodes); + if (nodes.get()) { + auto ret = std::vector(*nodes); + return ret; + } + } + return std::nullopt; +} + struct MyCustomOpKernel : OpKernel { MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) { op_kernel_ = @@ -405,6 +432,29 @@ vaip_core::OrtApiForVaip* create_org_api_hook() { graph.AddInitializedTensor(tensor); }; + the_global_api.get_model_path = [](const Graph& graph) -> const std::filesystem::path& { + return graph.ModelPath(); + }; + + the_global_api.create_empty_model = [](const std::filesystem::path& path, const std::vector>& opset) -> Model* { + auto model_proto = ONNX_NAMESPACE::ModelProto::Create(); + auto graph_proto = ONNX_NAMESPACE::GraphProto::Create(); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); + for (const auto& op : opset) { + auto* opset_import = model_proto->add_opset_import(); + *(opset_import->mutable_domain()) = op.first; + opset_import->set_version(op.second); + } + std::ignore = model_proto->mutable_graph(); // create a graph + auto& logger = logging::LoggingManager::DefaultLogger(); + auto model = Model::Create(std::move(*model_proto), path, nullptr, logger); + return model.release(); + }; + + the_global_api.graph_set_inputs = [](Graph& graph, gsl::span inputs) { + graph.SetInputs(inputs); + }; + if (!s_library_vitisaiep.vaip_get_version) { return reinterpret_cast(&(the_global_api.host_)); } else { diff --git a/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h b/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h index d34f7095b704d..5d020e00ff5b7 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/custom_op.h @@ -26,6 +26,17 @@ class ExecutionProvider { virtual DllSafe> get_meta_def_constant_initializer() const = 0; virtual std::unique_ptr compile() const = 0; + + public: + inline void set_fused_node(const onnxruntime::Node* fused_node) { + fused_node_ = fused_node; + } + inline const onnxruntime::Node* get_fused_node() const { + return fused_node_; + } + + private: + const onnxruntime::Node* fused_node_ = nullptr; }; class CustomOp { diff --git a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h index 3fdbc60bb0ee6..ae2a513a98e32 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/global_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/global_api.h @@ -9,10 +9,14 @@ #include "vaip/my_ort.h" #include "vaip/dll_safe.h" #include "vaip/custom_op.h" - +#include void initialize_vitisai_ep(); vaip_core::DllSafe>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options); std::shared_ptr get_kernel_registry_vitisaiep(); const std::vector& get_domains_vitisaiep(); void get_backend_compilation_cache(const onnxruntime::PathString& model_path_str, const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::ProviderOptions& options, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data); void restore_backend_compilation_cache(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path); +std::optional> create_ep_context_nodes( + onnxruntime::Graph& ep_context_graph, + const std::vector>& eps); +bool has_create_ep_context_nodes(); diff --git a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h index 3346739890484..e6aacfe1f0272 100644 --- a/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h +++ b/onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h @@ -8,12 +8,13 @@ #include #include #include +#include struct OrtApi; namespace vaip_core { -#define VAIP_ORT_API_MAJOR (3u) -#define VAIP_ORT_API_MINOR (1u) +#define VAIP_ORT_API_MAJOR (4u) +#define VAIP_ORT_API_MINOR (0u) #define VAIP_ORT_API_PATCH (0u) struct OrtApiForVaip { uint32_t magic; // 'VAIP' or something else to make sure the following field @@ -222,7 +223,11 @@ struct OrtApiForVaip { const std::vector& data); // [88] TensorProto* (*tensor_proto_new_bf16)( const std::string& name, const std::vector& shape, - const std::vector& data); // [89] + const std::vector& data); // [89] + const std::filesystem::path& (*get_model_path)(const Graph& graph); // [90] + Model* (*create_empty_model)(const std::filesystem::path& path, const std::vector>& opset); //[91] + void (*graph_set_inputs)(Graph& graph, + gsl::span inputs); // [92] }; #ifndef USE_VITISAI diff --git a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc index 58fef537535d2..756bda2199e89 100644 --- a/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc +++ b/onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc @@ -58,8 +58,15 @@ const InlinedVector VitisAIExecutionProvider::GetEpContextNodes() c // All preconditions are supposed to have happened. if (p_ep_ctx_model_) { auto& graph = p_ep_ctx_model_->MainGraph(); - for (const auto* p_node : graph.Nodes()) { - ep_context_node_ptrs.push_back(p_node); + if (has_create_ep_context_nodes()) { + auto nodes = create_ep_context_nodes(graph, **execution_providers_); + if (nodes.has_value()) { + ep_context_node_ptrs.assign(nodes->begin(), nodes->end()); + } + } else { + for (const auto* p_node : graph.Nodes()) { + ep_context_node_ptrs.push_back(p_node); + } } } return ep_context_node_ptrs; @@ -187,6 +194,7 @@ common::Status VitisAIExecutionProvider::Compile(const std::vectorexecution_providers_)[index]->set_fused_node(&fused_node_graph.fused_node.get()); compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) { auto* p = (**this->execution_providers_)[index]->compile().release(); *state = p; @@ -204,7 +212,7 @@ common::Status VitisAIExecutionProvider::Compile(const std::vectoradd_tensors(); } // GraphProto (wrapped) + std::unique_ptr GraphProto__construct() override { return std::make_unique(); } void GraphProto__operator_delete(ONNX_NAMESPACE::GraphProto* p) override { delete p; } const ONNX_NAMESPACE::ValueInfoProto& GraphProto__input(const ONNX_NAMESPACE::GraphProto* p, int index) override { return p->input(index); }