Skip to content

Commit

Permalink
Address comment, add test
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed Dec 15, 2023
1 parent 064f8f2 commit 59736f6
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 13 deletions.
13 changes: 0 additions & 13 deletions onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -622,19 +622,6 @@ bool Node::TryGetFunctionProto(ONNX_NAMESPACE::FunctionProto& onnx_function_prot
// non-onnx domains, like MSDOMAIN.

auto requested_opset_version = get_opset_version(graph_);

/// When we fix custom function registration then we can apply the following
/// logic for non ONNX domains as we do for ONNX. Currently, we only apply ONNX
/// domain version
/* {
const auto iter = domain_to_version.find(op_->domain());
if (iter != domain_to_version.cend()) {
function_ptr = op_->GetFunction(iter->second, true);
} else {
function_ptr = op_->GetFunction();
}
}*/

if (requested_opset_version.has_value()) {
function_ptr = op_->GetFunction(*requested_opset_version, true);
} else {
Expand Down
54 changes: 54 additions & 0 deletions onnxruntime/test/framework/function_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -614,5 +614,59 @@ TEST(FunctionTest, TestInlinedFunctionDoesNotReserrectNonExistingArgs) {
AsSpan(output_names), &fetches, 0));
}

/// <summary>
/// This test covers the issues:
/// https://github.com/microsoft/onnxruntime/issues/16438
/// https://github.com/microsoft/onnxruntime/issues/18781
/// </summary>
TEST(FunctionTest, Test_GH_issue_16438) {
const char* code = R"(
<
ir_version: 8,
opset_import: ["pkg.onnxscript.torch_lib" : 1, "" : 18],
producer_name: "pytorch",
producer_version: "2.1.0"
>
torch_jit (float16[5,10,5] input_0) => (double[5,10,5] _val_1) {
_val_1 = pkg.onnxscript.torch_lib.aten_special_log_softmax <dim: int = 2, dtype: int = 11> (input_0)
}
<
domain: "pkg.onnxscript.torch_lib",
opset_import: ["" : 18]
>
aten_special_log_softmax <dim, dtype>(self) => (result_8)
{
tmp = Shape(self)
tmp_0 = Size(tmp)
int64_0 = Constant<value : tensor = int64 int64_0{0}> ()
int64_0_cast = CastLike(int64_0, tmp_0)
self_is_scalar = Equal(tmp_0, int64_0_cast)
self_4 = If(self_is_scalar) <then_branch : graph = thenGraph_8() => (self_2) {
tmp_1 = Constant<value_ints : ints = [0]> ()
self_2 = Unsqueeze(self, tmp_1)
}, else_branch : graph = elseGraph_8() => (self_3) {
self_3 = Identity(self)
}>
result = LogSoftmax<axis : int = @dim>(self_4)
result_5 = Cast<to : int = @dtype>(result)
result_8 = If(self_is_scalar) <then_branch : graph = thenGraph_12() => (result_6) {
result_6 = Squeeze(result_5)
}, else_branch : graph = elseGraph_12() => (result_7) {
result_7 = Identity(result_5)
}>
}
)";

Check warning on line 659 in onnxruntime/test/framework/function_test.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/test/framework/function_test.cc#L659

Line ends in whitespace. Consider deleting these extra spaces. [whitespace/end_of_line] [4]
Raw output
onnxruntime/test/framework/function_test.cc:659:  Line ends in whitespace.  Consider deleting these extra spaces.  [whitespace/end_of_line] [4]
std::string serialized_model;
ParseOnnxSource(code, serialized_model);
SessionOptions session_options;
InferenceSession session_object{session_options, GetEnvironment()};

std::stringstream sstr(serialized_model);
auto status = session_object.Load(sstr);
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
status = session_object.Initialize();
ASSERT_TRUE(status.IsOK()) << status.ErrorMessage();
}
} // namespace test
} // namespace onnxruntime

0 comments on commit 59736f6

Please sign in to comment.