diff --git a/onnxscript/_internal/runtime_typing.py b/onnxscript/_internal/runtime_typing.py index 1dae48643..3cf8a8db5 100644 --- a/onnxscript/_internal/runtime_typing.py +++ b/onnxscript/_internal/runtime_typing.py @@ -17,9 +17,11 @@ T = typing.TypeVar("T", bound=typing.Callable[..., typing.Any]) try: - from beartype import beartype as checked + from beartype import beartype as _beartype_decorator from beartype import roar as _roar + checked = typing.cast(typing.Callable[[T], T], _beartype_decorator) + # Beartype warns when we import from typing because the types are deprecated # in Python 3.9. But there will be a long time until we can move to using # the native container types for type annotations (when 3.9 is the lowest