Skip to content

Commit

Permalink
Apply extra Ruff rules
Browse files Browse the repository at this point in the history
  • Loading branch information
gerlero committed Nov 2, 2024
1 parent 9f07703 commit ea076aa
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 32 deletions.
46 changes: 32 additions & 14 deletions electrolytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from contextlib import ContextDecorator, suppress
from functools import cached_property, singledispatchmethod
from pathlib import Path
from typing import Annotated, Any, Optional
from types import TracebackType
from typing import Annotated, Optional
from warnings import warn

from filelock import FileLock
Expand Down Expand Up @@ -35,8 +36,8 @@ class Constituent(BaseModel, populate_by_name=True, frozen=True):
pkas_pos: Annotated[
list[float], Field(alias="pKaPos")
] = [] # [+1, +2, +3, ..., +pos_count]
neg_count: Annotated[int, Field(alias="negCount", validate_default=True)] = None # type: ignore
pos_count: Annotated[int, Field(alias="posCount", validate_default=True)] = None # type: ignore
neg_count: Annotated[int, Field(alias="negCount", validate_default=True)] = None # type: ignore [assignment]
pos_count: Annotated[int, Field(alias="posCount", validate_default=True)] = None # type: ignore [assignment]

def mobilities(self) -> Sequence[float]:
n = max(self.neg_count, self.pos_count, 3)
Expand Down Expand Up @@ -77,46 +78,56 @@ def _default_pka(charge: int) -> float:
return -charge

@field_validator("name")
@classmethod
def _normalize_db1_names(cls, v: str, info: ValidationInfo) -> str:
if info.context and info.context.get("fix", None) == "db1":
v = v.replace(" ", "_").replace("Cl-", "CHLORO")
return v

@field_validator("name")
def _no_whitespace(cls, v: str, info: ValidationInfo) -> str:
@classmethod
def _no_whitespace(cls, v: str, _: ValidationInfo) -> str:
parts = v.split()
if len(parts) > 1 or len(parts[0]) != len(v):
raise ValueError("name cannot contain any whitespace")
msg = "name cannot contain any whitespace"
raise ValueError(msg)
return parts[0]

@field_validator("name")
def _all_uppercase(cls, v: str, info: ValidationInfo) -> str:
@classmethod
def _all_uppercase(cls, v: str, _: ValidationInfo) -> str:
if not v.isupper():
raise ValueError("name must be all uppercase")
msg = "name must be all uppercase"
raise ValueError(msg)
return v

@field_validator("pkas_neg", "pkas_pos")
@classmethod
def _pka_lengths(cls, v: list[float], info: ValidationInfo) -> list[float]:
assert isinstance(info.field_name, str)
if len(v) != len(info.data[f"u_{info.field_name[5:]}"]):
raise ValueError(f"len({info.field_name}) != len(u_{info.field_name[5:]})")
msg = f"len({info.field_name}) != len(u_{info.field_name[5:]})"
raise ValueError(msg)
return v

@field_validator("neg_count", "pos_count", mode="before")
@classmethod
def _counts(cls, v: Optional[int], info: ValidationInfo) -> int:
assert isinstance(info.field_name, str)
if v is None:
v = len(info.data[f"u_{info.field_name[:3]}"])
elif v != len(info.data[f"u_{info.field_name[:3]}"]):
raise ValueError(f"{info.field_name} != len(u_{info.field_name[:3]})")
msg = f"{info.field_name} != len(u_{info.field_name[:3]})"
raise ValueError(msg)
return v

@model_validator(mode="after")
def _pkas_not_increasing(self) -> "Constituent":
pkas = [*self.pkas_neg, *self.pkas_pos]

if not all(x >= y for x, y in zip(pkas, pkas[1:])):
raise ValueError("pKa values must not increase with charge")
msg = "pKa values must not increase with charge"
raise ValueError(msg)

return self

Expand Down Expand Up @@ -148,7 +159,8 @@ def __init__(self, user_constituents_file: Path) -> None:
def _default_constituents(self) -> dict[str, Constituent]:
data = pkgutil.get_data(__package__, "db1.json")
if data is None:
raise RuntimeError("failed to load default constituents")
msg = "failed to load default constituents"
raise RuntimeError(msg)
constituents = _load_constituents(data, context={"fix": "db1"})
return {c.name: c for c in constituents}

Expand Down Expand Up @@ -189,7 +201,12 @@ def __enter__(self) -> "_Database":

return self

