From fdf81a387672e592813f7a2b6fc5dce9ca80c905 Mon Sep 17 00:00:00 2001 From: Atanas Dimitrov Date: Wed, 4 Dec 2024 15:55:04 +0200 Subject: [PATCH] Improve type-hinting information --- pyproject.toml | 3 ++ src/spox/_fields.py | 20 ++++++++----- src/spox/_inline.py | 6 ++-- src/spox/_internal_op.py | 4 +-- src/spox/_node.py | 4 +-- src/spox/_standard.py | 1 - src/spox/_value_prop.py | 8 ++++-- src/spox/_var.py | 62 ++++++++++++++++++---------------------- 8 files changed, 57 insertions(+), 51 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8f5e249..fe7bbf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -52,6 +52,9 @@ namespaces = false [tool.ruff.lint] # Enable the isort rules. extend-select = ["I", "UP"] +ignore = [ + "UP007", # https://docs.astral.sh/ruff/rules/non-pep604-annotation/ +] [tool.ruff.lint.isort] known-first-party = ["spox"] diff --git a/src/spox/_fields.py b/src/spox/_fields.py index f291295..ef20a50 100644 --- a/src/spox/_fields.py +++ b/src/spox/_fields.py @@ -1,6 +1,8 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import dataclasses import enum import warnings @@ -35,14 +37,18 @@ class VarFieldKind(enum.Enum): class BaseVars: - def __init__(self, vars): + """A collection of `Var`-s used to carry around inputs/outputs of nodes""" + + vars: dict[str, Union[Var, Optional[Var], Sequence[Var]]] + + def __init__(self, vars: dict[str, Union[Var, Optional[Var], Sequence[Var]]]): self.vars = vars def _unpack_to_any(self): """Unpack the stored fields into a tuple of appropriate length, typed as Any.""" return tuple(self.vars.values()) - def _flatten(self): + def _flatten(self) -> Iterator[tuple[str, Optional[Var]]]: """Iterate over the pairs of names and values of fields in this object.""" for key, value in self.vars.items(): if value is None or isinstance(value, Var): @@ -50,11 +56,11 @@ def _flatten(self): else: yield from ((f"{key}_{i}", v) for i, v in enumerate(value)) - def flatten_vars(self): + def flatten_vars(self) -> dict[str, Var]: """Return a flat mapping by name of all the VarInfos in this object.""" return {key: var for key, var in self._flatten() if var is not None} - def __getattr__(self, attr: str) -> Union["Var", Sequence["Var"]]: + def __getattr__(self, attr: str): """Retrieves the attribute if present in the stored variables.""" try: return self.vars[attr] @@ -63,7 +69,7 @@ def __getattr__(self, attr: str) -> Union["Var", Sequence["Var"]]: f"{self.__class__.__name__!r} object has no attribute {attr!r}" ) - def __setattr__(self, attr: str, value: Union["Var", Sequence["Var"]]) -> None: + def __setattr__(self, attr: str, value: Union[Var, Sequence[Var]]) -> None: """Sets the attribute to a value if the attribute is present in the stored variables.""" if attr == "vars": super().__setattr__(attr, value) @@ -74,7 +80,7 @@ def __getitem__(self, key: str): """Allows dictionary-like access to retrieve variables.""" return self.vars[key] - def __setitem__(self, key: str, value) -> None: + def __setitem__(self, key: str, value: Union) -> None: """Allows dictionary-like access to set variables.""" self.vars[key] = value @@ -160,7 +166,7 @@ def vars(self, prop_values: Optional[PropDict] = None) -> BaseVars: if prop_values is None: prop_values = {} - vars_dict: dict[str, Union[Var, Sequence[Var]]] = {} + vars_dict: dict[str, Union[Var, Optional[Var], Sequence[Var]]] = {} for field in dataclasses.fields(self): field_type = self._get_field_type(field) diff --git a/src/spox/_inline.py b/src/spox/_inline.py index 6fa01d7..6ed2af9 100644 --- a/src/spox/_inline.py +++ b/src/spox/_inline.py @@ -111,7 +111,9 @@ def opset_req(self) -> set[tuple[str, int]]: ("", INTERNAL_MIN_OPSET) } - def infer_output_types(self, input_prop_values) -> dict[str, Type]: + def infer_output_types( + self, input_prop_values: _value_prop.PropDict + ) -> dict[str, Type]: # First, type check that we match the ModelProto type requirements for i, var in zip(self.graph.input, self.inputs.inputs): if var.type is not None and not ( @@ -128,7 +130,7 @@ def infer_output_types(self, input_prop_values) -> dict[str, Type]: } def propagate_values( - self, input_prop_values + self, input_prop_values: _value_prop.PropDict ) -> dict[str, _value_prop.PropValueType]: if any( var_info.type is None or input_prop_values.get(var_info.name) is None diff --git a/src/spox/_internal_op.py b/src/spox/_internal_op.py index 84fbc2b..14782de 100644 --- a/src/spox/_internal_op.py +++ b/src/spox/_internal_op.py @@ -88,7 +88,7 @@ def post_init(self, **kwargs): if self.attrs.name is not None: self.outputs.arg._rename(self.attrs.name.value) - def infer_output_types(self, input_prop_values) -> dict[str, Type]: + def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]: # Output type is based on the value of the type attribute return {"arg": self.attrs.type.value} @@ -161,7 +161,7 @@ class Outputs(BaseOutputs): inputs: Inputs outputs: Outputs - def infer_output_types(self, input_prop_values) -> dict[str, Type]: + def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]: return { f"outputs_{i}": arr.type for i, arr in enumerate(self.inputs.inputs) diff --git a/src/spox/_node.py b/src/spox/_node.py index 383f452..691a7c7 100644 --- a/src/spox/_node.py +++ b/src/spox/_node.py @@ -17,7 +17,7 @@ from ._attributes import AttrGraph from ._debug import STORE_TRACEBACK from ._exceptions import InferenceWarning -from ._fields import BaseAttributes, BaseInputs, BaseOutputs, VarFieldKind +from ._fields import BaseAttributes, BaseInputs, BaseOutputs, BaseVars, VarFieldKind from ._type_system import Type from ._value_prop import PropDict from ._var import _VarInfo @@ -244,7 +244,7 @@ def inference( def get_output_vars( self, input_prop_values: Optional[PropDict] = None, infer_types: bool = True - ): + ) -> BaseVars: if input_prop_values is None: input_prop_values = {} # After typing everything, try to get values for outputs diff --git a/src/spox/_standard.py b/src/spox/_standard.py index 0dde0d3..77c3fb0 100644 --- a/src/spox/_standard.py +++ b/src/spox/_standard.py @@ -103,7 +103,6 @@ def out_value_info(curr_key, curr_var): ] # Initializers, passed in to allow partial data propagation # - used so that operators like Reshape are aware of constant shapes - # TODO: fix this initializers = [] diff --git a/src/spox/_value_prop.py b/src/spox/_value_prop.py index 2e01f86..6bad9f0 100644 --- a/src/spox/_value_prop.py +++ b/src/spox/_value_prop.py @@ -3,11 +3,11 @@ import enum import logging +import typing import warnings from collections.abc import Iterable from dataclasses import dataclass from typing import Callable, Union -from typing import Optional as tOptional import numpy as np import numpy.typing as npt @@ -26,8 +26,10 @@ - PropValue -> Optional, Some (has value) - None -> Optional, Nothing (no value) """ -PropValueType = Union[np.ndarray, Iterable[tOptional["PropValue"]], "PropValue", None] -PropDict = dict[str, Union[Iterable[tOptional["PropValue"]], "PropValue", None]] +PropValueType = Union[ + np.ndarray, Iterable[typing.Optional["PropValue"]], "PropValue", None +] +PropDict = dict[str, Union[Iterable[typing.Optional["PropValue"]], "PropValue", None]] ORTValue = Union[np.ndarray, Iterable, None] RefValue = Union[np.ndarray, Iterable, float, None] diff --git a/src/spox/_var.py b/src/spox/_var.py index c9b86e0..291ea6d 100644 --- a/src/spox/_var.py +++ b/src/spox/_var.py @@ -1,6 +1,8 @@ # Copyright (c) QuantCo 2023-2024 # SPDX-License-Identifier: BSD-3-Clause +from __future__ import annotations + import typing from collections.abc import Iterable, Sequence from typing import Any, Callable, ClassVar, Optional, TypeVar, Union, overload @@ -31,12 +33,12 @@ class _VarInfo: """ type: Optional[_type_system.Type] - _op: "Node" + _op: Node _name: Optional[str] def __init__( self, - op: "Node", + op: Node, type_: Optional[_type_system.Type], ): """The initializer of ``VarInfo`` is protected. Use operator constructors to construct them instead.""" @@ -100,12 +102,12 @@ def unwrap_optional(self) -> _type_system.Optional: """Equivalent to ``self.unwrap_type().unwrap_optional()``.""" return self.unwrap_type().unwrap_optional() - def __copy__(self) -> "_VarInfo": + def __copy__(self) -> _VarInfo: # Simply return `self` to ensure that "copies" are still equal # during the build process return self - def __deepcopy__(self, _) -> "_VarInfo": + def __deepcopy__(self, _) -> _VarInfo: raise ValueError("'VarInfo' objects cannot be deepcopied.") @@ -158,7 +160,7 @@ def __init__( self._var_info = var_info self._value = value - def _get_value(self) -> "_value_prop.ORTValue": + def _get_value(self) -> _value_prop.ORTValue: """Get the propagated value in this Var and convert it to the ORT format. Raises if value is missing.""" if self._value is None: raise ValueError("No propagated value associated with this Var.") @@ -224,66 +226,66 @@ def _which_output(self): def type(self): return self._var_info.type - def __copy__(self) -> "Var": + def __copy__(self) -> Var: # Simply return `self` to ensure that "copies" are still equal # during the build process return self - def __deepcopy__(self, _) -> "Var": + def __deepcopy__(self, _) -> Var: raise ValueError("'Var' objects cannot be deepcopied.") - def __add__(self, other) -> "Var": + def __add__(self, other) -> Var: return Var._operator_dispatcher.add(self, other) - def __sub__(self, other) -> "Var": + def __sub__(self, other) -> Var: return Var._operator_dispatcher.sub(self, other) - def __mul__(self, other) -> "Var": + def __mul__(self, other) -> Var: return Var._operator_dispatcher.mul(self, other) - def __truediv__(self, other) -> "Var": + def __truediv__(self, other) -> Var: return Var._operator_dispatcher.truediv(self, other) - def __floordiv__(self, other) -> "Var": + def __floordiv__(self, other) -> Var: return Var._operator_dispatcher.floordiv(self, other) - def __neg__(self) -> "Var": + def __neg__(self) -> Var: return Var._operator_dispatcher.neg(self) - def __and__(self, other) -> "Var": + def __and__(self, other) -> Var: return Var._operator_dispatcher.and_(self, other) - def __or__(self, other) -> "Var": + def __or__(self, other) -> Var: return Var._operator_dispatcher.or_(self, other) - def __xor__(self, other) -> "Var": + def __xor__(self, other) -> Var: return Var._operator_dispatcher.xor(self, other) - def __invert__(self) -> "Var": + def __invert__(self) -> Var: return Var._operator_dispatcher.not_(self) - def __radd__(self, other) -> "Var": + def __radd__(self, other) -> Var: return Var._operator_dispatcher.add(other, self) - def __rsub__(self, other) -> "Var": + def __rsub__(self, other) -> Var: return Var._operator_dispatcher.sub(other, self) - def __rmul__(self, other) -> "Var": + def __rmul__(self, other) -> Var: return Var._operator_dispatcher.mul(other, self) - def __rtruediv__(self, other) -> "Var": + def __rtruediv__(self, other) -> Var: return Var._operator_dispatcher.truediv(other, self) - def __rfloordiv__(self, other) -> "Var": + def __rfloordiv__(self, other) -> Var: return Var._operator_dispatcher.floordiv(other, self) - def __rand__(self, other) -> "Var": + def __rand__(self, other) -> Var: return Var._operator_dispatcher.and_(other, self) - def __ror__(self, other) -> "Var": + def __ror__(self, other) -> Var: return Var._operator_dispatcher.or_(other, self) - def __rxor__(self, other) -> "Var": + def __rxor__(self, other) -> Var: return Var._operator_dispatcher.xor(other, self) @@ -372,12 +374,4 @@ def create_prop_dict( flattened_vars = BaseVars(kwargs).flatten_vars() - return { - key: ( - var._value - if isinstance(var, Var) - else {k: v._value for k, v in var.items()} - ) - for key, var in flattened_vars.items() - if var is not None - } + return {key: var._value for key, var in flattened_vars.items() if var is not None}