From 064f8f2cfa39ad824091845033092e2d61b8bef3 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 14 Dec 2023 17:14:10 -0800 Subject: [PATCH] Build function bodies according to the imported global opset. Same is for querying ONNX functions. --- onnxruntime/core/graph/graph.cc | 48 ++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index d489a59c4b798..eb5ab3e8f1371 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -582,6 +582,17 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot onnx_function_proto = *func_template_->onnx_func_proto_; return true; } else if (op_) { + auto get_opset_version = [op = op_](Graph* graph) -> std::optional { + if (op->domain() == kOnnxDomain) { + const auto& domain_to_version = graph->DomainToVersionMap(); + const auto iter = domain_to_version.find(kOnnxDomain); + if (iter != domain_to_version.cend()) { + return iter->second; + } + } + return {}; + }; + // Check if this node has a schema defined function proto. if (op_->HasContextDependentFunction()) { NodeProto node_proto; @@ -595,8 +606,13 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot } else input_types.emplace_back(); } + + auto requested_opset_version = get_opset_version(graph_); + if (!requested_opset_version.has_value()) { + requested_opset_version = SinceVersion(); + } ONNX_NAMESPACE::FunctionBodyBuildContextImpl function_body_ctx(node_proto, input_types); - return op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto); + return op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto, *requested_opset_version); } else if (op_->HasFunction()) { const FunctionProto* function_ptr = nullptr; // We need to get a function-body suitable for the ONNX opset used by the model. @@ -605,17 +621,25 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot // as the default-version, which is incorrect in the case of functions belonging to // non-onnx domains, like MSDOMAIN. - // We use the following as a temporary hack. - function_ptr = op_->GetFunction(SinceVersion(), false); - - // TODO: Switch to following, once ONNX issue is fixed. - // auto& map = graph_->DomainToVersionMap(); - // const auto iter = map.find(kOnnxDomain); - // if (iter != map.end()) { - // function_ptr = op_->GetFunction(iter->second, true); - // } else { - // function_ptr = op_->GetFunction(); - // } + auto requested_opset_version = get_opset_version(graph_); + + /// When we fix custom function registration then we can apply the following + /// logic for non ONNX domains as we do for ONNX. Currently, we only apply ONNX + /// domain version + /* { + const auto iter = domain_to_version.find(op_->domain()); + if (iter != domain_to_version.cend()) { + function_ptr = op_->GetFunction(iter->second, true); + } else { + function_ptr = op_->GetFunction(); + } + }*/ + + if (requested_opset_version.has_value()) { + function_ptr = op_->GetFunction(*requested_opset_version, true); + } else { + function_ptr = op_->GetFunction(SinceVersion(), false); + } if (function_ptr != nullptr) { onnx_function_proto = *function_ptr;