Skip to content

Commit

Permalink
Type anotate everything
Browse files Browse the repository at this point in the history
  • Loading branch information
neNasko1 committed Dec 2, 2024
1 parent 8c8c05a commit 8ae420e
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 24 deletions.
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import subprocess
import sys
from subprocess import CalledProcessError
from typing import cast
from typing import Any, Optional, cast

_mod = importlib.import_module("spox")

Expand Down Expand Up @@ -60,7 +60,7 @@
# Copied and adapted from
# https://github.com/pandas-dev/pandas/blob/4a14d064187367cacab3ff4652a12a0e45d0711b/doc/source/conf.py#L613-L659
# Required configuration function to use sphinx.ext.linkcode
def linkcode_resolve(domain, info):
def linkcode_resolve(domain: str, info: dict[str, Any]) -> Optional[str]:
"""Determine the URL corresponding to a given Python object."""
if domain != "py":
return None
Expand Down
4 changes: 2 additions & 2 deletions src/spox/_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

import sys
from collections.abc import Generator
from collections.abc import Iterator
from contextlib import contextmanager
from typing import Any

Expand All @@ -15,7 +15,7 @@
@contextmanager
def show_construction_tracebacks(
debug_index: dict[str, Any],
) -> Generator[None, None, None]:
) -> Iterator[None]:
"""
Context manager constructed with a ``Builder.build_result.debug_index``.
Expand Down
19 changes: 11 additions & 8 deletions src/spox/_future.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
"""Module containing experimental Spox features that may be standard in the future."""

import warnings
from collections.abc import Iterable
from collections.abc import Iterable, Iterator
from contextlib import contextmanager
from typing import Optional, Union
from types import ModuleType
from typing import Any, Optional, Union

import numpy as np
import numpy.typing as npt
Expand All @@ -25,7 +26,7 @@ def set_type_warning_level(level: TypeWarningLevel) -> None:


@contextmanager
def type_warning_level(level: TypeWarningLevel):
def type_warning_level(level: TypeWarningLevel) -> Iterator[None]:
prev_level = spox._node._TYPE_WARNING_LEVEL
set_type_warning_level(level)
yield
Expand All @@ -40,7 +41,7 @@ def set_value_prop_backend(backend: ValuePropBackend) -> None:


@contextmanager
def value_prop_backend(backend: ValuePropBackend):
def value_prop_backend(backend: ValuePropBackend) -> Iterator[None]:
prev_backend = spox._value_prop._VALUE_PROP_BACKEND
set_value_prop_backend(backend)
yield
Expand Down Expand Up @@ -74,7 +75,9 @@ def initializer(value: npt.ArrayLike, dtype: npt.DTypeLike = None) -> Var:


class _NumpyLikeOperatorDispatcher:
def __init__(self, op, type_promotion: bool, constant_promotion: bool):
def __init__(
self, op: ModuleType, type_promotion: bool, constant_promotion: bool
) -> None:
self.op = op
self.type_promotion = type_promotion
self.constant_promotion = constant_promotion
Expand Down Expand Up @@ -166,8 +169,8 @@ def not_(self, a: Var) -> Var:

@contextmanager
def _operator_overloading(
op, type_promotion: bool = False, constant_promotion: bool = True
):
op: ModuleType, type_promotion: bool = False, constant_promotion: bool = True
) -> Iterator[None]:
"""Enable operator overloading on Var for this block.
May be used either as a context manager, or a decorator.
Expand Down Expand Up @@ -210,7 +213,7 @@ def _operator_overloading(
Var._operator_dispatcher = prev_dispatcher


def __getattr__(name):
def __getattr__(name: str) -> Any:
if name == "operator_overloading":
warnings.warn(
"using 'operator_overloading' is deprecated, consider using https://github.com/Quantco/ndonnx instead",
Expand Down
4 changes: 2 additions & 2 deletions src/spox/_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import itertools
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Callable, Optional
from typing import Any, Callable, Optional

import onnx

Expand Down Expand Up @@ -98,7 +98,7 @@ class Outputs(BaseOutputs):
inputs: Inputs
outputs: Outputs

def pre_init(self, **kwargs) -> None:
def pre_init(self, **kwargs: Any) -> None:
self.model = kwargs["model"]

@property
Expand Down
4 changes: 2 additions & 2 deletions src/spox/_internal_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from abc import ABC
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional
from typing import TYPE_CHECKING, Any, Callable, Optional

import numpy as np
import onnx
Expand Down Expand Up @@ -88,7 +88,7 @@ class Outputs(BaseOutputs):
inputs: Inputs
outputs: Outputs

def post_init(self, **kwargs) -> None:
def post_init(self, **kwargs: Any) -> None:
if self.attrs.name is not None:
self.outputs.arg._rename(self.attrs.name.value)

Expand Down
8 changes: 4 additions & 4 deletions src/spox/_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from abc import ABC
from collections.abc import Generator, Iterable, Sequence
from dataclasses import dataclass
from typing import ClassVar, Optional, Union
from typing import Any, ClassVar, Optional, Union

import numpy as np
import onnx
Expand Down Expand Up @@ -98,7 +98,7 @@ def __init__(
infer_types: bool = True,
propagate_values: bool = True,
validate: bool = True,
**kwargs,
**kwargs: Any,
):
"""
Parameters
Expand Down Expand Up @@ -208,10 +208,10 @@ def get_op_repr(cls) -> str:
domain = cls.op_type.domain if cls.op_type.domain != "" else "ai.onnx"
return f"{domain}@{cls.op_type.version}::{cls.op_type.identifier}"

def pre_init(self, **kwargs) -> None:
def pre_init(self, **kwargs: Any) -> None:
"""Pre-initialization hook. Called during ``__init__`` before any field on the object is set."""

def post_init(self, **kwargs) -> None:
def post_init(self, **kwargs: Any) -> None:
"""Post-initialization hook. Called at the end of ``__init__`` after other default fields are set."""

def propagate_values(self) -> dict[str, PropValueType]:
Expand Down
4 changes: 2 additions & 2 deletions src/spox/_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import contextlib
import itertools
from collections.abc import Generator
from collections.abc import Iterator
from typing import Optional, Protocol

import numpy as np
Expand Down Expand Up @@ -42,7 +42,7 @@ def argument(typ: Type) -> Var:


@contextlib.contextmanager
def _temporary_renames(**kwargs: Var) -> Generator[None, None, None]:
def _temporary_renames(**kwargs: Var) -> Iterator[None]:
# The build code can't really special-case variable names that are
# not just ``Var._name``. So we set names here and reset them
# afterwards.
Expand Down
2 changes: 1 addition & 1 deletion src/spox/_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


class NotImplementedOperatorDispatcher:
def _not_impl(self, *args):
def _not_impl(self, *args: Any):
return NotImplemented

add = sub = mul = truediv = floordiv = neg = and_ = or_ = xor = not_ = _not_impl
Expand Down
2 changes: 1 addition & 1 deletion tools/generate_opset.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class Attribute:
# Mark whether generating extra constructor arguments caused by this should raise
allow_extra: bool = False

def __post_init__(self):
def __post_init__(self) -> None:
if self.attr_constructor != "AttrGraph" and self.subgraph_solution is not None:
raise TypeError(
"Subgraph input types should only be specified for an AttrGraph."
Expand Down

0 comments on commit 8ae420e

Please sign in to comment.