Skip to content

Commit

Permalink
Build function bodies according to the imported global opset.
Browse files Browse the repository at this point in the history
Same is for querying ONNX functions.
  • Loading branch information
yuslepukhin committed Dec 15, 2023
1 parent b129f42 commit 064f8f2
Showing 1 changed file with 36 additions and 12 deletions.
48 changes: 36 additions & 12 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,17 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot
onnx_function_proto = *func_template_->onnx_func_proto_;
return true;
} else if (op_) {
auto get_opset_version = [op = op_](Graph* graph) -> std::optional<int> {
if (op->domain() == kOnnxDomain) {
const auto& domain_to_version = graph->DomainToVersionMap();
const auto iter = domain_to_version.find(kOnnxDomain);
if (iter != domain_to_version.cend()) {
return iter->second;
}
}
return {};
};

// Check if this node has a schema defined function proto.
if (op_->HasContextDependentFunction()) {
NodeProto node_proto;
Expand All @@ -595,8 +606,13 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot
} else
input_types.emplace_back();
}

auto requested_opset_version = get_opset_version(graph_);
if (!requested_opset_version.has_value()) {
requested_opset_version = SinceVersion();
}
ONNX_NAMESPACE::FunctionBodyBuildContextImpl function_body_ctx(node_proto, input_types);
return op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto);
return op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto, *requested_opset_version);
} else if (op_->HasFunction()) {
const FunctionProto* function_ptr = nullptr;
// We need to get a function-body suitable for the ONNX opset used by the model.
Expand All @@ -605,17 +621,25 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot
// as the default-version, which is incorrect in the case of functions belonging to
// non-onnx domains, like MSDOMAIN.

// We use the following as a temporary hack.
function_ptr = op_->GetFunction(SinceVersion(), false);

// TODO: Switch to following, once ONNX issue is fixed.
// auto& map = graph_->DomainToVersionMap();
// const auto iter = map.find(kOnnxDomain);
// if (iter != map.end()) {
// function_ptr = op_->GetFunction(iter->second, true);
// } else {
// function_ptr = op_->GetFunction();
// }
auto requested_opset_version = get_opset_version(graph_);

/// When we fix custom function registration then we can apply the following
/// logic for non ONNX domains as we do for ONNX. Currently, we only apply ONNX
/// domain version
/* {
const auto iter = domain_to_version.find(op_->domain());
if (iter != domain_to_version.cend()) {
function_ptr = op_->GetFunction(iter->second, true);
} else {
function_ptr = op_->GetFunction();
}
}*/

if (requested_opset_version.has_value()) {
function_ptr = op_->GetFunction(*requested_opset_version, true);
} else {
function_ptr = op_->GetFunction(SinceVersion(), false);
}

if (function_ptr != nullptr) {
onnx_function_proto = *function_ptr;
Expand Down

0 comments on commit 064f8f2

Please sign in to comment.