diff --git a/onnxscript/_internal/deprecation.py b/onnxscript/_internal/deprecation.py index 301565c8d..7bf18482a 100644 --- a/onnxscript/_internal/deprecation.py +++ b/onnxscript/_internal/deprecation.py @@ -12,6 +12,12 @@ T = TypeVar("T") +@functools.lru_cache(maxsize=1024) +def _warn_once(message: str): + """Issue a FutureWarning only once per message.""" + warnings.warn(message, category=FutureWarning, stacklevel=3) + + def deprecated(since: str, removed_in: str, instructions: str) -> Callable[[T], T]: """Marks functions as deprecated. @@ -30,12 +36,10 @@ def deprecated(since: str, removed_in: str, instructions: str) -> Callable[[T], def decorator(function): @functools.wraps(function) def wrapper(*args, **kwargs): - warnings.warn( + _warn_once( f"'{function.__module__}.{function.__qualname__}' " f"is deprecated in version {since} and will be " f"removed in {removed_in}. Please {instructions}.", - category=FutureWarning, - stacklevel=2, ) return function(*args, **kwargs) diff --git a/onnxscript/ir/_schemas.py b/onnxscript/ir/_schemas.py new file mode 100644 index 000000000..3422a0c28 --- /dev/null +++ b/onnxscript/ir/_schemas.py @@ -0,0 +1,548 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import collections.abc +import dataclasses +import inspect +import logging +import types +import typing +from typing import Any, Iterator, Mapping, Optional, Sequence, TypeVar, Union + +import onnx + +import onnxscript +from onnxscript import ir + +logger = logging.getLogger(__name__) + + +# A special value to indicate that the default value is not specified +class _Empty: + def __repr__(self): + return "_EMPTY_DEFAULT" + + +_EMPTY_DEFAULT = _Empty() + +# Map from python type to corresponding ONNX AttributeProto type +_PY_TYPE_TO_ATTR_TYPE = { + float: ir.AttributeType.FLOAT, + int: ir.AttributeType.INT, + str: ir.AttributeType.STRING, + bool: ir.AttributeType.INT, + ir.Tensor: ir.AttributeType.TENSOR, + ir.TensorProtocol: ir.AttributeType.TENSOR, + ir.Graph: ir.AttributeType.GRAPH, + ir.GraphProtocol: ir.AttributeType.GRAPH, +} + +# Map from python type to corresponding ONNX AttributeProto type, +# for repeated (i.e., list of) values +_LIST_TYPE_TO_ATTR_TYPE = { + float: ir.AttributeType.FLOATS, + int: ir.AttributeType.INTS, + str: ir.AttributeType.STRINGS, + bool: ir.AttributeType.INTS, + ir.Tensor: ir.AttributeType.TENSORS, + ir.TensorProtocol: ir.AttributeType.TENSORS, + ir.Graph: ir.AttributeType.GRAPHS, + ir.GraphProtocol: ir.AttributeType.GRAPHS, +} + +_ALL_VALUE_TYPES = ( + {ir.TensorType(dtype) for dtype in ir.DataType} + | {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType} + | {ir.OptionalType(ir.TensorType(dtype)) for dtype in ir.DataType} +) + +# TypeAnnotationValue represents the (value of) valid type-annotations recognized +# by ONNX Script. Currently, it supports +# - float, int, str (primitive attribute types) +# - Sequence[float], Sequence[int], Sequence[str] (attribute types) +# - Tensor types +# - Sequence[Tensor] types +# - Union of above 2 +# - TypeVars with above bounds +# - Above types with annotation attached +TypeAnnotationValue = Any + + +@dataclasses.dataclass(frozen=True) +class TypeConstraintParam: + """Type constraint for a parameter. + + Attributes: + name: Name of the parameter. E.g. "TFloat" + allowed_types: Allowed types for the parameter. + """ + + name: str + allowed_types: set[ir.TypeProtocol] + description: str = "" + + def __hash__(self) -> int: + return hash((self.name, tuple(self.allowed_types))) + + def __str__(self) -> str: + allowed_types_str = " | ".join(str(t) for t in self.allowed_types) + return f"{self.name}={allowed_types_str}" + + @classmethod + def any_tensor(cls, name: str, description: str = "") -> TypeConstraintParam: + return cls(name, {ir.TensorType(dtype) for dtype in ir.DataType}, description) + + @classmethod + def any_value(cls, name: str, description: str = "") -> TypeConstraintParam: + return cls(name, _ALL_VALUE_TYPES, description) # type: ignore[arg-type] + + +@dataclasses.dataclass(frozen=True) +class Parameter: + """A formal parameter of an operator.""" + + name: str + type_constraint: TypeConstraintParam + required: bool + variadic: bool + default: Any = _EMPTY_DEFAULT + # TODO: Add other properties too + + def __str__(self) -> str: + type_str = self.type_constraint.name + if self.has_default(): + return f"{self.name}: {type_str} = {self.default}" + return f"{self.name}: {type_str}" + + def has_default(self) -> bool: + return self.default is not _EMPTY_DEFAULT + + +@dataclasses.dataclass(frozen=True) +class AttributeParameter: + """A parameter in the function signature that represents an ONNX attribute.""" + + name: str + type: ir.AttributeType + required: bool + default: ir.Attr | None = None + + def __str__(self) -> str: + type_str = self.type.name + if self.has_default(): + return f"{self.name}: {type_str} = {self.default}" + return f"{self.name}: {type_str}" + + def has_default(self) -> bool: + return self.default is not None + + +def _get_type_from_str( + type_str: str, +) -> ir.TensorType | ir.SequenceType | ir.OptionalType: + """Converter a type_str from ONNX OpSchema to ir.TypeProtocol. + + A type str has the form of "tensor(float)" or composite type like "seq(tensor(float))". + """ + # Split the type_str a sequence types and dtypes + # 1. Remove the ending ")" + striped = type_str.rstrip(")") + # 2. Split the type_str by "(" + type_parts = striped.split("(") + + # Convert the dtype to ir.DataType + dtype = ir.DataType[type_parts[-1].upper()] + + # Create a place holder type first + type_: ir.TypeProtocol = ir.TensorType(ir.DataType.UNDEFINED) + + # Construct the type + for type_part in reversed(type_parts[:-1]): + if type_part == "tensor": + type_ = ir.TensorType(dtype) + elif type_part == "seq": + type_ = ir.SequenceType(type_) + elif type_part == "optional": + type_ = ir.OptionalType(type_) + else: + raise ValueError(f"Unknown type part: '{type_part}' in type '{type_str}'") + return type_ # type: ignore[return-value] + + +def _convert_formal_parameter( + param: onnx.defs.OpSchema.FormalParameter, + type_constraints: Mapping[str, TypeConstraintParam], +) -> Parameter: + """Convert a formal parameter from ONNX OpSchema to Parameter.""" + if param.type_str in type_constraints: + type_constraint = type_constraints[param.type_str] + else: + # param.type_str can be a plain type like 'int64'. + type_constraint = TypeConstraintParam( + name=param.name, + allowed_types={_get_type_from_str(param.type_str)}, + ) + return Parameter( + name=param.name, + type_constraint=type_constraint, + required=param.option != onnx.defs.OpSchema.FormalParameterOption.Optional, + variadic=param.option == onnx.defs.OpSchema.FormalParameterOption.Variadic, + ) + + +def _is_optional(type_: type) -> bool: + """Returns whether a type_ is an Optional.""" + origin_type = typing.get_origin(type_) + if origin_type is Union and type(None) in typing.get_args(type_): + # Python < 3.10 + return True + if origin_type is Optional: + # Python >= 3.10 + return True + if ( + hasattr(types, "UnionType") + and origin_type is types.UnionType + and type(None) in typing.get_args(type_) + ): + # Python >= 3.10 + return True + return False + + +def _get_attr_type(type_: type) -> ir.AttributeType: + """Obtain the type of the attribute from a Python class.""" + try: + if type_ in _PY_TYPE_TO_ATTR_TYPE: + return _PY_TYPE_TO_ATTR_TYPE[type_] + origin_type = typing.get_origin(type_) + if origin_type is None: + return ir.AttributeType.UNDEFINED + if origin_type in ( + collections.abc.Sequence, + Sequence, + typing.List, + list, + typing.Tuple, + tuple, + ): + inner_type = typing.get_args(type_)[0] + if inner_type in _LIST_TYPE_TO_ATTR_TYPE: + return _LIST_TYPE_TO_ATTR_TYPE[inner_type] + except TypeError: + logger.warning("TypeError when checking %s.", type_, exc_info=True) + return ir.AttributeType.UNDEFINED + + +def _get_type_constraint_name(type_: TypeAnnotationValue) -> str | None: + """Returns the name of the type constraint for a given type annotation. + + Args: + type_: A Python type. + + Returns: + The name of the type constraint if it is a TypeVar. + - Prefixes the name with "Sequence_" if the type annotation is a Sequence[]. + """ + if isinstance(type_, TypeVar): + return type_.__name__ + if _is_optional(type_): + subtypes = typing.get_args(type_) + for subtype in subtypes: + if subtype is type(None): + continue + type_param_name = _get_type_constraint_name(subtype) + return type_param_name if type_param_name else None + origin_type = typing.get_origin(type_) + if isinstance(origin_type, type) and issubclass(origin_type, Sequence): + subtypes = typing.get_args(type_) + type_param_name = _get_type_constraint_name(subtypes[0]) + return f"Sequence_{type_param_name}" if type_param_name else None + return None + + +def _get_allowed_types_from_type_annotation( + type_: TypeAnnotationValue, +) -> set[ir.TypeProtocol]: + """Obtain the allowed types from a type annotation.""" + if type_ is onnxscript.onnx_types.TensorType: + # Any tensor type + return {ir.TensorType(dtype) for dtype in ir.DataType} + + allowed_types: set[ir.TypeProtocol] + + if isinstance(type_, TypeVar): + allowed_types = set() + if constraints := type_.__constraints__: + for constraint in constraints: + allowed_types.update(_get_allowed_types_from_type_annotation(constraint)) + else: + bound = type_.__bound__ + if bound is None: + allowed_types = _ALL_VALUE_TYPES # type: ignore[assignment] + else: + allowed_types.update(_get_allowed_types_from_type_annotation(bound)) + return allowed_types + if hasattr(type_, "dtype"): + # A single tensor type like INT64, FLOAT, etc. + return {ir.TensorType(ir.DataType(type_.dtype))} + if _is_optional(type_): + allowed_types = set() + subtypes = typing.get_args(type_) + for subtype in subtypes: + if subtype is type(None): + continue + allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) + # NOTE: We do not consider dynamic optional types like optional(float) because they are not very useful. + return allowed_types + + origin_type = typing.get_origin(type_) + if origin_type is Union: + allowed_types = set() + subtypes = typing.get_args(type_) + for subtype in subtypes: + assert subtype is not type( + None + ), "Union should not contain None type because it is handled by _is_optional." + allowed_types.update(_get_allowed_types_from_type_annotation(subtype)) + return allowed_types + + if isinstance(origin_type, type) and issubclass(origin_type, Sequence): + subtypes = typing.get_args(type_) + return { + ir.SequenceType(t) for t in _get_allowed_types_from_type_annotation(subtypes[0]) + } + + # Allow everything by default + return _ALL_VALUE_TYPES # type: ignore[return-value] + + +@dataclasses.dataclass +class OpSignature: + """Schema for an operator. + + Attributes: + domain: Domain of the operator. E.g. "". + name: Name of the operator. E.g. "Add". + overload: Overload name of the operator. + params: Input parameters. When the op is an ONNX function definition, + the order is according to the function signature. This mean we can + interleave ONNX inputs and ONNX attributes in the list. + outputs: Output parameters. + """ + + domain: str + name: str + overload: str + params: Sequence[Parameter | AttributeParameter] + outputs: Sequence[Parameter] + params_map: Mapping[str, Parameter | AttributeParameter] = dataclasses.field( + init=False, repr=False + ) + + def __post_init__(self): + self.params_map = {param.name: param for param in self.params} + + def get(self, name: str) -> Parameter | AttributeParameter: + return self.params_map[name] + + def __contains__(self, name: str) -> bool: + return name in self.params_map + + def __iter__(self) -> Iterator[Parameter | AttributeParameter]: + return iter(self.params) + + def __str__(self) -> str: + domain = self.domain or "''" + # TODO: Double check the separator for overload + overload = f"::{self.overload}" if self.overload else "" + params = ", ".join(str(param) for param in self.params) + outputs = ", ".join(str(param.type_constraint.name) for param in self.outputs) + type_constraints = {} + for param in self.params: + if isinstance(param, Parameter): + type_constraints[param.type_constraint.name] = param.type_constraint + for param in self.outputs: + type_constraints[param.type_constraint.name] = param.type_constraint + type_constraints_str = ", ".join( + str(type_constraint) for type_constraint in type_constraints.values() + ) + return f"{domain}::{self.name}{overload}({params}) -> ({outputs}) where {type_constraints_str}" + + @classmethod + def from_op_schema(cls, op_schema: onnx.defs.OpSchema) -> OpSignature: + """Produce an OpSignature from an ONNX OpSchema.""" + type_constraints = { + constraint.type_param_str: TypeConstraintParam( + name=constraint.type_param_str, + allowed_types={ + _get_type_from_str(type_str) for type_str in constraint.allowed_type_strs + }, + description=constraint.description, + ) + for constraint in op_schema.type_constraints + } + + params = [ + _convert_formal_parameter(param, type_constraints) for param in op_schema.inputs + ] + + for param in op_schema.attributes.values(): + default_attr = ( + ir.serde.deserialize_attribute(param.default_value) + if param.default_value is not None + else None + ) + if default_attr is not None: + # Set the name of the default attribute because it may have a different name from the parameter + default_attr.name = param.name + params.append( + AttributeParameter( + name=param.name, + type=ir.AttributeType(param.type), # type: ignore[arg-type] + required=param.required, + default=default_attr, # type: ignore[arg-type] + ) + ) + + outputs = [ + _convert_formal_parameter(param, type_constraints) for param in op_schema.outputs + ] + + return cls( + domain=op_schema.domain, + name=op_schema.name, + overload="", + params=params, + outputs=outputs, + ) + + @classmethod + def from_function( + cls, func, domain: str, name: str | None = None, overload: str = "" + ) -> OpSignature: + """Produce an OpSignature from a function using type annotation.""" + + py_signature = inspect.signature(func) + # Not using inspect.get_annotations because typing.get_type_hints seems to handle more cases + # https://github.com/python/cpython/issues/102405 + type_hints = typing.get_type_hints(func) + + params: list[Parameter | AttributeParameter] = [] + # Create a mapping from type to a unique name + type_constraints: dict[str, TypeConstraintParam] = {} + + for param in py_signature.parameters.values(): + if param.name not in type_hints: + logger.warning( + "Missing annotation for parameter '%s' from %s. Treating as an Input.", + param.name, + py_signature, + ) + type_constraint = TypeConstraintParam.any_value(f"T_{param.name}") + type_constraints[param.name] = type_constraint + params.append( + Parameter( + name=param.name, + type_constraint=type_constraint, + required=param.default is inspect.Parameter.empty, + # TODO: Handle variadic + variadic=False, + default=param.default + if param.default is not inspect.Parameter.empty + else _EMPTY_DEFAULT, + ) + ) + else: + type_ = type_hints[param.name] + if (attr_type := _get_attr_type(type_)) != ir.AttributeType.UNDEFINED: + # Construct the default attribute + if param.default is not inspect.Parameter.empty: + # TODO: Use ir_convenience instead to handle int as float + default = ir.Attr(param.name, attr_type, param.default) + else: + default = None + params.append( + AttributeParameter( + name=param.name, + type=attr_type, + required=param.default is inspect.Parameter.empty, + default=default, + ) + ) + else: + # Obtain the type constraint from the type annotation + + # 1. Get a type constraint name from the type annotation + # If the type annotation is a TypeVar or Optional[TypeVar], get its name + # Otherwise, name it T_{param.name} + type_constraint_name = _get_type_constraint_name(type_) + if type_constraint_name is None: + type_constraint_name = f"T_{param.name}" + + # 2. If the type constraint param is already initialized, use it + if type_constraint_name in type_constraints: + type_constraint = type_constraints[type_constraint_name] + else: + # 3. Otherwise, create a new TypeConstraintParam + type_constraint = TypeConstraintParam( + name=type_constraint_name, + allowed_types=_get_allowed_types_from_type_annotation(type_), + ) + type_constraints[type_constraint_name] = type_constraint + # 4. Create Parameter + params.append( + Parameter( + name=param.name, + type_constraint=type_constraint, + required=param.default is inspect.Parameter.empty, + # TODO: Handle variadic + variadic=False, + default=param.default + if param.default is not inspect.Parameter.empty + else _EMPTY_DEFAULT, + ) + ) + + return_type = type_hints.get("return") + + outputs = [] + if return_type is None: + # No returns + pass + else: + if typing.get_origin(return_type) is tuple: + # Multiple returns + return_types = typing.get_args(return_type) + else: + return_types = [return_type] # type: ignore[assignment] + + for i, return_type_i in enumerate(return_types): + if ( + return_param_name := _get_type_constraint_name(return_type_i) + ) in type_constraints: + type_constraint = type_constraints[return_param_name] + else: + return_param_name = f"TReturn{i}" + type_constraint = TypeConstraintParam( + name=return_param_name, + allowed_types=_get_allowed_types_from_type_annotation(return_type_i), + ) + type_constraints[return_param_name] = type_constraint + outputs.append( + Parameter( + name=return_param_name, + type_constraint=type_constraint, + required=True, + variadic=False, + default=_EMPTY_DEFAULT, + ) + ) + + return cls( + domain=domain, + name=name or func.__name__, + overload=overload, + params=params, + outputs=outputs, + ) diff --git a/onnxscript/ir/_schemas_test.py b/onnxscript/ir/_schemas_test.py new file mode 100644 index 000000000..c134bd7a6 --- /dev/null +++ b/onnxscript/ir/_schemas_test.py @@ -0,0 +1,176 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from __future__ import annotations + +import unittest +from typing import Any, Optional, Sequence, TypeVar, Union + +import parameterized + +import onnxscript +import onnxscript.testing +from onnxscript import FLOAT, INT64, ir +from onnxscript.ir import _schemas + +_TestTypeVarConstraints = TypeVar("_TestTypeVarConstraints", INT64, FLOAT) +_TestTypeVarOneBound = TypeVar("_TestTypeVarOneBound", bound=INT64) +_TestTypeVarTwoBound = TypeVar("_TestTypeVarTwoBound", bound=Union[INT64, FLOAT]) + + +class TypeConversionFunctionsTest(unittest.TestCase): + @parameterized.parameterized.expand( + [ + ( + "tensor_type_all", + onnxscript.onnx_types.TensorType, + {ir.TensorType(dtype) for dtype in ir.DataType}, + ), + ("tensor_type", INT64, {ir.TensorType(ir.DataType.INT64)}), + ( + "tensor_type_union", + Union[INT64, FLOAT], + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "tensor_type_variadic_shape", + INT64[...], + {ir.TensorType(ir.DataType.INT64)}, + ), + ("tensor_type_shape", INT64[10], {ir.TensorType(ir.DataType.INT64)}), + ( + "type_var_constraints", + _TestTypeVarConstraints, + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "type_bound_one", + _TestTypeVarOneBound, + {ir.TensorType(ir.DataType.INT64)}, + ), + ( + "type_bound_two", + _TestTypeVarTwoBound, + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "optional_tensor_type_all", + Optional[onnxscript.onnx_types.TensorType], + {ir.TensorType(dtype) for dtype in ir.DataType}, + ), + ( + "optional_tensor_type", + Optional[INT64], + {ir.TensorType(ir.DataType.INT64)}, + ), + ( + "optional_tensor_type_union", + Optional[Union[INT64, FLOAT]], + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "optional_tensor_type_variadic_shape", + Optional[INT64[...]], + {ir.TensorType(ir.DataType.INT64)}, + ), + ( + "optional_tensor_type_shape", + Optional[INT64[10]], + {ir.TensorType(ir.DataType.INT64)}, + ), + ( + "optional_type_var_constraints", + Optional[_TestTypeVarConstraints], + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "optional_type_bound_one", + Optional[_TestTypeVarOneBound], + {ir.TensorType(ir.DataType.INT64)}, + ), + ( + "optional_type_bound_two", + Optional[_TestTypeVarTwoBound], + {ir.TensorType(ir.DataType.INT64), ir.TensorType(ir.DataType.FLOAT)}, + ), + ( + "sequence_type_all", + Sequence[onnxscript.onnx_types.TensorType], + {ir.SequenceType(ir.TensorType(dtype)) for dtype in ir.DataType}, + ), + ( + "sequence_type", + Sequence[INT64], + {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, + ), + ( + "union_sequence_type", + Union[Sequence[INT64], Sequence[FLOAT]], + { + ir.SequenceType(ir.TensorType(ir.DataType.INT64)), + ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), + }, + ), + ( + "sequence_type_variadic_shape", + Sequence[INT64[...]], + {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, + ), + ( + "sequence_type_shape", + Sequence[INT64[10]], + {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, + ), + ( + "sequence_type_var_constraints", + Sequence[_TestTypeVarConstraints], + { + ir.SequenceType(ir.TensorType(ir.DataType.INT64)), + ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), + }, + ), + ( + "sequence_type_bound_one", + Sequence[_TestTypeVarOneBound], + {ir.SequenceType(ir.TensorType(ir.DataType.INT64))}, + ), + ( + "sequence_type_bound_two", + Sequence[_TestTypeVarTwoBound], + { + ir.SequenceType(ir.TensorType(ir.DataType.INT64)), + ir.SequenceType(ir.TensorType(ir.DataType.FLOAT)), + }, + ), + ] + ) + def test_pytype_to_ir_type(self, _, pytype: Any, expected: set[ir.TypeProtocol]): + self.assertEqual(_schemas._get_allowed_types_from_type_annotation(pytype), expected) # pylint: disable=protected-access + + @parameterized.parameterized.expand( + [ + ("type_var", _TestTypeVarConstraints, "_TestTypeVarConstraints"), + ("type_var_bound", _TestTypeVarOneBound, "_TestTypeVarOneBound"), + ( + "optional_type_var", + Optional[_TestTypeVarOneBound], + "_TestTypeVarOneBound", + ), + ( + "sequence_type_var", + Sequence[_TestTypeVarOneBound], + "Sequence__TestTypeVarOneBound", + ), + ("normal_type", INT64, None), + ("union_type", Union[INT64, FLOAT], None), + ("optional_type", Optional[INT64], None), + ("sequence_type", Sequence[INT64], None), + ("optional_sequence_type", Optional[Sequence[INT64]], None), + ("optional_union_type", Optional[Union[INT64, FLOAT]], None), + ] + ) + def test_get_type_constraint_name(self, _: str, pytype: Any, expected: str | None): + self.assertEqual(_schemas._get_type_constraint_name(pytype), expected) # pylint: disable=protected-access + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/values.py b/onnxscript/values.py index f47c64f70..89fe1e478 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -25,6 +25,7 @@ from onnxscript import converter as converter_module from onnxscript import irbuilder, sourceinfo, type_annotation from onnxscript._internal import ast_utils, deprecation +from onnxscript.ir import _schemas _ATTRIBUTE_TYPE_TO_PYTHON_TYPE = { onnx.defs.OpSchema.AttrType.FLOAT: float, @@ -173,7 +174,7 @@ def _get_attribute_value(attr_proto: onnx.AttributeProto) -> Any: return onnx.helper.get_attribute_value(attr_proto) -def param_schemas_from_op_schema( +def _param_schemas_from_op_schema( op_schema: onnx.defs.OpSchema, ) -> tuple[ParamSchema, ...]: """Get the parameter schemas from an ONNX OpSchema.""" @@ -222,7 +223,7 @@ def _param_schema_from_function_ir_attr(attr: irbuilder.IRAttributeParameter): ) -def param_schemas_from_function_ir( +def _param_schemas_from_function_ir( function_ir: irbuilder.IRFunction, ) -> tuple[ParamSchema, ...]: """Get the parameter schemas from a FunctionIR.""" @@ -259,7 +260,8 @@ def opset(self) -> Opset: ... @property def op_schema(self) -> Optional[onnx.defs.OpSchema]: ... - def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]: ... + @property + def op_signature(self) -> Optional[_schemas.OpSignature]: ... class Op(OpLike): @@ -274,18 +276,19 @@ class Op(OpLike): """ def __init__( - self, opset: Opset, opname: str, op_schema: Optional[onnx.defs.OpSchema] = None + self, opset: Opset, name: str, op_schema: Optional[onnx.defs.OpSchema] = None ) -> None: self._opset = opset - self._name = opname - self._op_schema = op_schema or opset[opname] + self._name = name + self._op_schema = op_schema or opset[name] + self._signature: Optional[_schemas.OpSignature] = None self._param_schemas: Optional[tuple[ParamSchema, ...]] = None if self._op_schema is None: logger.debug( "An OpSchema was not provided for Op '%s' and " "there is not one found in opset '%s'.", - opname, + name, opset, ) @@ -312,10 +315,36 @@ def opset(self) -> Opset: def op_schema(self) -> Optional[onnx.defs.OpSchema]: return self._op_schema + @deprecation.deprecated( + since="0.1", + removed_in="the future", + instructions="check if '.op_schema' is not None instead", + ) def has_schema(self) -> bool: """Returns True if this op has an OpSchema.""" return self.op_schema is not None + @property + def op_signature(self) -> Optional[_schemas.OpSignature]: + """Returns the signature of this op.""" + if self._signature is not None: + return self._signature + + if self.op_schema is None: + return None + + self._signature = _schemas.OpSignature.from_op_schema(self.op_schema) + return self._signature + + @op_signature.setter + def op_signature(self, value: _schemas.OpSignature): + self._signature = value + + @deprecation.deprecated( + since="0.1", + removed_in="the future", + instructions="use '.op_signature' instead", + ) def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]: """Returns the parameter schemas for this op, if it has one.""" if self._param_schemas is not None: @@ -325,7 +354,7 @@ def param_schemas(self) -> Optional[tuple[ParamSchema, ...]]: if op_schema is None: return None - self._param_schemas = param_schemas_from_op_schema(op_schema) + self._param_schemas = _param_schemas_from_op_schema(op_schema) return self._param_schemas @@ -362,7 +391,7 @@ def as_tuple(self) -> tuple[str, list[str], str]: return (self.name, self.allowed_types, self.description) -def op_schema_from_function_ir( +def _op_schema_from_function_ir( function_ir: irbuilder.IRFunction, opset: Opset ) -> onnx.defs.OpSchema: """Construct an ONNX OpSchema from an IRFunction.""" @@ -486,7 +515,7 @@ def __init__( @property @deprecation.deprecated( since="0.1", - removed_in="0.3", + removed_in="the future", instructions="use '.name' instead", ) def opname(self) -> str: @@ -500,10 +529,28 @@ def op_schema(self) -> Optional[onnx.defs.OpSchema]: if self._op_schema is not None: return self._op_schema - self._op_schema = op_schema_from_function_ir(self.function_ir, self.opset) + self._op_schema = _op_schema_from_function_ir(self.function_ir, self.opset) return self._op_schema + @property + def op_signature(self) -> Optional[_schemas.OpSignature]: + """Returns the signature of this op.""" + if self._signature is not None: + return self._signature + + if self.op_schema is None: + return None + + self._signature = _schemas.OpSignature.from_function( + self.function, domain=self.function_ir.domain, name=self.name + ) + return self._signature + + @op_signature.setter + def op_signature(self, value: _schemas.OpSignature): + self._signature = value + def __getitem__(self, instance): """Returns a lambda to evaluate function using given evaluator instance. @@ -531,6 +578,11 @@ def __call__(self, *args, **kwargs): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.function!r})" + @deprecation.deprecated( + since="0.1", + removed_in="the future", + instructions="use '.op_signature' instead", + ) def param_schemas(self) -> tuple[ParamSchema, ...]: """Returns the parameter schemas of this function.""" if self._param_schemas is not None: @@ -539,7 +591,7 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: # NOTE: We generate the parameter schemas from the function_ir instead # of relying on the auto generated OpSchema because we need to preserve the keyword # argument order from the Python function definition, which is lost in OpSchema. - self._param_schemas = param_schemas_from_function_ir(self.function_ir) + self._param_schemas = _param_schemas_from_function_ir(self.function_ir) return self._param_schemas def to_function_proto(self) -> onnx.FunctionProto: @@ -612,10 +664,33 @@ def op_schema(self) -> Optional[onnx.defs.OpSchema]: return self._op_schema # FIXME(justinchuby): outputs are empty. Need to fix. - self._op_schema = op_schema_from_function_ir(self.function_ir, self._opset) + self._op_schema = _op_schema_from_function_ir(self.function_ir, self._opset) return self._op_schema + @property + def op_signature(self) -> Optional[_schemas.OpSignature]: + """Returns the signature of this op.""" + if self._signature is not None: + return self._signature + + if self.op_schema is None: + return None + + self._signature = _schemas.OpSignature.from_function( + self.func, domain="_traced", name=self.name + ) + return self._signature + + @op_signature.setter + def op_signature(self, value: _schemas.OpSignature): + self._signature = value + + @deprecation.deprecated( + since="0.1", + removed_in="the future", + instructions="use '.op_signature' instead", + ) def param_schemas(self) -> tuple[ParamSchema, ...]: """Returns the parameter schemas of this function.""" if self._param_schemas is not None: @@ -624,7 +699,7 @@ def param_schemas(self) -> tuple[ParamSchema, ...]: # NOTE: We generate the parameter schemas from the function_ir instead # of relying on the auto generated OpSchema because we need to preserve the keyword # argument order from the Python function definition, which is lost in OpSchema. - self._param_schemas = param_schemas_from_function_ir(self.function_ir) + self._param_schemas = _param_schemas_from_function_ir(self.function_ir) return self._param_schemas