Skip to content

Commit

Permalink
[VitisAI] 1. KernelDef supports StartVersion and EndVersion (#21519)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->

[VitisAI] 1. KernelDef supports StartVersion and EndVersion
2. CapabilityOps checks domain

### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Zhenze Wang <[email protected]>
  • Loading branch information
zz002 and Zhenze Wang authored Jul 27, 2024
1 parent 5af423c commit 690d745
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 4 deletions.
6 changes: 5 additions & 1 deletion onnxruntime/core/providers/vitisai/imp/capability.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ GetComputeCapabilityOps(const onnxruntime::GraphViewer& graph,

std::vector<NodeIndex> node_indexs = graph.GetNodesInTopologicalOrder();
node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_nodes_included_eps.count(index) > 0; }), node_indexs.end());
node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(), [&](NodeIndex index) { return all_support_optypes_by_eps.count(graph.GetNode(index)->OpType()) == 0; }), node_indexs.end());
node_indexs.erase(std::remove_if(node_indexs.begin(), node_indexs.end(),
[&](NodeIndex index) {
auto node = graph.GetNode(index);
return all_support_optypes_by_eps.count(node->Domain() + ":" + node->OpType()) == 0; }),
node_indexs.end());

std::vector<std::unique_ptr<ComputeCapability>> result;
for (auto& n : node_indexs) {
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/vitisai/imp/global_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ void create_kernel_registry(std::vector<OrtCustomOpDomain*> domains) {
auto def_builder = KernelDefBuilder::Create();
def_builder->SetName(op->GetName(op));
def_builder->SetDomain(domain->domain_.c_str());
def_builder->SinceVersion(1);
def_builder->SinceVersion(op->GetStartVersion(op), op->GetEndVersion(op));
if (op->version > 12) {
auto input_count = op->GetInputTypeCount(op);
for (auto i = 0u; i < input_count; i++) {
Expand All @@ -183,7 +183,7 @@ void create_kernel_registry(std::vector<OrtCustomOpDomain*> domains) {
def_builder->Provider(onnxruntime::kVitisAIExecutionProvider);
KernelCreateFn kernel_create_fn =
[op](FuncManager&, const OpKernelInfo& info, std::unique_ptr<OpKernel>& out) -> Status {
// out = std::make_unique<MyCustomOpKernel>(info, *op);
out = std::make_unique<MyCustomOpKernel>(info, *op);
return Status::OK();
};
std::ignore = s_kernel_registry_vitisaiep->Register(KernelCreateInfo(def_builder->Build(), kernel_create_fn));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ VitisAIExecutionProvider::VitisAIExecutionProvider(
void VitisAIExecutionProvider::CreateKernelRegistry() {
for (const auto& domain : get_domains_vitisaiep()) {
for (const auto* op : domain->custom_ops_) {
vitisai_optypes_.insert(op->GetName(op));
vitisai_optypes_.insert(domain->domain_ + ":" + op->GetName(op));
}
}
}
Expand Down

0 comments on commit 690d745

Please sign in to comment.