diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc
index eb5ab3e8f1371..baebe2420073b 100644
--- a/onnxruntime/core/graph/graph.cc
+++ b/onnxruntime/core/graph/graph.cc
@@ -622,19 +622,6 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot
// non-onnx domains, like MSDOMAIN.
auto requested_opset_version = get_opset_version(graph_);
-
- /// When we fix custom function registration then we can apply the following
- /// logic for non ONNX domains as we do for ONNX. Currently, we only apply ONNX
- /// domain version
- /* {
- const auto iter = domain_to_version.find(op_->domain());
- if (iter != domain_to_version.cend()) {
- function_ptr = op_->GetFunction(iter->second, true);
- } else {
- function_ptr = op_->GetFunction();
- }
- }*/
-
if (requested_opset_version.has_value()) {
function_ptr = op_->GetFunction(*requested_opset_version, true);
} else {
diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc
index 9ab78cac3aca4..956477a0e92bb 100644
--- a/onnxruntime/test/framework/function_test.cc
+++ b/onnxruntime/test/framework/function_test.cc
@@ -614,5 +614,59 @@ TEST(FunctionTest, TestInlinedFunctionDoesNotReserrectNonExistingArgs) {
AsSpan(output_names), &fetches, 0));
}
+///
+/// This test covers the issues:
+/// https://github.com/microsoft/onnxruntime/issues/16438
+/// https://github.com/microsoft/onnxruntime/issues/18781
+///
+TEST(FunctionTest, Test_GH_issue_16438) {
+ const char* code = R"(
+ <
+ ir_version: 8,
+ opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18],
+ producer_name: "pytorch",
+ producer_version: "2.1.0"
+ >
+ torch_jit (float16[5,10,5] input_0) => (double[5,10,5] _val_1) {
+ _val_1 = pkg.onnxscript.torch_lib.aten_special_log_softmax (input_0)
+ }
+ <
+ domain: "pkg.onnxscript.torch_lib",
+ opset_import: ["" : 18]
+ >
+ aten_special_log_softmax (self) => (result_8)
+ {
+ tmp = Shape(self)
+ tmp_0 = Size(tmp)
+ int64_0 = Constant ()
+ int64_0_cast = CastLike(int64_0, tmp_0)
+ self_is_scalar = Equal(tmp_0, int64_0_cast)
+ self_4 = If(self_is_scalar) (self_2) {
+ tmp_1 = Constant ()
+ self_2 = Unsqueeze(self, tmp_1)
+ }, else_branch : graph = elseGraph_8() => (self_3) {
+ self_3 = Identity(self)
+ }>
+ result = LogSoftmax(self_4)
+ result_5 = Cast(result)
+ result_8 = If(self_is_scalar) (result_6) {
+ result_6 = Squeeze(result_5)
+ }, else_branch : graph = elseGraph_12() => (result_7) {
+ result_7 = Identity(result_5)
+ }>
+ }
+ )";
+
+ std::string serialized_model;
+ ParseOnnxSource(code, serialized_model);
+ SessionOptions session_options;
+ InferenceSession session_object{session_options, GetEnvironment()};
+
+ std::stringstream sstr(serialized_model);
+ auto status = session_object.Load(sstr);
+ ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
+ status = session_object.Initialize();
+ ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
+}
} // namespace test
} // namespace onnxruntime