From 59736f6f923b2991d35dee51b141ad2fefc8c016 Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Fri, 15 Dec 2023 11:39:30 -0800 Subject: [PATCH] Address comment, add test --- onnxruntime/core/graph/graph.cc | 13 ----- onnxruntime/test/framework/function_test.cc | 54 +++++++++++++++++++++ 2 files changed, 54 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index eb5ab3e8f1371..baebe2420073b 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -622,19 +622,6 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot // non-onnx domains, like MSDOMAIN. auto requested_opset_version = get_opset_version(graph_); - - /// When we fix custom function registration then we can apply the following - /// logic for non ONNX domains as we do for ONNX. Currently, we only apply ONNX - /// domain version - /* { - const auto iter = domain_to_version.find(op_->domain()); - if (iter != domain_to_version.cend()) { - function_ptr = op_->GetFunction(iter->second, true); - } else { - function_ptr = op_->GetFunction(); - } - }*/ - if (requested_opset_version.has_value()) { function_ptr = op_->GetFunction(*requested_opset_version, true); } else { diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index 9ab78cac3aca4..956477a0e92bb 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -614,5 +614,59 @@ TEST(FunctionTest, TestInlinedFunctionDoesNotReserrectNonExistingArgs) { AsSpan(output_names), &fetches, 0)); } +/// +/// This test covers the issues: +/// https://github.com/microsoft/onnxruntime/issues/16438 +/// https://github.com/microsoft/onnxruntime/issues/18781 +/// +TEST(FunctionTest, Test_GH_issue_16438) { + const char* code = R"( + < + ir_version: 8, + opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18], + producer_name: "pytorch", + producer_version: "2.1.0" + > + torch_jit (float16[5,10,5] input_0) => (double[5,10,5] _val_1) { + _val_1 = pkg.onnxscript.torch_lib.aten_special_log_softmax (input_0) + } + < + domain: "pkg.onnxscript.torch_lib", + opset_import: ["" : 18] + > + aten_special_log_softmax (self) => (result_8) + { + tmp = Shape(self) + tmp_0 = Size(tmp) + int64_0 = Constant () + int64_0_cast = CastLike(int64_0, tmp_0) + self_is_scalar = Equal(tmp_0, int64_0_cast) + self_4 = If(self_is_scalar) (self_2) { + tmp_1 = Constant () + self_2 = Unsqueeze(self, tmp_1) + }, else_branch : graph = elseGraph_8() => (self_3) { + self_3 = Identity(self) + }> + result = LogSoftmax(self_4) + result_5 = Cast(result) + result_8 = If(self_is_scalar) (result_6) { + result_6 = Squeeze(result_5) + }, else_branch : graph = elseGraph_12() => (result_7) { + result_7 = Identity(result_5) + }> + } + )"; + + std::string serialized_model; + ParseOnnxSource(code, serialized_model); + SessionOptions session_options; + InferenceSession session_object{session_options, GetEnvironment()}; + + std::stringstream sstr(serialized_model); + auto status = session_object.Load(sstr); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + status = session_object.Initialize(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); +} } // namespace test } // namespace onnxruntime