Skip to content

Commit

Permalink
🏷️ First pass at addressing type checker errors
Browse files Browse the repository at this point in the history
* Decided to ignore the lookups.py because... well, the Django ORM.
* and decided to ignore zgw_consumers.legacy which is scheduled
  for removal anyway, there are better places to sink our energy
  into
* zgw_consumers/api_models/base.py should be rewritten using pydantic
  OR be based on TypedDict - the code is not type checker friendly
  • Loading branch information
sergei-maertens committed Nov 29, 2024
1 parent 13d947a commit 127dc3c
Show file tree
Hide file tree
Showing 21 changed files with 126 additions and 65 deletions.
16 changes: 16 additions & 0 deletions pyright.pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[tool.pyright]
include = [
"zgw_consumers/"
]
exclude = [
# should really be replaced with pydantic or typed dicts...
"zgw_consumers/api_models/base.py",
# this module is quite funky... doesn't hold up to the base types
"zgw_consumers/models/lookups.py",
# this should be removed instead of fixed
"zgw_consumers/legacy/",
]
ignore = []

pythonVersion = "3.11"
pythonPlatform = "Linux"
1 change: 0 additions & 1 deletion zgw_consumers/api_models/_camel_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.

import re

from django.core.files import File
Expand Down
2 changes: 1 addition & 1 deletion zgw_consumers/api_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

from dateutil.parser import parse
from dateutil.relativedelta import relativedelta
from relativedeltafield.utils import parse_relativedelta

from ._camel_case import underscoreize
from .compat import parse_relativedelta
from .types import JSONObject

__all__ = ["CONVERTERS", "factory", "Model", "ZGWModel"]
Expand Down
2 changes: 1 addition & 1 deletion zgw_consumers/api_models/besluiten.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class Besluit(ZGWModel):
uiterlijke_reactiedatum: Optional[date] = None

def get_vervalreden_display(self) -> str:
return VervalRedenen.labels[self.vervalreden]
return VervalRedenen(self.vervalreden).label


@dataclass
Expand Down
8 changes: 6 additions & 2 deletions zgw_consumers/api_models/catalogi.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,17 +123,21 @@ class Eigenschap(ZGWModel):
zaaktype: str
naam: str
definitie: str
specificatie: dict
specificatie: EigenschapSpecificatie | dict
toelichting: str = ""

def __post_init__(self):
super().__post_init__()
self.specificatie = factory(EigenschapSpecificatie, self.specificatie)
assert isinstance(self.specificatie, dict)
_specificatie = factory(EigenschapSpecificatie, self.specificatie)
assert isinstance(_specificatie, EigenschapSpecificatie)
self.specificatie = _specificatie

def to_python(self, value: str) -> Union[str, Decimal, date, datetime]:
"""
Cast the string value into the appropriate python type based on the spec.
"""
assert isinstance(self.specificatie, EigenschapSpecificatie)
formaat = self.specificatie.formaat
assert formaat in EIGENSCHAP_FORMATEN, f"Unknown format {formaat}"

Expand Down
4 changes: 0 additions & 4 deletions zgw_consumers/api_models/compat.py

This file was deleted.

2 changes: 1 addition & 1 deletion zgw_consumers/api_models/documenten.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ class Document(ZGWModel):
locked: bool = False

def get_vertrouwelijkheidaanduiding_display(self):
return VertrouwelijkheidsAanduidingen.values[self.vertrouwelijkheidaanduiding]
return VertrouwelijkheidsAanduidingen(self.vertrouwelijkheidaanduiding).label
10 changes: 6 additions & 4 deletions zgw_consumers/api_models/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Dict, List, Union
from __future__ import annotations

JSONPrimitive = Union[str, int, None, float]
JSONValue = Union[JSONPrimitive, "JSONObject", List["JSONValue"]]
JSONObject = Dict[str, JSONValue]
from typing import TypeAlias

JSONPrimitive: TypeAlias = str | int | None | float
JSONValue: TypeAlias = "JSONPrimitive | JSONObject | list[JSONValue]"
JSONObject: TypeAlias = dict[str, JSONValue]
6 changes: 3 additions & 3 deletions zgw_consumers/api_models/zaken.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class Zaak(ZGWModel):
zaakgeometrie: dict = field(default_factory=dict)

def get_vertrouwelijkheidaanduiding_display(self):
return VertrouwelijkheidsAanduidingen.values[self.vertrouwelijkheidaanduiding]
return VertrouwelijkheidsAanduidingen(self.vertrouwelijkheidaanduiding).label


@dataclass
Expand Down Expand Up @@ -92,10 +92,10 @@ class Rol(ZGWModel):
betrokkene_identificatie: dict = field(default_factory=dict)

def get_betrokkene_type_display(self):
return RolTypes.values[self.betrokkene_type]
return RolTypes(self.betrokkene_type).label

def get_omschrijving_generiek_display(self):
return RolOmschrijving.values[self.omschrijving_generiek]
return RolOmschrijving(self.omschrijving_generiek).label


