Skip to content

Commit

Permalink
Improve type-hinting information
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Dec 4, 2024
1 parent b9f922c commit fdf81a3
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 51 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ namespaces = false
[tool.ruff.lint]
# Enable the isort rules.
extend-select = ["I", "UP"]
ignore = [
"UP007", # https://docs.astral.sh/ruff/rules/non-pep604-annotation/
]

[tool.ruff.lint.isort]
known-first-party = ["spox"]
Expand Down
20 changes: 13 additions & 7 deletions src/spox/_fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import dataclasses
import enum
import warnings
Expand Down Expand Up @@ -35,26 +37,30 @@ class VarFieldKind(enum.Enum):


class BaseVars:
def __init__(self, vars):
"""A collection of `Var`-s used to carry around inputs/outputs of nodes"""

vars: dict[str, Union[Var, Optional[Var], Sequence[Var]]]

def __init__(self, vars: dict[str, Union[Var, Optional[Var], Sequence[Var]]]):
self.vars = vars

def _unpack_to_any(self):
"""Unpack the stored fields into a tuple of appropriate length, typed as Any."""
return tuple(self.vars.values())

def _flatten(self):
def _flatten(self) -> Iterator[tuple[str, Optional[Var]]]:
"""Iterate over the pairs of names and values of fields in this object."""
for key, value in self.vars.items():
if value is None or isinstance(value, Var):
yield key, value
else:
yield from ((f"{key}_{i}", v) for i, v in enumerate(value))

def flatten_vars(self):
def flatten_vars(self) -> dict[str, Var]:
"""Return a flat mapping by name of all the VarInfos in this object."""
return {key: var for key, var in self._flatten() if var is not None}

def __getattr__(self, attr: str) -> Union["Var", Sequence["Var"]]:
def __getattr__(self, attr: str):
"""Retrieves the attribute if present in the stored variables."""
try:
return self.vars[attr]
Expand All @@ -63,7 +69,7 @@ def __getattr__(self, attr: str) -> Union["Var", Sequence["Var"]]:
f"{self.__class__.__name__!r} object has no attribute {attr!r}"
)

def __setattr__(self, attr: str, value: Union["Var", Sequence["Var"]]) -> None:
def __setattr__(self, attr: str, value: Union[Var, Sequence[Var]]) -> None:
"""Sets the attribute to a value if the attribute is present in the stored variables."""
if attr == "vars":
super().__setattr__(attr, value)
Expand All @@ -74,7 +80,7 @@ def __getitem__(self, key: str):
"""Allows dictionary-like access to retrieve variables."""
return self.vars[key]

def __setitem__(self, key: str, value) -> None:
def __setitem__(self, key: str, value: Union) -> None:
"""Allows dictionary-like access to set variables."""
self.vars[key] = value

Expand Down Expand Up @@ -160,7 +166,7 @@ def vars(self, prop_values: Optional[PropDict] = None) -> BaseVars:
if prop_values is None:
prop_values = {}

vars_dict: dict[str, Union[Var, Sequence[Var]]] = {}
vars_dict: dict[str, Union[Var, Optional[Var], Sequence[Var]]] = {}

for field in dataclasses.fields(self):
field_type = self._get_field_type(field)
Expand Down
6 changes: 4 additions & 2 deletions src/spox/_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,9 @@ def opset_req(self) -> set[tuple[str, int]]:
("", INTERNAL_MIN_OPSET)
}

def infer_output_types(self, input_prop_values) -> dict[str, Type]:
def infer_output_types(
self, input_prop_values: _value_prop.PropDict
) -> dict[str, Type]:
# First, type check that we match the ModelProto type requirements
for i, var in zip(self.graph.input, self.inputs.inputs):
if var.type is not None and not (
Expand All @@ -128,7 +130,7 @@ def infer_output_types(self, input_prop_values) -> dict[str, Type]:
}

