From dac54d8d88aa20a888b86ae88b3010fbe4cbbf48 Mon Sep 17 00:00:00 2001 From: "G. Ramalingam" Date: Wed, 22 May 2024 16:25:12 -0700 Subject: [PATCH] Use default opset only if functions don't use opset (#1564) Fix bug reported in Issue #1559 : look at opsets imported by functions _before_ using a default. --- onnxscript/converter_test.py | 18 ++++++++++++++++++ onnxscript/irbuilder.py | 12 +++++++++--- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/onnxscript/converter_test.py b/onnxscript/converter_test.py index 58ed37968..121175755 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/converter_test.py @@ -675,6 +675,24 @@ def sum(n: INT64) -> INT64: self.check_run(sum, [np.array(5, dtype=np.int64)], np.array(10, dtype=np.int64)) self.check_run(sum, [np.array(-5, dtype=np.int64)], np.array(0, dtype=np.int64)) + def test_function_opset_import(self): + """Test that model inherits opset version from the function.""" + from onnxscript import opset19 + + @script() + def double(x): + return opset19.Add(x, x) + + @script() + def model(x): + return double(x) + + model_proto = model.to_model_proto() + onnx_opset_import = [opset for opset in model_proto.opset_import if opset.domain == ""] + + self.assertEqual(len(onnx_opset_import), 1) + self.assertEqual(onnx_opset_import[0].version, 19) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 3940ba929..90923a3f6 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -370,13 +370,19 @@ def to_proto(f): for n in self.stmts: if n.callee.opset.domain not in opsets: opsets[n.callee.opset.domain] = n.callee.opset.version + + for proto in functions: + if proto.domain not in opsets: + opsets[proto.domain] = 1 + # TODO(rama): Handle conflicts with appropriate error/warning message. + for opset in proto.opset_import: + if opset.domain not in opsets: + opsets[opset.domain] = opset.version + if "" not in opsets: # No operator is using the standard opset. # A default value is given. opsets[""] = onnx_opset_version() - for proto in functions: - if proto.domain not in opsets: - opsets[proto.domain] = 1 if "ir_version" not in kwargs: kwargs["ir_version"] = select_ir_version(opsets[""])