@dataclass
Expand Down
2 changes: 1 addition & 1 deletion zgw_consumers/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def install_schema_fetcher_cache():
except ImportError:
return

schema_fetcher.cache = OASCache()
schema_fetcher.cache = OASCache() # type: ignore - untyped library...
9 changes: 5 additions & 4 deletions zgw_consumers/concurrent.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,13 @@ class parallel:
def __init__(self, **kwargs):
self.executor = futures.ThreadPoolExecutor(**kwargs)

def submit(*args, **kwargs):
if len(args) >= 2:
self, _fn, *args = args
def submit(self, *args, **kwargs):
if len(args) >= 1:
_fn, *args = args
elif "fn" in kwargs:
_fn = kwargs.pop("fn")
self, *args = args
else:
raise TypeError("Invalid signature")

fn = wrap_fn(_fn)

Expand Down
35 changes: 27 additions & 8 deletions zgw_consumers/drf/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,18 @@ def get_fields(self):
serializer_class=self.__class__.__name__
)
assert hasattr(
self.Meta, "model"
self.Meta, "model" # pyright: ignore[reportAttributeAccessIssue]
), 'Class {serializer_class} missing "Meta.model" attribute'.format(
serializer_class=self.__class__.__name__
)

declared_fields = copy.deepcopy(self._declared_fields)
model = self.Meta.model
depth = getattr(self.Meta, "depth", 0)
model = self.Meta.model # pyright: ignore[reportAttributeAccessIssue]
depth = getattr(
self.Meta, # pyright: ignore[reportAttributeAccessIssue]
"depth",
0,
)

if depth is not None:
assert depth >= 0, "'depth' may not be negative."
Expand Down Expand Up @@ -88,7 +92,7 @@ def get_fields(self):
return fields

def get_field_names(self, declared_fields):
fields = self.Meta.fields
fields = self.Meta.fields # pyright: ignore[reportAttributeAccessIssue]
# Ensure that all declared fields have also been included in the
# `Meta.fields` option.

Expand All @@ -114,9 +118,19 @@ def get_extra_kwargs(self):
Return a dictionary mapping field names to a dictionary of
additional keyword arguments.
"""
extra_kwargs = copy.deepcopy(getattr(self.Meta, "extra_kwargs", {}))
extra_kwargs = copy.deepcopy(
getattr(
self.Meta, # pyright: ignore[reportAttributeAccessIssue]
"extra_kwargs",
{},
)
)

read_only_fields = getattr(self.Meta, "read_only_fields", None)
read_only_fields = getattr(
self.Meta, # pyright: ignore[reportAttributeAccessIssue]
"read_only_fields",
None,
)
if read_only_fields is not None:
if not isinstance(read_only_fields, (list, tuple)):
raise TypeError(
Expand All @@ -131,7 +145,10 @@ def get_extra_kwargs(self):
else:
# Guard against the possible misspelling `readonly_fields` (used
# by the Django admin and others).
assert not hasattr(self.Meta, "readonly_fields"), (
assert not hasattr(
self.Meta, # pyright: ignore[reportAttributeAccessIssue]
"readonly_fields",
), (
"Serializer `%s.%s` has field `readonly_fields`; "
"the correct spelling for the option is `read_only_fields`."
% (self.__class__.__module__, self.__class__.__name__)
Expand Down Expand Up @@ -186,7 +203,9 @@ def build_standard_field(self, field_name, model_field_type):
if "choices" in field_kwargs:
# Fields with choices get coerced into `ChoiceField`
# instead of using their regular typed field.
field_class = self.serializer_choice_field
# fmt: off
field_class = self.serializer_choice_field # pyright: ignore[reportAttributeAccessIssue]
# fmt: on
# Some model fields may introduce kwargs that would not be valid
# for the choice field. We need to strip these out.
# Eg. models.DecimalField(max_digits=3, decimal_places=1, choices=DECIMAL_CHOICES)
Expand Down
13 changes: 9 additions & 4 deletions zgw_consumers/drf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,17 @@ def extract_model_field_type(model_class, field_name):

# support for Optional / List
if hasattr(typehint, "__origin__"):
if typehint.__origin__ is list and typehint.__args__:
subtypehint = typehint.__args__[0]
if (
typehint.__origin__ is list # pyright: ignore[reportAttributeAccessIssue]
and typehint.__args__ # pyright: ignore[reportAttributeAccessIssue]
):
# fmt: off
subtypehint = typehint.__args__[0] # pyright: ignore[reportAttributeAccessIssue]
# fmt: on
raise NotImplementedError("TODO: support collections")

if typehint.__origin__ is Union:
typehint = typehint.__args__
if typehint.__origin__ is Union: # pyright: ignore[reportAttributeAccessIssue]
typehint = typehint.__args__ # pyright: ignore[reportAttributeAccessIssue]
# Optional is ONE type combined with None
typehint = next(t for t in typehint if t is not None)
return typehint
2 changes: 1 addition & 1 deletion zgw_consumers/models/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class RestAPIService(Service):
validators=[FileExtensionValidator(["yml", "yaml"])],
)

class Meta:
class Meta: # pyright: ignore[reportIncompatibleVariableOverride]
abstract = True

def clean(self):
Expand Down
2 changes: 1 addition & 1 deletion zgw_consumers/models/certificates.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@


class Certificate(NewCertificate):
class Meta:
class Meta: # pyright: ignore[reportIncompatibleVariableOverride]
proxy = True
25 changes: 14 additions & 11 deletions zgw_consumers/models/fields.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Optional
from urllib.parse import urljoin

from django.core import checks
Expand All @@ -12,14 +11,14 @@ def __init__(self, field):
self.field = field

def get_base_url(self, base_val) -> str:
return getattr(base_val, "api_root", None)
return getattr(base_val, "api_root", "")

def get_base_val(self, detail_url: str):
from zgw_consumers.models import Service

return Service.get_service(detail_url)

def __get__(self, instance: Model, cls=None) -> Optional[str]:
def __get__(self, instance: Model | None, cls=None) -> str | None:
if instance is None:
return None

Expand All @@ -30,7 +29,7 @@ def __get__(self, instance: Model, cls=None) -> Optional[str]:
# todo cache value
return urljoin(base_url, relative_val)

def __set__(self, instance: Model, value: Optional[str]):
def __set__(self, instance: Model, value: str | None):
if value is None and not self.field.null:
raise ValueError(
"A 'None'-value is not allowed. Make the field "
Expand Down Expand Up @@ -64,9 +63,9 @@ class ServiceUrlField(Field):
"""

