Skip to content

Commit

Permalink
Preserve function signatures for OnnxFunction; stabilize traceable
Browse files Browse the repository at this point in the history
…option (#1587)

- Rename the `experimental_traceable` property to `traceable`
- Preserve function signatures for OnnxFunction and TracedFunction

Fixes #401
  • Loading branch information
justinchuby authored Jun 4, 2024
1 parent 1efd8e6 commit b007b12
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def eval_function( # type: ignore[override]
else:
# Python constants are scalars
return 0
elif function.experimental_traceable:
elif function.traceable:
# Trace the function call instead of adding the function as a node
return function.function(*args, **kwargs)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def eval_function( # type: ignore[override]
else:
# Python constants are scalars
return 0
elif function.experimental_traceable:
elif function.traceable:
# Trace the function call instead of adding the function as a node
return function.function(*args, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion onnxscript/function_libs/torch_lib/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def wrapper(
else:
assert isinstance(func, FunctionType)
processed_func = onnxscript.script(opset=custom_opset)(func)
processed_func.experimental_traceable = traceable
processed_func.traceable = traceable

assert registry is not None
for name_ in _check_and_normalize_names(name):
Expand Down
9 changes: 8 additions & 1 deletion onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import dataclasses
import functools
import inspect
import logging
import types
Expand Down Expand Up @@ -477,8 +478,11 @@ def __init__(
self._param_schemas: Optional[tuple[ParamSchema, ...]] = None
self._op_schema: Optional[onnx.defs.OpSchema] = None

# Allow the object to be inspected as a function
functools.update_wrapper(self, pyfun)

# Experimental fields
self.experimental_traceable = False
self.traceable = False

@property
@deprecation.deprecated(
Expand Down Expand Up @@ -570,6 +574,9 @@ def __init__(self, opset: Opset, func: types.FunctionType):
super().__init__(opset, func.__name__)
self.func = func

# Allow the object to be inspected as a function
functools.update_wrapper(self, func)

def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)

Expand Down
74 changes: 74 additions & 0 deletions onnxscript/values_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# --------------------------------------------------------------------------
from __future__ import annotations

import inspect
import typing
import unittest

import onnxscript
Expand All @@ -15,6 +23,48 @@ def function(input1, input2, attr1: int, attr2: int = 1):
self.assertEqual(traced_function.name, function.__name__)
self.assertEqual(traced_function.func, function)

def test_param_schemas_in_correct_order_with_mixed_inputs_and_attrs(self):
opset = values.Opset("test", 1)

def function(input1, input2, attr1: int, attr2: float, input3, attr3: str = "default"):
return opset.CustomOp(input1 + input2, input3, attr1, attr2, attr3)

traced_function = values.TracedOnnxFunction(opset, function)
param_schemas = traced_function.param_schemas()
expected_ordered_param_names = [
"input1",
"input2",
"attr1",
"attr2",
"input3",
"attr3",
]
self.assertEqual(len(param_schemas), len(expected_ordered_param_names))
for i, param_schema in enumerate(param_schemas):
self.assertEqual(param_schema.name, expected_ordered_param_names[i])

def test_it_preserves_the_function_signature(self):
opset = values.Opset("test", 1)

def function(input1, input2, attr1: int, attr2: float, input3, attr3: str = "default"):
return opset.CustomOp(input1 + input2, input3, attr1, attr2, attr3)

traced_function = values.TracedOnnxFunction(opset, function)
signature = inspect.signature(traced_function)
self.assertEqual(signature.parameters["input1"].name, "input1")
self.assertEqual(signature.parameters["input2"].name, "input2")
self.assertEqual(signature.parameters["attr1"].name, "attr1")
self.assertEqual(signature.parameters["attr2"].name, "attr2")
self.assertEqual(signature.parameters["input3"].name, "input3")
self.assertEqual(signature.parameters["attr3"].name, "attr3")

annotations = typing.get_type_hints(traced_function)
self.assertEqual(annotations["attr1"], int)
self.assertEqual(annotations["attr2"], float)
self.assertEqual(annotations["attr3"], str)


class OnnxFunctionTest(unittest.TestCase):
def test_param_schemas_in_correct_order_with_mixed_inputs_and_attrs(self):
opset = values.Opset("test", 1)

Expand All @@ -34,3 +84,27 @@ def function(input1, input2, attr1: int, attr2: float, input3, attr3: str = "def
self.assertEqual(len(param_schemas), len(expected_ordered_param_names))
for i, param_schema in enumerate(param_schemas):
self.assertEqual(param_schema.name, expected_ordered_param_names[i])

def test_it_preserves_the_function_signature(self):
opset = values.Opset("test", 1)

@onnxscript.script(default_opset=opset)
def function(input1, input2, attr1: int, attr2: float, input3, attr3: str = "default"):
return opset.CustomOp(input1 + input2, input3, attr1, attr2, attr3)

signature = inspect.signature(function)
self.assertEqual(signature.parameters["input1"].name, "input1")
self.assertEqual(signature.parameters["input2"].name, "input2")
self.assertEqual(signature.parameters["attr1"].name, "attr1")
self.assertEqual(signature.parameters["attr2"].name, "attr2")
self.assertEqual(signature.parameters["input3"].name, "input3")
self.assertEqual(signature.parameters["attr3"].name, "attr3")

annotations = typing.get_type_hints(function)
self.assertEqual(annotations["attr1"], int)
self.assertEqual(annotations["attr2"], float)
self.assertEqual(annotations["attr3"], str)


if __name__ == "__main__":
unittest.main()

0 comments on commit b007b12

Please sign in to comment.