Skip to content

Commit

Permalink
Use default opset only if functions don't use opset (#1564)
Browse files Browse the repository at this point in the history
Fix bug reported in Issue #1559 : look at opsets imported by functions
_before_ using a default.
  • Loading branch information
gramalingam authored May 22, 2024
1 parent fca7401 commit dac54d8
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
18 changes: 18 additions & 0 deletions onnxscript/converter_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,24 @@ def sum(n: INT64) -> INT64:
self.check_run(sum, [np.array(5, dtype=np.int64)], np.array(10, dtype=np.int64))
self.check_run(sum, [np.array(-5, dtype=np.int64)], np.array(0, dtype=np.int64))

def test_function_opset_import(self):
"""Test that model inherits opset version from the function."""
from onnxscript import opset19

@script()
def double(x):
return opset19.Add(x, x)

@script()
def model(x):
return double(x)

model_proto = model.to_model_proto()
onnx_opset_import = [opset for opset in model_proto.opset_import if opset.domain == ""]

self.assertEqual(len(onnx_opset_import), 1)
self.assertEqual(onnx_opset_import[0].version, 19)


if __name__ == "__main__":
unittest.main(verbosity=2)
12 changes: 9 additions & 3 deletions onnxscript/irbuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,13 +370,19 @@ def to_proto(f):
for n in self.stmts:
if n.callee.opset.domain not in opsets:
opsets[n.callee.opset.domain] = n.callee.opset.version

for proto in functions:
if proto.domain not in opsets:
opsets[proto.domain] = 1
# TODO(rama): Handle conflicts with appropriate error/warning message.
for opset in proto.opset_import:
if opset.domain not in opsets:
opsets[opset.domain] = opset.version

if "" not in opsets:
# No operator is using the standard opset.
# A default value is given.
opsets[""] = onnx_opset_version()
for proto in functions:
if proto.domain not in opsets:
opsets[proto.domain] = 1

if "ir_version" not in kwargs:
kwargs["ir_version"] = select_ir_version(opsets[""])
Expand Down

0 comments on commit dac54d8

Please sign in to comment.