Skip to content

Commit

Permalink
[VitisAI] support vaip create ep context nodes & bug fix (#21506)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
1. We decided to move the context node creation back to our own repo because it is more flexible to modify.
2. We found a bug related the context node. It would change the inference order. So, we fixed in this PR as well.


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
This is crucial for Microsoft Release next month.

---------

Co-authored-by: Yueqing Zhang <[email protected]>
  • Loading branch information
BoarQing and Yueqing Zhang authored Jul 27, 2024
1 parent 690d745 commit d01fc75
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ struct ProviderHost {
virtual ONNX_NAMESPACE::TensorProto* AttributeProto__add_tensors(ONNX_NAMESPACE::AttributeProto* p) = 0;

// GraphProto
virtual std::unique_ptr<ONNX_NAMESPACE::GraphProto> GraphProto__construct() = 0;
virtual void GraphProto__operator_delete(ONNX_NAMESPACE::GraphProto* p) = 0;
virtual void GraphProto__operator_assign(ONNX_NAMESPACE::GraphProto* p, const ONNX_NAMESPACE::GraphProto& v) = 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ struct AttributeProto final {
};

struct GraphProto final {
static std::unique_ptr<GraphProto> Create() { return g_host->GraphProto__construct(); }
static void operator delete(void* p) { g_host->GraphProto__operator_delete(reinterpret_cast<GraphProto*>(p)); }
void operator=(const GraphProto& v) { return g_host->GraphProto__operator_assign(this, v); }

Expand Down
50 changes: 50 additions & 0 deletions onnxruntime/core/providers/vitisai/imp/global_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,15 @@ struct OrtVitisAIEpAPI {
uint32_t (*vaip_get_version)();
void (*get_backend_compilation_cache)(const std::string& model_path, const onnxruntime::Graph& graph, const char* json_config, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data);
void (*restore_backend_compilation_cache)(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path);
void (*create_ep_context_nodes)(
onnxruntime::Graph& ep_context_graph,
const std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>& eps,
vaip_core::DllSafe<std::vector<Node*>>* ret_value) = nullptr;
void Ensure() {
if (handle_)
return;
auto& env = Provider_GetHost()->Env__Default();
auto& logger = *Provider_GetHost()->LoggingManager_GetDefaultLogger();
#ifdef _WIN32
// this dll is already linked to the executable, normally a test program
handle_ = reinterpret_cast<void*>(GetModuleHandle(TEXT("onnxruntime_vitisai_ep.dll")));
Expand All @@ -81,6 +86,10 @@ struct OrtVitisAIEpAPI {
(void**)&vaip_get_version);
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "get_compilation_cache", (void**)&get_backend_compilation_cache));
ORT_THROW_IF_ERROR(env.GetSymbolFromLibrary(handle_, "restore_compilation_cache", (void**)&restore_backend_compilation_cache));
status1 = (env.GetSymbolFromLibrary(handle_, "create_ep_context_nodes", (void**)&create_ep_context_nodes));
if (!status1.IsOK()) {
LOGS(logger, WARNING) << "create_ep_context_nodes is not defined, please upgrade onnxruntime_vitisai_ep.dll. However, it still works.";
}
}

private:
Expand Down Expand Up @@ -146,6 +155,24 @@ void restore_backend_compilation_cache(const std::string& cache_dir, const std::
s_library_vitisaiep.restore_backend_compilation_cache(cache_dir, cache_key, cache_data, model_path);
}

bool has_create_ep_context_nodes() {
return s_library_vitisaiep.create_ep_context_nodes != nullptr;
}

std::optional<std::vector<Node*>> create_ep_context_nodes(
onnxruntime::Graph& ep_context_graph,
const std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>& eps) {
if (s_library_vitisaiep.create_ep_context_nodes) {
vaip_core::DllSafe<std::vector<Node*>> nodes;
s_library_vitisaiep.create_ep_context_nodes(ep_context_graph, eps, &nodes);
if (nodes.get()) {
auto ret = std::vector<Node*>(*nodes);
return ret;
}
}
return std::nullopt;
}

struct MyCustomOpKernel : OpKernel {
MyCustomOpKernel(const OpKernelInfo& info, const OrtCustomOp& op) : OpKernel(info), op_(op) {
op_kernel_ =
Expand Down Expand Up @@ -405,6 +432,29 @@ vaip_core::OrtApiForVaip* create_org_api_hook() {
graph.AddInitializedTensor(tensor);
};

the_global_api.get_model_path = [](const Graph& graph) -> const std::filesystem::path& {
return graph.ModelPath();
};

the_global_api.create_empty_model = [](const std::filesystem::path& path, const std::vector<std::pair<std::string, int64_t>>& opset) -> Model* {
auto model_proto = ONNX_NAMESPACE::ModelProto::Create();
auto graph_proto = ONNX_NAMESPACE::GraphProto::Create();
model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION);
for (const auto& op : opset) {
auto* opset_import = model_proto->add_opset_import();
*(opset_import->mutable_domain()) = op.first;
opset_import->set_version(op.second);
}
std::ignore = model_proto->mutable_graph(); // create a graph
auto& logger = logging::LoggingManager::DefaultLogger();
auto model = Model::Create(std::move(*model_proto), path, nullptr, logger);
return model.release();
};

the_global_api.graph_set_inputs = [](Graph& graph, gsl::span<const NodeArg* const> inputs) {
graph.SetInputs(inputs);
};

if (!s_library_vitisaiep.vaip_get_version) {
return reinterpret_cast<vaip_core::OrtApiForVaip*>(&(the_global_api.host_));
} else {
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/providers/vitisai/include/vaip/custom_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ class ExecutionProvider {
virtual DllSafe<std::vector<std::string>>
get_meta_def_constant_initializer() const = 0;
virtual std::unique_ptr<CustomOp> compile() const = 0;

public:
inline void set_fused_node(const onnxruntime::Node* fused_node) {
fused_node_ = fused_node;
}
inline const onnxruntime::Node* get_fused_node() const {
return fused_node_;
}

private:
const onnxruntime::Node* fused_node_ = nullptr;
};

class CustomOp {
Expand Down
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/vitisai/include/vaip/global_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
#include "vaip/my_ort.h"
#include "vaip/dll_safe.h"
#include "vaip/custom_op.h"

#include <optional>
void initialize_vitisai_ep();
vaip_core::DllSafe<std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>> compile_onnx_model(const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::logging::Logger& logger, const onnxruntime::ProviderOptions& options);
std::shared_ptr<onnxruntime::KernelRegistry> get_kernel_registry_vitisaiep();
const std::vector<OrtCustomOpDomain*>& get_domains_vitisaiep();
void get_backend_compilation_cache(const onnxruntime::PathString& model_path_str, const onnxruntime::GraphViewer& graph_viewer, const onnxruntime::ProviderOptions& options, uint8_t compiler_codes, std::string& cache_dir, std::string& cache_key, std::string& cache_data);
void restore_backend_compilation_cache(const std::string& cache_dir, const std::string& cache_key, const std::string& cache_data, const std::string& model_path);
std::optional<std::vector<onnxruntime::Node*>> create_ep_context_nodes(
onnxruntime::Graph& ep_context_graph,
const std::vector<std::unique_ptr<vaip_core::ExecutionProvider>>& eps);
bool has_create_ep_context_nodes();
11 changes: 8 additions & 3 deletions onnxruntime/core/providers/vitisai/include/vaip/vaip_ort_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
#include <cassert>
#include <functional>
#include <vector>
#include <filesystem>
struct OrtApi;

namespace vaip_core {

#define VAIP_ORT_API_MAJOR (3u)
#define VAIP_ORT_API_MINOR (1u)
#define VAIP_ORT_API_MAJOR (4u)
#define VAIP_ORT_API_MINOR (0u)
#define VAIP_ORT_API_PATCH (0u)
struct OrtApiForVaip {
uint32_t magic; // 'VAIP' or something else to make sure the following field
Expand Down Expand Up @@ -222,7 +223,11 @@ struct OrtApiForVaip {
const std::vector<int16_t>& data); // [88]
TensorProto* (*tensor_proto_new_bf16)(
const std::string& name, const std::vector<int64_t>& shape,
const std::vector<int16_t>& data); // [89]
const std::vector<int16_t>& data); // [89]
const std::filesystem::path& (*get_model_path)(const Graph& graph); // [90]
Model* (*create_empty_model)(const std::filesystem::path& path, const std::vector<std::pair<std::string, int64_t>>& opset); //[91]
void (*graph_set_inputs)(Graph& graph,
gsl::span<const NodeArg* const> inputs); // [92]
};

#ifndef USE_VITISAI
Expand Down
14 changes: 11 additions & 3 deletions onnxruntime/core/providers/vitisai/vitisai_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,15 @@ const InlinedVector<const Node*> VitisAIExecutionProvider::GetEpContextNodes() c
// All preconditions are supposed to have happened.
if (p_ep_ctx_model_) {
auto& graph = p_ep_ctx_model_->MainGraph();
for (const auto* p_node : graph.Nodes()) {
ep_context_node_ptrs.push_back(p_node);
if (has_create_ep_context_nodes()) {
auto nodes = create_ep_context_nodes(graph, **execution_providers_);
if (nodes.has_value()) {
ep_context_node_ptrs.assign(nodes->begin(), nodes->end());
}
} else {
for (const auto* p_node : graph.Nodes()) {
ep_context_node_ptrs.push_back(p_node);
}
}
}
return ep_context_node_ptrs;
Expand Down Expand Up @@ -187,6 +194,7 @@ common::Status VitisAIExecutionProvider::Compile(const std::vector<FusedNodeAndG
auto& attrs = fused_node_graph.fused_node.get().GetAttributes();
assert(attrs.count("index"));
size_t index = attrs.at("index").i();
(**this->execution_providers_)[index]->set_fused_node(&fused_node_graph.fused_node.get());
compute_info.create_state_func = [this, index](ComputeContext* context, FunctionState* state) {
auto* p = (**this->execution_providers_)[index]->compile().release();
*state = p;
Expand All @@ -204,7 +212,7 @@ common::Status VitisAIExecutionProvider::Compile(const std::vector<FusedNodeAndG
};
node_compute_funcs.push_back(compute_info);
}
if (ep_ctx_enabled_ && p_ep_ctx_model_) {
if (ep_ctx_enabled_ && p_ep_ctx_model_ && !has_create_ep_context_nodes()) {
FulfillEPContextEnablement(fused_nodes_and_graphs);
}
return Status::OK();
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ struct ProviderHostImpl : ProviderHost {
ONNX_NAMESPACE::TensorProto* AttributeProto__add_tensors(ONNX_NAMESPACE::AttributeProto* p) override { return p->add_tensors(); }

// GraphProto (wrapped)
std::unique_ptr<ONNX_NAMESPACE::GraphProto> GraphProto__construct() override { return std::make_unique<ONNX_NAMESPACE::GraphProto>(); }
void GraphProto__operator_delete(ONNX_NAMESPACE::GraphProto* p) override { delete p; }

const ONNX_NAMESPACE::ValueInfoProto& GraphProto__input(const ONNX_NAMESPACE::GraphProto* p, int index) override { return p->input(index); }
Expand Down

0 comments on commit d01fc75

Please sign in to comment.