def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
try:
if (
self._user_constituents_lock.lock_counter == 1
Expand Down Expand Up @@ -228,7 +245,8 @@ def __delitem__(self, name: str) -> None:
self._user_constituents_dirty = True
except KeyError:
if name in self._default_constituents:
raise ValueError(f"{name}: cannot remove default component") from None
msg = f"{name}: cannot remove default component"
raise ValueError(msg) from None
raise

def __len__(self) -> int:
Expand All @@ -238,7 +256,7 @@ def user_defined(self) -> Collection[str]:
return sorted(self._user_constituents)

@singledispatchmethod
def __contains__(self, _: Any) -> bool: # type: ignore
def __contains__(self, _: object) -> bool: # type: ignore [override]
return False

@__contains__.register
Expand Down
30 changes: 17 additions & 13 deletions electrolytes/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,51 +23,52 @@ def add(
p1: Annotated[
tuple[float, float],
typer.Option("+1", help="Mobility (*1e-9) and pKa for +1", show_default=False),
] = (None, None), # type: ignore
] = (None, None), # type: ignore [assignment]
p2: Annotated[
tuple[float, float],
typer.Option("+2", help="Mobility (*1e-9) and pKa for +2", show_default=False),
] = (None, None), # type: ignore
] = (None, None), # type: ignore [assignment]
p3: Annotated[
tuple[float, float],
typer.Option("+3", help="Mobility (*1e-9) and pKa for +3", show_default=False),
] = (None, None), # type: ignore
] = (None, None), # type: ignore [assignment]
p4: Annotated[
tuple[float, float],
typer.Option("+4", help="Mobility (*1e-9) and pKa for +4", show_default=False),
] = (None, None), # type: ignore
] = (None, None), # type: ignore [assignment]
p5: Annotated[
tuple[float, float],
typer.Option("+5", help="Mobility (*1e-9) and pKa for +5", show_default=False),
] = (None, None), # type: ignore
] = (None, None), # type: ignore [assignment]
p6: Annotated[
tuple[float, float],
typer.Option("+6", help="Mobility (*1e-9) and pKa for +6", show_default=False),
] = (None, None), # type: ignore
] = (None, None), # type: ignore [assignment]
m1: Annotated[
tuple[float, float],
typer.Option("-1", help="Mobility (*1e-9) and pKa for -1", show_default=False),
] = (None, None), # type: ignore
] = (None, None), # type: ignore [assignment]
m2: Annotated[
tuple[float, float],
typer.Option("-2", help="Mobility (*1e-9) and pKa for -2", show_default=False),
] = (None, None), # type: ignore
] = (None, None), # type: ignore [assignment]
m3: Annotated[
tuple[float, float],
typer.Option("-3", help="Mobility (*1e-9) and pKa for -3", show_default=False),
] = (None, None), # type: ignore
] = (None, None), # type: ignore [assignment]
m4: Annotated[
tuple[float, float],
typer.Option("-4", help="Mobility (*1e-9) and pKa for -4", show_default=False),
] = (None, None), # type: ignore
] = (None, None), # type: ignore [assignment]
m5: Annotated[
tuple[float, float],
typer.Option("-5", help="Mobility (*1e-9) and pKa for -5", show_default=False),
] = (None, None), # type: ignore
] = (None, None), # type: ignore [assignment]
m6: Annotated[
tuple[float, float],
typer.Option("-6", help="Mobility (*1e-9) and pKa for -6", show_default=False),
] = (None, None), # type: ignore
] = (None, None), # type: ignore [assignment]
*,
force: Annotated[
bool,
typer.Option(
Expand Down Expand Up @@ -189,6 +190,7 @@ def info(

@app.command()
def ls(
*,
user: Annotated[
Optional[bool],
typer.Option(
Expand All @@ -214,6 +216,7 @@ def rm(
names: Annotated[
list[str], typer.Argument(autocompletion=complete_name_user_defined)
],
*,
force: Annotated[
Optional[bool], typer.Option("-f", help="Ignore non-existent components")
] = False,
Expand Down Expand Up @@ -267,14 +270,15 @@ def search(
)


def version_callback(show: bool) -> None:
def version_callback(*, show: bool) -> None:
if show:
typer.echo(f"{__package__} {__version__}")
raise typer.Exit


@app.callback()
def common(
*,
version: Annotated[
bool,
typer.Option(
Expand Down
3 changes: 1 addition & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pytest

import electrolytes
import pytest
from electrolytes import Constituent, database


Expand Down
5 changes: 2 additions & 3 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import contextlib

import pytest
from typer.testing import CliRunner

import electrolytes
import pytest
from electrolytes import Constituent, database
from electrolytes.__main__ import app
from typer.testing import CliRunner

runner = CliRunner()

Expand Down

0 comments on commit ea076aa

Please sign in to comment.