Skip to content

Commit

Permalink
Bumping pydantic >= 2.0.0 (#247)
Browse files Browse the repository at this point in the history
* chore: bump pydantic to >2.0

* chore: black

* chore: mypy checks

* chore : remove duplication of options/examples

* chore: add pydantic_settings

* fix: pydantic error

* chore: mypy

* fix: mock on subclass __parameters__

* chore

* fix: tests for registry

* chore: black

* fix: ignore mypy

* chore: black

* fix: move deps from dev to main requirements file

* fix: absolute imports before relative

---------

Co-authored-by: fiskrt <[email protected]>
  • Loading branch information
jannisborn and fiskrt authored Jun 13, 2024
1 parent daae05b commit 7dc926d
Show file tree
Hide file tree
Showing 12 changed files with 52 additions and 70 deletions.
3 changes: 1 addition & 2 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,4 @@ pytest==6.2.5
pytest-cov==2.10.1
sphinx>=5
sphinx-autodoc-typehints==1.11.1
jinja2<3.1.0
sphinx_rtd_theme==0.5.1
jinja2<3.1.0
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ numpy>=1.16.5,<1.24.0
pandas<=2.0.3
protobuf<3.20
pyarrow>=8.0.0
pydantic>=1.7.3,<2.0.0
pydantic>=2.0.0
pymatgen>=2022.11.7
PyTDC==0.3.7
pytorch_lightning<=1.7.7
Expand Down Expand Up @@ -49,3 +49,5 @@ transformers>=4.22.0,<=4.24.0
typing_extensions>=3.7.4.3
wheel>=0.26
xgboost>=1.7.6
sphinx_rtd_theme==0.5.1
pydantic-settings>=2.0.0
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -290,4 +290,7 @@ ignore_missing_imports = True
ignore_missing_imports = True

[mypy-xgboost.*]
ignore_missing_imports = True

[mypy-pydantic_settings.*]
ignore_missing_imports = True
2 changes: 1 addition & 1 deletion src/gt4sd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#
"""Module initialization."""

__version__ = "1.4.1"
__version__ = "1.4.2"
__name__ = "gt4sd"

# NOTE: configure SSL to allow unverified contexts by default
Expand Down
2 changes: 0 additions & 2 deletions src/gt4sd/algorithms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,8 +912,6 @@ def get_configuration_class_with_attributes(


class PropertyPredictor(ABC, Generic[S, U]):
"""TODO: Might be deprecated in future release."""

def __init__(self, context: U) -> None:
"""Property predictor to investigate items.
Expand Down
14 changes: 13 additions & 1 deletion src/gt4sd/algorithms/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,17 @@
from dataclasses import dataclass as vanilla_dataclass
from dataclasses import field, make_dataclass
from functools import WRAPPER_ASSIGNMENTS, update_wrapper
from typing import Any, Callable, ClassVar, Dict, List, NamedTuple, Optional, Type
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
NamedTuple,
Optional,
Type,
TypeVar,
)

import pydantic

Expand Down Expand Up @@ -185,6 +195,8 @@ def decorator(
),
], # type: ignore
)
# NOTE: Needed to circumvent a pydantic TypeError: Parameter list to Generic[...] cannot be empty
VanillaConfiguration.__parameters__ = (TypeVar("T"),) # type: ignore
# NOTE: Duplicate call necessary for pydantic >=1.10.* - see https://github.com/pydantic/pydantic/issues/4695
PydanticConfiguration: Type[AlgorithmConfiguration] = dataclass( # type: ignore
VanillaConfiguration
Expand Down
31 changes: 4 additions & 27 deletions src/gt4sd/algorithms/tests/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,15 @@ def test_list_available_local_via_S3SyncError(mock_wrong_s3_env):

def test_inherited_validation():
Config = next(iter(ApplicationsRegistry.applications.values())).configuration_class
with pytest.raises(
ValidationError, match="algorithm_version\n +none is not an allowed value"
):
with pytest.raises(ValidationError, match="should be a valid string"):
Config(algorithm_version=None) # type: ignore

# NOTE: values convertible to string will not raise!
Config(algorithm_version=5) # type: ignore
with pytest.raises(ValidationError, match="should be a valid string"):
Config(algorithm_version=5) # type: ignore


def test_validation():
with pytest.raises(
ValidationError, match="batch_size\n +value is not a valid integer"
):
with pytest.raises(ValidationError, match="should be a valid integer"):
ApplicationsRegistry.get_configuration_instance(
algorithm_type="conditional_generation",
domain="materials",
Expand All @@ -80,25 +76,6 @@ def test_validation():
)


def test_pickable_wrapped_configurations():
# https://github.com/samuelcolvin/pydantic/issues/2111
Config = next(iter(ApplicationsRegistry.applications.values())).configuration_class
restored_obj = assert_pickable(Config(algorithm_version="test"))

# wrong type assignment, but we did not configure it to raise here:
restored_obj.algorithm_version = object
# ensure the restored dataclass is still a pydantic dataclass (mimic validation)
_, optional_errors = restored_obj.__pydantic_model__.__fields__.get(
"algorithm_version"
).validate(
restored_obj.algorithm_version,
restored_obj.__dict__,
loc="algorithm_version",
cls=restored_obj.__class__,
)
assert optional_errors is not None


def test_multiple_registration():
class OtherAlgorithm(GeneratorAlgorithm):
pass
Expand Down
8 changes: 2 additions & 6 deletions src/gt4sd/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@
import os
from functools import lru_cache
from typing import Dict, Optional, Set

from pydantic import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict

from .s3 import GT4SDS3Client, S3SyncError, sync_folder_with_s3, upload_file_to_s3

Expand Down Expand Up @@ -65,10 +64,7 @@ class GT4SDConfiguration(BaseSettings):
gt4sd_s3_secure_hub: bool = True
gt4sd_s3_bucket_hub_algorithms: str = "gt4sd-cos-hub-algorithms-artifacts"
gt4sd_s3_bucket_hub_properties: str = "gt4sd-cos-hub-properties-artifacts"

class Config:
# immutable and in turn hashable, that is required for lru_cache
frozen = True
model_config = SettingsConfigDict(frozen=True)

@staticmethod
@lru_cache(maxsize=None)
Expand Down
14 changes: 8 additions & 6 deletions src/gt4sd/properties/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,21 @@ class S3Parameters(PropertyPredictorParameters):
algorithm_type: str = "prediction"

domain: DomainSubmodule = Field(
..., example="molecules", description="Submodule of gt4sd.properties"
..., examples=["molecules"], description="Submodule of gt4sd.properties"
)
algorithm_name: str = Field(
..., examples=["MCA"], description="Name of the algorithm"
)
algorithm_name: str = Field(..., example="MCA", description="Name of the algorithm")
algorithm_version: str = Field(
..., example="v0", description="Version of the algorithm"
..., examples=["v0"], description="Version of the algorithm"
)
algorithm_application: str = Field(..., example="Tox21")
algorithm_application: str = Field(..., examples=["Tox21"])


class ApiTokenParameters(PropertyPredictorParameters):
api_token: str = Field(
...,
example="apk-c9db......",
examples=["apk-c9db......"],
description="The API token/key to access the service",
)

Expand All @@ -68,7 +70,7 @@ class IpAdressParameters(PropertyPredictorParameters):

host_ip: str = Field(
...,
example="xx.xx.xxx.xxx",
examples=["xx.xx.xxx.xxx"],
description="The host IP address to access the service",
)

Expand Down
30 changes: 13 additions & 17 deletions src/gt4sd/properties/molecules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from paccmann_generator.drug_evaluators import OrganDB as _OrganTox
from paccmann_generator.drug_evaluators import SCScore
from paccmann_generator.drug_evaluators import Tox21 as _Tox21
from pydantic import Field
from pydantic import ConfigDict, Field
from tdc import Oracle
from tdc.metadata import download_receptor_oracle_name

Expand Down Expand Up @@ -119,12 +119,12 @@ class ScscoreConfiguration(PropertyPredictorParameters):


class SimilaritySeedParameters(PropertyPredictorParameters):
smiles: str = Field(..., example="c1ccccc1")
smiles: str = Field(..., examples=["c1ccccc1"])
fp_key: str = "ECFP4"


class ActivityAgainstTargetParameters(PropertyPredictorParameters):
target: str = Field(..., example="drd2", description="name of the target.")
target: str = Field(..., examples=["drd2"], description="name of the target.")


class AskcosParameters(IpAdressParameters):
Expand All @@ -136,9 +136,8 @@ class Output(str, Enum):

output: Output = Field(
default=Output.plausability,
example=Output.synthesizability,
examples=[Output.synthesizability],
description="Main output return type from ASKCOS",
options=["plausibility", "num_step", "synthesizability", "price"],
)
save_json: bool = Field(default=False)
file_name: str = Field(default="tree_builder_result.json")
Expand All @@ -159,10 +158,7 @@ class Output(str, Enum):
min_chempop_products: int = Field(default=5)
filter_threshold: float = Field(default=0.1)
return_first: str = Field(default="true")

# Convert enum items back to strings
class Config:
use_enum_values = True
model_config = ConfigDict(use_enum_values=True)


class MoleculeOneParameters(ApiTokenParameters):
Expand All @@ -174,22 +170,23 @@ class DockingTdcParameters(PropertyPredictorParameters):
# To dock against a receptor defined via TDC
target: str = Field(
...,
example="1iep_docking",
examples=download_receptor_oracle_name,
description="Target for docking, provided via TDC",
options=download_receptor_oracle_name,
)


class DockingParameters(PropertyPredictorParameters):
# To dock against a user-provided receptor
name: str = Field(default="pyscreener")
receptor_pdb_file: str = Field(
example="/tmp/2hbs.pdb", description="Path to receptor PDB file"
examples=["/tmp/2hbs.pdb"], description="Path to receptor PDB file"
)
box_center: List[int] = Field(
example=[15.190, 53.903, 16.917], description="Docking box center"
examples=[[15.190, 53.903, 16.917]], description="Docking box center"
)
box_size: List[float] = Field(
examples=[[20, 20, 20]], description="Docking box size"
)
box_size: List[float] = Field(example=[20, 20, 20], description="Docking box size")


class S3ParametersMolecules(S3Parameters):
Expand Down Expand Up @@ -265,14 +262,13 @@ class ToxType(str, Enum):
algorithm_application: str = "OrganTox"
site: Organs = Field(
...,
example=Organs.kidney,
examples=[Organs.kidney],
description="name of the target site of interest.",
)
toxicity_type: ToxType = Field(
default=ToxType.all,
example=ToxType.chronic,
examples=[ToxType.chronic],
description="type of toxicity for which predictions are made.",
options=["chronic", "subchronic", "multigenerational", "all"],
)


Expand Down
4 changes: 2 additions & 2 deletions src/gt4sd/properties/proteins/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
class AmideConfiguration(PropertyPredictorParameters):
amide: bool = Field(
False,
example=False,
examples=[False],
description="whether the sequences are C-terminally amidated.",
)

Expand All @@ -58,7 +58,7 @@ class PhConfiguration(PropertyPredictorParameters):
class AmidePhConfiguration(PropertyPredictorParameters):
amide: bool = Field(
False,
example=False,
examples=[False],
description="whether the sequences are C-terminally amidated.",
)
ph: float = 7.0
Expand Down
7 changes: 2 additions & 5 deletions src/gt4sd/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from pathlib import PosixPath

import importlib_resources
from pydantic import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict


class GT4SDTestSettings(BaseSettings):
Expand All @@ -40,10 +40,7 @@ class GT4SDTestSettings(BaseSettings):
gt4sd_s3_secret_key: str = "5748375c761a4f09c30a68cd15e218e3b27ca3e2aebd7726"
gt4sd_s3_secure: bool = True
gt4sd_ci: bool = False

class Config:
# immutable and in turn hashable, that is required for lru_cache
frozen = True
model_config = SettingsConfigDict(frozen=True)

@staticmethod
@lru_cache(maxsize=None)
Expand Down

0 comments on commit 7dc926d

Please sign in to comment.