def propagate_values(
self, input_prop_values
self, input_prop_values: _value_prop.PropDict
) -> dict[str, _value_prop.PropValueType]:
if any(
var_info.type is None or input_prop_values.get(var_info.name) is None
Expand Down
4 changes: 2 additions & 2 deletions src/spox/_internal_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def post_init(self, **kwargs):
if self.attrs.name is not None:
self.outputs.arg._rename(self.attrs.name.value)

def infer_output_types(self, input_prop_values) -> dict[str, Type]:
def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:
# Output type is based on the value of the type attribute
return {"arg": self.attrs.type.value}

Expand Down Expand Up @@ -161,7 +161,7 @@ class Outputs(BaseOutputs):
inputs: Inputs
outputs: Outputs

def infer_output_types(self, input_prop_values) -> dict[str, Type]:
def infer_output_types(self, input_prop_values: PropDict) -> dict[str, Type]:
return {
f"outputs_{i}": arr.type
for i, arr in enumerate(self.inputs.inputs)
Expand Down
4 changes: 2 additions & 2 deletions src/spox/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._attributes import AttrGraph
from ._debug import STORE_TRACEBACK
from ._exceptions import InferenceWarning
from ._fields import BaseAttributes, BaseInputs, BaseOutputs, VarFieldKind
from ._fields import BaseAttributes, BaseInputs, BaseOutputs, BaseVars, VarFieldKind
from ._type_system import Type
from ._value_prop import PropDict
from ._var import _VarInfo
Expand Down Expand Up @@ -244,7 +244,7 @@ def inference(

def get_output_vars(
self, input_prop_values: Optional[PropDict] = None, infer_types: bool = True
):
) -> BaseVars:
if input_prop_values is None:
input_prop_values = {}
# After typing everything, try to get values for outputs
Expand Down
1 change: 0 additions & 1 deletion src/spox/_standard.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,6 @@ def out_value_info(curr_key, curr_var):
]
# Initializers, passed in to allow partial data propagation
# - used so that operators like Reshape are aware of constant shapes
# TODO: fix this

initializers = []

Expand Down
8 changes: 5 additions & 3 deletions src/spox/_value_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

import enum
import logging
import typing
import warnings
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Callable, Union
from typing import Optional as tOptional

import numpy as np
import numpy.typing as npt
Expand All @@ -26,8 +26,10 @@
- PropValue -> Optional, Some (has value)
- None -> Optional, Nothing (no value)
"""
PropValueType = Union[np.ndarray, Iterable[tOptional["PropValue"]], "PropValue", None]
PropDict = dict[str, Union[Iterable[tOptional["PropValue"]], "PropValue", None]]
PropValueType = Union[
np.ndarray, Iterable[typing.Optional["PropValue"]], "PropValue", None
]
PropDict = dict[str, Union[Iterable[typing.Optional["PropValue"]], "PropValue", None]]
ORTValue = Union[np.ndarray, Iterable, None]
RefValue = Union[np.ndarray, Iterable, float, None]

Expand Down
62 changes: 28 additions & 34 deletions src/spox/_var.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) QuantCo 2023-2024
# SPDX-License-Identifier: BSD-3-Clause

from __future__ import annotations

import typing
from collections.abc import Iterable, Sequence
from typing import Any, Callable, ClassVar, Optional, TypeVar, Union, overload
Expand Down Expand Up @@ -31,12 +33,12 @@ class _VarInfo:
"""

type: Optional[_type_system.Type]
_op: "Node"
_op: Node
_name: Optional[str]

def __init__(
self,
op: "Node",
op: Node,
type_: Optional[_type_system.Type],
):
"""The initializer of ``VarInfo`` is protected. Use operator constructors to construct them instead."""
Expand Down Expand Up @@ -100,12 +102,12 @@ def unwrap_optional(self) -> _type_system.Optional:
"""Equivalent to ``self.unwrap_type().unwrap_optional()``."""
return self.unwrap_type().unwrap_optional()

def __copy__(self) -> "_VarInfo":
def __copy__(self) -> _VarInfo:
# Simply return `self` to ensure that "copies" are still equal
# during the build process
return self

def __deepcopy__(self, _) -> "_VarInfo":
def __deepcopy__(self, _) -> _VarInfo:
raise ValueError("'VarInfo' objects cannot be deepcopied.")


Expand Down Expand Up @@ -158,7 +160,7 @@ def __init__(
self._var_info = var_info
self._value = value

def _get_value(self) -> "_value_prop.ORTValue":
def _get_value(self) -> _value_prop.ORTValue:
"""Get the propagated value in this Var and convert it to the ORT format. Raises if value is missing."""
if self._value is None:
raise ValueError("No propagated value associated with this Var.")
Expand Down Expand Up @@ -224,66 +226,66 @@ def _which_output(self):
def type(self):
return self._var_info.type

def __copy__(self) -> "Var":
def __copy__(self) -> Var:
# Simply return `self` to ensure that "copies" are still equal
# during the build process
return self

def __deepcopy__(self, _) -> "Var":
def __deepcopy__(self, _) -> Var:
raise ValueError("'Var' objects cannot be deepcopied.")

def __add__(self, other) -> "Var":
def __add__(self, other) -> Var:
return Var._operator_dispatcher.add(self, other)

def __sub__(self, other) -> "Var":
def __sub__(self, other) -> Var:
return Var._operator_dispatcher.sub(self, other)

def __mul__(self, other) -> "Var":
def __mul__(self, other) -> Var:
return Var._operator_dispatcher.mul(self, other)

def __truediv__(self, other) -> "Var":
def __truediv__(self, other) -> Var:
return Var._operator_dispatcher.truediv(self, other)

def __floordiv__(self, other) -> "Var":
def __floordiv__(self, other) -> Var:
return Var._operator_dispatcher.floordiv(self, other)

def __neg__(self) -> "Var":
def __neg__(self) -> Var:
return Var._operator_dispatcher.neg(self)

def __and__(self, other) -> "Var":
def __and__(self, other) -> Var:
return Var._operator_dispatcher.and_(self, other)

def __or__(self, other) -> "Var":
def __or__(self, other) -> Var:
return Var._operator_dispatcher.or_(self, other)

def __xor__(self, other) -> "Var":
def __xor__(self, other) -> Var:
return Var._operator_dispatcher.xor(self, other)

def __invert__(self) -> "Var":
def __invert__(self) -> Var:
return Var._operator_dispatcher.not_(self)

def __radd__(self, other) -> "Var":
def __radd__(self, other) -> Var:
return Var._operator_dispatcher.add(other, self)

def __rsub__(self, other) -> "Var":
def __rsub__(self, other) -> Var:
return Var._operator_dispatcher.sub(other, self)

def __rmul__(self, other) -> "Var":
def __rmul__(self, other) -> Var:
return Var._operator_dispatcher.mul(other, self)

def __rtruediv__(self, other) -> "Var":
def __rtruediv__(self, other) -> Var:
return Var._operator_dispatcher.truediv(other, self)

def __rfloordiv__(self, other) -> "Var":
def __rfloordiv__(self, other) -> Var:
return Var._operator_dispatcher.floordiv(other, self)

def __rand__(self, other) -> "Var":
def __rand__(self, other) -> Var:
return Var._operator_dispatcher.and_(other, self)

def __ror__(self, other) -> "Var":
def __ror__(self, other) -> Var:
return Var._operator_dispatcher.or_(other, self)

def __rxor__(self, other) -> "Var":
def __rxor__(self, other) -> Var:
return Var._operator_dispatcher.xor(other, self)


Expand Down Expand Up @@ -372,12 +374,4 @@ def create_prop_dict(

flattened_vars = BaseVars(kwargs).flatten_vars()

return {
key: (
var._value
if isinstance(var, Var)
else {k: v._value for k, v in var.items()}
)
for key, var in flattened_vars.items()
if var is not None
}
return {key: var._value for key, var in flattened_vars.items() if var is not None}

0 comments on commit fdf81a3

Please sign in to comment.