diff --git a/electrolytes/__init__.py b/electrolytes/__init__.py index 76dcdda..5f4c94c 100644 --- a/electrolytes/__init__.py +++ b/electrolytes/__init__.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import pkgutil from collections.abc import Collection, Iterator, Mapping, Sequence from contextlib import ContextDecorator, suppress from functools import cached_property, singledispatchmethod from pathlib import Path -from types import TracebackType -from typing import Annotated, Optional +from typing import TYPE_CHECKING, Annotated from warnings import warn from filelock import FileLock @@ -18,11 +19,20 @@ ) from typer import get_app_dir +if TYPE_CHECKING: + import sys + from types import TracebackType + + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + __version__ = "0.4.6" class Constituent(BaseModel, populate_by_name=True, frozen=True): - id: Optional[int] = None + id: int | None = None name: str u_neg: Annotated[ list[float], Field(alias="uNeg") @@ -112,7 +122,7 @@ def _pka_lengths(cls, v: list[float], info: ValidationInfo) -> list[float]: @field_validator("neg_count", "pos_count", mode="before") @classmethod - def _counts(cls, v: Optional[int], info: ValidationInfo) -> int: + def _counts(cls, v: int | None, info: ValidationInfo) -> int: assert isinstance(info.field_name, str) if v is None: v = len(info.data[f"u_{info.field_name[:3]}"]) @@ -122,7 +132,7 @@ def _counts(cls, v: Optional[int], info: ValidationInfo) -> int: return v @model_validator(mode="after") - def _pkas_not_increasing(self) -> "Constituent": + def _pkas_not_increasing(self) -> Constituent: pkas = [*self.pkas_neg, *self.pkas_pos] if not all(x >= y for x, y in zip(pkas, pkas[1:])): @@ -136,7 +146,7 @@ def _pkas_not_increasing(self) -> "Constituent": def _load_constituents( - data: bytes, context: Optional[dict[str, str]] = None + data: bytes, context: dict[str, str] | None = None ) -> list[Constituent]: return _StoredConstituents.validate_json(data, context=context)["constituents"] @@ -194,7 +204,7 @@ def _save_user_constituents(self) -> None: self._user_constituents_file.write_bytes(data) self._user_constituents_dirty = False - def __enter__(self) -> "_Database": + def __enter__(self) -> Self: if not self._user_constituents_lock.is_locked: self._invalidate_user_constituents() self._user_constituents_lock.acquire() @@ -203,9 +213,9 @@ def __enter__(self) -> "_Database": def __exit__( self, - exc_type: Optional[type[BaseException]], - exc_value: Optional[BaseException], - traceback: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, ) -> None: try: if ( diff --git a/electrolytes/__main__.py b/electrolytes/__main__.py index 0fd7a34..6bb3192 100644 --- a/electrolytes/__main__.py +++ b/electrolytes/__main__.py @@ -1,4 +1,6 @@ -from typing import Annotated, Optional +from __future__ import annotations + +from typing import Annotated import typer @@ -135,7 +137,7 @@ def add( @app.command() def info( names: Annotated[ - Optional[list[str]], + list[str] | None, typer.Argument(help="Component names", autocompletion=complete_name), ] = None, ) -> None: @@ -192,7 +194,7 @@ def info( def ls( *, user: Annotated[ - Optional[bool], + bool | None, typer.Option( "--user/--default", help="List only user-defined/default components" ), @@ -218,7 +220,7 @@ def rm( ], *, force: Annotated[ - Optional[bool], typer.Option("-f", help="Ignore non-existent components") + bool | None, typer.Option("-f", help="Ignore non-existent components") ] = False, ) -> None: """Remove user-defined components from the database.""" @@ -243,7 +245,7 @@ def rm( def search( text: str, user: Annotated[ - Optional[bool], + bool | None, typer.Option( "--user/--default", help="Search only user-defined/default components" ), diff --git a/pyproject.toml b/pyproject.toml index 76dbbc2..8dfccac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,7 @@ dynamic = ["version"] lint = ["ruff"] typing = [ "mypy==1.*", + "typing-extensions>=4,<5; python_version < '3.11'", "pytest>=7,<9", ] test = [