Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Build function bodies according to the imported global opset. #18833

Merged
merged 3 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 23 additions & 12 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,17 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot
onnx_function_proto = *func_template_->onnx_func_proto_;
return true;
} else if (op_) {
auto get_opset_version = [op = op_](Graph* graph) -> std::optional<int> {
if (op->domain() == kOnnxDomain) {
const auto& domain_to_version = graph->DomainToVersionMap();
const auto iter = domain_to_version.find(kOnnxDomain);
if (iter != domain_to_version.cend()) {
return iter->second;
}
}
return {};
};

// Check if this node has a schema defined function proto.
if (op_->HasContextDependentFunction()) {
NodeProto node_proto;
Expand All @@ -595,8 +606,13 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot
} else
input_types.emplace_back();
}

auto requested_opset_version = get_opset_version(graph_);
if (!requested_opset_version.has_value()) {
requested_opset_version = SinceVersion();
}
ONNX_NAMESPACE::FunctionBodyBuildContextImpl function_body_ctx(node_proto, input_types);
return op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto);
return op_->BuildContextDependentFunction(function_body_ctx, onnx_function_proto, *requested_opset_version);
} else if (op_->HasFunction()) {
const FunctionProto* function_ptr = nullptr;
// We need to get a function-body suitable for the ONNX opset used by the model.
Expand All @@ -605,17 +621,12 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot
// as the default-version, which is incorrect in the case of functions belonging to
// non-onnx domains, like MSDOMAIN.

// We use the following as a temporary hack.
function_ptr = op_->GetFunction(SinceVersion(), false);

// TODO: Switch to following, once ONNX issue is fixed.
// auto& map = graph_->DomainToVersionMap();
// const auto iter = map.find(kOnnxDomain);
// if (iter != map.end()) {
// function_ptr = op_->GetFunction(iter->second, true);
// } else {
// function_ptr = op_->GetFunction();
// }
auto requested_opset_version = get_opset_version(graph_);
if (requested_opset_version.has_value()) {
function_ptr = op_->GetFunction(*requested_opset_version, true);
} else {
function_ptr = op_->GetFunction(SinceVersion(), false);
}

if (function_ptr != nullptr) {
onnx_function_proto = *function_ptr;
Expand Down
54 changes: 54 additions & 0 deletions onnxruntime/test/framework/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -614,5 +614,59 @@ TEST(FunctionTest, TestInlinedFunctionDoesNotReserrectNonExistingArgs) {
AsSpan(output_names), &fetches, 0));
}

/// <summary>
/// This test covers the issues:
/// https://github.com/microsoft/onnxruntime/issues/16438
/// https://github.com/microsoft/onnxruntime/issues/18781
/// </summary>
TEST(FunctionTest, Test_GH_issue_16438) {
const char* code = R"(
<
ir_version: 8,
opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18],
producer_name: "pytorch",
producer_version: "2.1.0"
>
torch_jit (float16[5,10,5] input_0) => (double[5,10,5] _val_1) {
_val_1 = pkg.onnxscript.torch_lib.aten_special_log_softmax <dim: int = 2, dtype: int = 11> (input_0)
}
<
domain: "pkg.onnxscript.torch_lib",
opset_import: ["" : 18]
>
aten_special_log_softmax <dim, dtype>(self) => (result_8)
{
tmp = Shape(self)
tmp_0 = Size(tmp)
int64_0 = Constant<value : tensor = int64 int64_0{0}> ()
int64_0_cast = CastLike(int64_0, tmp_0)
self_is_scalar = Equal(tmp_0, int64_0_cast)
self_4 = If(self_is_scalar) <then_branch : graph = thenGraph_8() => (self_2) {
tmp_1 = Constant<value_ints : ints = [0]> ()
self_2 = Unsqueeze(self, tmp_1)
}, else_branch : graph = elseGraph_8() => (self_3) {
self_3 = Identity(self)
}>
result = LogSoftmax<axis : int = @dim>(self_4)
result_5 = Cast<to : int = @dtype>(result)
result_8 = If(self_is_scalar) <then_branch : graph = thenGraph_12() => (result_6) {
result_6 = Squeeze(result_5)
}, else_branch : graph = elseGraph_12() => (result_7) {
result_7 = Identity(result_5)
}>
}
)";

std::string serialized_model;
ParseOnnxSource(code, serialized_model);
SessionOptions session_options;
InferenceSession session_object{session_options, GetEnvironment()};

std::stringstream sstr(serialized_model);
auto status = session_object.Load(sstr);
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
status = session_object.Initialize();
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
}
} // namespace test
} // namespace onnxruntime
Loading