Skip to content

Commit

Permalink
Fix mypy errors
Browse files Browse the repository at this point in the history
  • Loading branch information
gerlero committed Oct 9, 2023
1 parent 9ef6b0b commit 8258fc0
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
14 changes: 8 additions & 6 deletions electrolytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from contextlib import ContextDecorator
from warnings import warn

from pydantic import BaseModel, Field, field_validator, FieldValidationInfo, model_validator, TypeAdapter
from pydantic import BaseModel, Field, field_validator, ValidationInfo, model_validator, TypeAdapter
from filelock import FileLock
from typer import get_app_dir

Expand Down Expand Up @@ -71,32 +71,34 @@ def _default_pka(charge: int) -> float:


@field_validator("name")
def _normalize_db1_names(cls, v: str, info: FieldValidationInfo) -> str:
def _normalize_db1_names(cls, v: str, info: ValidationInfo) -> str:
if info.context and info.context.get("fix", None) == "db1":
v = v.replace(" ", "_").replace("Cl-", "CHLORO")
return v

@field_validator("name")
def _no_whitespace(cls, v: str, info: FieldValidationInfo) -> str:
def _no_whitespace(cls, v: str, info: ValidationInfo) -> str:
parts = v.split()
if len(parts) > 1 or len(parts[0]) != len(v):
raise ValueError("name cannot contain any whitespace")
return parts[0]

@field_validator("name")
def _all_uppercase(cls, v: str, info: FieldValidationInfo) -> str:
def _all_uppercase(cls, v: str, info: ValidationInfo) -> str:
if not v.isupper():
raise ValueError("name must be all uppercase")
return v

@field_validator("pkas_neg", "pkas_pos")
def _pka_lengths(cls, v: List[float], info: FieldValidationInfo) -> List[float]:
def _pka_lengths(cls, v: List[float], info: ValidationInfo) -> List[float]:
assert isinstance(info.field_name, str)
if len(v) != len(info.data[f"u_{info.field_name[5:]}"]):
raise ValueError(f"len({info.field_name}) != len(u_{info.field_name[5:]})")
return v

@field_validator("neg_count", "pos_count", mode="before")
def _counts(cls, v: Optional[int], info: FieldValidationInfo) -> int:
def _counts(cls, v: Optional[int], info: ValidationInfo) -> int:
assert isinstance(info.field_name, str)
if v is None:
v = len(info.data[f"u_{info.field_name[:3]}"])
elif v != len(info.data[f"u_{info.field_name[:3]}"]):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dynamic = ["version"]
[project.optional-dependencies]
lint = [
"mypy==1.*",
"pytest==7.*",
]
test = [
"pytest==7.*",
Expand Down

0 comments on commit 8258fc0

Please sign in to comment.