From 064f8f2cfa39ad824091845033092e2d61b8bef3 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Thu, 14 Dec 2023 17:14:10 -0800 Subject: [PATCH 1/3] Build function bodies according to the imported global opset. Same is for querying ONNX functions. --- onnxruntime/core/graph/graph.cc | 48 ++++++++++++++++++++++++--------- 1 file changed, 36 insertions(+), 12 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index d489a59c4b798..eb5ab3e8f1371 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -582,6 +582,17 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot onnx_function_proto = *func_template_->onnx_func_proto_; return true; } else if (op_) { + auto get_opset_version = [op = op_](Graph* graph) -> std::optional { + if (op->domain() == kOnnxDomain) { + const auto& domain_to_version = graph->DomainToVersionMap(); + const auto iter = domain_to_version.find(kOnnxDomain); + if (iter != domain_to_version.cend()) { + return iter->second; + } + } + return {}; + }; + // Check if this node has a schema defined function proto. if (op_->HasContextDependentFunction()) { NodeProto node_proto; @@ -595,8 +606,13 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot } else input_types.emplace_back(); } + + auto requested_opset_version = get_opset_version(graph_); + if (!requested_opset_version.has_value()) { + requested_opset_version = SinceVersion(); + } ONNX_NAMESPACE::FunctionBodyBuildContextImpl function_body_ctx(node_proto, input_types); - return op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto); + return op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto, *requested_opset_version); } else if (op_->HasFunction()) { const FunctionProto* function_ptr = nullptr; // We need to get a function-body suitable for the ONNX opset used by the model. @@ -605,17 +621,25 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot // as the default-version, which is incorrect in the case of functions belonging to // non-onnx domains, like MSDOMAIN. - // We use the following as a temporary hack. - function_ptr = op_->GetFunction(SinceVersion(), false); - - // TODO: Switch to following, once ONNX issue is fixed. - // auto& map = graph_->DomainToVersionMap(); - // const auto iter = map.find(kOnnxDomain); - // if (iter != map.end()) { - // function_ptr = op_->GetFunction(iter->second, true); - // } else { - // function_ptr = op_->GetFunction(); - // } + auto requested_opset_version = get_opset_version(graph_); + + /// When we fix custom function registration then we can apply the following + /// logic for non ONNX domains as we do for ONNX. Currently, we only apply ONNX + /// domain version + /* { + const auto iter = domain_to_version.find(op_->domain()); + if (iter != domain_to_version.cend()) { + function_ptr = op_->GetFunction(iter->second, true); + } else { + function_ptr = op_->GetFunction(); + } + }*/ + + if (requested_opset_version.has_value()) { + function_ptr = op_->GetFunction(*requested_opset_version, true); + } else { + function_ptr = op_->GetFunction(SinceVersion(), false); + } if (function_ptr != nullptr) { onnx_function_proto = *function_ptr; From 59736f6f923b2991d35dee51b141ad2fefc8c016 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 15 Dec 2023 11:39:30 -0800 Subject: [PATCH 2/3] Address comment, add test --- onnxruntime/core/graph/graph.cc | 13 ----- onnxruntime/test/framework/function_test.cc | 54 +++++++++++++++++++++ 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index eb5ab3e8f1371..baebe2420073b 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -622,19 +622,6 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot // non-onnx domains, like MSDOMAIN. auto requested_opset_version = get_opset_version(graph_); - - /// When we fix custom function registration then we can apply the following - /// logic for non ONNX domains as we do for ONNX. Currently, we only apply ONNX - /// domain version - /* { - const auto iter = domain_to_version.find(op_->domain()); - if (iter != domain_to_version.cend()) { - function_ptr = op_->GetFunction(iter->second, true); - } else { - function_ptr = op_->GetFunction(); - } - }*/ - if (requested_opset_version.has_value()) { function_ptr = op_->GetFunction(*requested_opset_version, true); } else { diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index 9ab78cac3aca4..956477a0e92bb 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -614,5 +614,59 @@ TEST(FunctionTest, TestInlinedFunctionDoesNotReserrectNonExistingArgs) { AsSpan(output_names), &fetches, 0)); } +/// +/// This test covers the issues: +/// https://github.com/microsoft/onnxruntime/issues/16438 +/// https://github.com/microsoft/onnxruntime/issues/18781 +/// +TEST(FunctionTest, Test_GH_issue_16438) { + const char* code = R"( + < + ir_version: 8, + opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18], + producer_name: "pytorch", + producer_version: "2.1.0" + > + torch_jit (float16[5,10,5] input_0) => (double[5,10,5] _val_1) { + _val_1 = pkg.onnxscript.torch_lib.aten_special_log_softmax (input_0) + } + < + domain: "pkg.onnxscript.torch_lib", + opset_import: ["" : 18] + > + aten_special_log_softmax (self) => (result_8) + { + tmp = Shape(self) + tmp_0 = Size(tmp) + int64_0 = Constant () + int64_0_cast = CastLike(int64_0, tmp_0) + self_is_scalar = Equal(tmp_0, int64_0_cast) + self_4 = If(self_is_scalar) (self_2) { + tmp_1 = Constant () + self_2 = Unsqueeze(self, tmp_1) + }, else_branch : graph = elseGraph_8() => (self_3) { + self_3 = Identity(self) + }> + result = LogSoftmax(self_4) + result_5 = Cast(result) + result_8 = If(self_is_scalar) (result_6) { + result_6 = Squeeze(result_5) + }, else_branch : graph = elseGraph_12() => (result_7) { + result_7 = Identity(result_5) + }> + } + )"; + + std::string serialized_model; + ParseOnnxSource(code, serialized_model); + SessionOptions session_options; + InferenceSession session_object{session_options, GetEnvironment()}; + + std::stringstream sstr(serialized_model); + auto status = session_object.Load(sstr); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + status = session_object.Initialize(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); +} } // namespace test } // namespace onnxruntime From b152170fa640310fffb60bf503d7a03a03db85d7 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 15 Dec 2023 11:41:06 -0800 Subject: [PATCH 3/3] Lint --- onnxruntime/test/framework/function_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index 956477a0e92bb..fa3545ef27d72 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -656,7 +656,7 @@ TEST(FunctionTest, Test_GH_issue_16438) { }> } )"; - + std::string serialized_model; ParseOnnxSource(code, serialized_model); SessionOptions session_options;