# field flags
name = None
name: str
concrete = False
column = None
column: str | None = None # pyright: ignore[reportIncompatibleVariableOverride]
db_column = None

descriptor_class = ServiceUrlDescriptor
Expand Down Expand Up @@ -132,10 +131,10 @@ def _add_check_constraint(
return

@property
def attname(self) -> str:
def attname(self) -> str: # pyright: ignore[reportIncompatibleVariableOverride]
return self.name

def get_attname_column(self):
def get_attname_column(self): # pyright: ignore[reportIncompatibleMethodOverride]
return self.attname, None

def deconstruct(self):
Expand All @@ -150,13 +149,17 @@ def deconstruct(self):

@property
def _base_field(self) -> ForeignKey:
return self.model._meta.get_field(self.base_field)
field = self.model._meta.get_field(self.base_field)
assert isinstance(field, ForeignKey)
return field

@property
def _relative_field(self) -> CharField:
return self.model._meta.get_field(self.relative_field)
field = self.model._meta.get_field(self.relative_field)
assert isinstance(field, CharField)
return field

def check(self, **kwargs):
def check(self, **kwargs) -> list[checks.CheckMessage]:
return [
*self._check_field_name(),
*self._check_base_field(),
Expand Down
8 changes: 5 additions & 3 deletions zgw_consumers/models/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import socket
import uuid
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Callable
from urllib.parse import urlparse, urlsplit, urlunsplit

from django.core.exceptions import ValidationError
Expand Down Expand Up @@ -116,7 +116,9 @@ class Service(RestAPIService):

objects = ServiceManager()

class Meta:
get_api_type_display: Callable[[], str]

class Meta: # pyright: ignore[reportIncompatibleVariableOverride]
verbose_name = _("service")
verbose_name_plural = _("services")

Expand Down Expand Up @@ -290,7 +292,7 @@ class NLXConfig(SingletonModel):
blank=True,
)

class Meta:
class Meta: # pyright: ignore[reportIncompatibleVariableOverride]
verbose_name = _("NLX configuration")

@property
Expand Down
16 changes: 12 additions & 4 deletions zgw_consumers/nlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
from collections.abc import Iterable
from itertools import groupby
from typing import TypedDict

import requests
from ape_pie import APIClient
Expand All @@ -26,9 +27,12 @@ def _rewrite_url(value: str, rewrites: Iterable[tuple[str, str]]) -> str | None:

class Rewriter:
def __init__(self):
self.rewrites: list[tuple[str, str]] = Service.objects.exclude(
nlx=""
).values_list("api_root", "nlx")
qs = Service.objects.exclude(nlx="").values_list("api_root", "nlx")
self._rewrites = qs

@property
def rewrites(self) -> list[tuple[str, str]]:
return list(self._rewrites)

@property
def reverse_rewrites(self) -> list[tuple[str, str]]:
Expand Down Expand Up @@ -159,7 +163,11 @@ class NLXClient(NLXMixin, APIClient):


Organization = dict[str, str]
ServiceType = dict[str, str]


class ServiceType(TypedDict):
name: str
organization: Organization


def get_nlx_services() -> list[tuple[Organization, list[ServiceType]]]:
Expand Down
Loading

0 comments on commit 127dc3c

Please sign in to comment.