diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py index aeefd2599..015c0e2be 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_ir.py @@ -235,7 +235,7 @@ def eval_function( # type: ignore[override] else: # Python constants are scalars return 0 - elif function.experimental_traceable: + elif function.traceable: # Trace the function call instead of adding the function as a node return function.function(*args, **kwargs) diff --git a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py index c07ba3ce8..a00df9f93 100644 --- a/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py +++ b/onnxscript/function_libs/torch_lib/graph_building/_graph_building_torch.py @@ -388,7 +388,7 @@ def eval_function( # type: ignore[override] else: # Python constants are scalars return 0 - elif function.experimental_traceable: + elif function.traceable: # Trace the function call instead of adding the function as a node return function.function(*args, **kwargs) diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index 05d8f6217..2b3e6577e 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -141,7 +141,7 @@ def wrapper( else: assert isinstance(func, FunctionType) processed_func = onnxscript.script(opset=custom_opset)(func) - processed_func.experimental_traceable = traceable + processed_func.traceable = traceable assert registry is not None for name_ in _check_and_normalize_names(name): diff --git a/onnxscript/values.py b/onnxscript/values.py index 31ebe3000..fc4846b5d 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -5,6 +5,7 @@ from __future__ import annotations import dataclasses +import functools import inspect import logging import types @@ -477,8 +478,11 @@ def __init__( self._param_schemas: Optional[tuple[ParamSchema, ...]] = None self._op_schema: Optional[onnx.defs.OpSchema] = None + # Allow the object to be inspected as a function + functools.update_wrapper(self, pyfun) + # Experimental fields - self.experimental_traceable = False + self.traceable = False @property @deprecation.deprecated( @@ -570,6 +574,9 @@ def __init__(self, opset: Opset, func: types.FunctionType): super().__init__(opset, func.__name__) self.func = func + # Allow the object to be inspected as a function + functools.update_wrapper(self, func) + def __call__(self, *args, **kwargs): return self.func(*args, **kwargs) diff --git a/onnxscript/values_test.py b/onnxscript/values_test.py index ed21ff277..f5d08ad72 100644 --- a/onnxscript/values_test.py +++ b/onnxscript/values_test.py @@ -1,3 +1,11 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +from __future__ import annotations + +import inspect +import typing import unittest import onnxscript @@ -15,6 +23,48 @@ def function(input1, input2, attr1: int, attr2: int = 1): self.assertEqual(traced_function.name, function.__name__) self.assertEqual(traced_function.func, function) + def test_param_schemas_in_correct_order_with_mixed_inputs_and_attrs(self): + opset = values.Opset("test", 1) + + def function(input1, input2, attr1: int, attr2: float, input3, attr3: str = "default"): + return opset.CustomOp(input1 + input2, input3, attr1, attr2, attr3) + + traced_function = values.TracedOnnxFunction(opset, function) + param_schemas = traced_function.param_schemas() + expected_ordered_param_names = [ + "input1", + "input2", + "attr1", + "attr2", + "input3", + "attr3", + ] + self.assertEqual(len(param_schemas), len(expected_ordered_param_names)) + for i, param_schema in enumerate(param_schemas): + self.assertEqual(param_schema.name, expected_ordered_param_names[i]) + + def test_it_preserves_the_function_signature(self): + opset = values.Opset("test", 1) + + def function(input1, input2, attr1: int, attr2: float, input3, attr3: str = "default"): + return opset.CustomOp(input1 + input2, input3, attr1, attr2, attr3) + + traced_function = values.TracedOnnxFunction(opset, function) + signature = inspect.signature(traced_function) + self.assertEqual(signature.parameters["input1"].name, "input1") + self.assertEqual(signature.parameters["input2"].name, "input2") + self.assertEqual(signature.parameters["attr1"].name, "attr1") + self.assertEqual(signature.parameters["attr2"].name, "attr2") + self.assertEqual(signature.parameters["input3"].name, "input3") + self.assertEqual(signature.parameters["attr3"].name, "attr3") + + annotations = typing.get_type_hints(traced_function) + self.assertEqual(annotations["attr1"], int) + self.assertEqual(annotations["attr2"], float) + self.assertEqual(annotations["attr3"], str) + + +class OnnxFunctionTest(unittest.TestCase): def test_param_schemas_in_correct_order_with_mixed_inputs_and_attrs(self): opset = values.Opset("test", 1) @@ -34,3 +84,27 @@ def function(input1, input2, attr1: int, attr2: float, input3, attr3: str = "def self.assertEqual(len(param_schemas), len(expected_ordered_param_names)) for i, param_schema in enumerate(param_schemas): self.assertEqual(param_schema.name, expected_ordered_param_names[i]) + + def test_it_preserves_the_function_signature(self): + opset = values.Opset("test", 1) + + @onnxscript.script(default_opset=opset) + def function(input1, input2, attr1: int, attr2: float, input3, attr3: str = "default"): + return opset.CustomOp(input1 + input2, input3, attr1, attr2, attr3) + + signature = inspect.signature(function) + self.assertEqual(signature.parameters["input1"].name, "input1") + self.assertEqual(signature.parameters["input2"].name, "input2") + self.assertEqual(signature.parameters["attr1"].name, "attr1") + self.assertEqual(signature.parameters["attr2"].name, "attr2") + self.assertEqual(signature.parameters["input3"].name, "input3") + self.assertEqual(signature.parameters["attr3"].name, "attr3") + + annotations = typing.get_type_hints(function) + self.assertEqual(annotations["attr1"], int) + self.assertEqual(annotations["attr2"], float) + self.assertEqual(annotations["attr3"], str) + + +if __name__ == "__main__": + unittest.main()