Skip to content

Commit

Permalink
Improves ClassList integration with Pydantic and adds JSON conversi…
Browse files Browse the repository at this point in the history
…on for `Project`s (RascalSoftware#79)

* made ClassList generic and added Pydantic validation schema

* added coercion for lists

* added JSON conversion for Projects

* moved pydantic imports into core schema generator

* fixed and improved write_script and its test

* review fixes

* review fixes
  • Loading branch information
alexhroom authored Oct 7, 2024
1 parent 072f954 commit aff6405
Show file tree
Hide file tree
Showing 9 changed files with 649 additions and 217 deletions.
97 changes: 73 additions & 24 deletions RATapi/classlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@
import contextlib
import warnings
from collections.abc import Sequence
from typing import Any, Union
from typing import Any, Generic, TypeVar, Union

import numpy as np
import prettytable

T = TypeVar("T")

class ClassList(collections.UserList):

class ClassList(collections.UserList, Generic[T]):
"""List of instances of a particular class.
This class subclasses collections.UserList to construct a list intended to store ONLY instances of a particular
Expand All @@ -30,14 +32,14 @@ class ClassList(collections.UserList):
Parameters
----------
init_list : Sequence [object] or object, optional
init_list : Sequence [T] or T, optional
An instance, or list of instance(s), of the class to be used in this ClassList.
name_field : str, optional
The field used to define unique objects in the ClassList (default is "name").
"""

def __init__(self, init_list: Union[Sequence[object], object] = None, name_field: str = "name") -> None:
def __init__(self, init_list: Union[Sequence[T], T] = None, name_field: str = "name") -> None:
self.name_field = name_field

# Set input as list if necessary
Expand Down Expand Up @@ -81,20 +83,20 @@ def __str__(self):
output = str(self.data)
return output

def __getitem__(self, index: Union[int, slice, str, object]) -> object:
def __getitem__(self, index: Union[int, slice, str, T]) -> T:
"""Get an item by its index, name, a slice, or the object itself."""
if isinstance(index, (int, slice)):
return self.data[index]
elif isinstance(index, (str, object)):
elif isinstance(index, (str, self._class_handle)):
return self.data[self.index(index)]
else:
raise IndexError("ClassLists can only be indexed by integers, slices, name strings, or objects.")

def __setitem__(self, index: int, item: object) -> None:
def __setitem__(self, index: int, item: T) -> None:
"""Replace the object at an existing index of the ClassList."""
self._setitem(index, item)

def _setitem(self, index: int, item: object) -> None:
def _setitem(self, index: int, item: T) -> None:
"""Auxiliary routine of "__setitem__" used to enable wrapping."""
self._check_classes([item])
self._check_unique_name_fields([item])
Expand All @@ -108,11 +110,11 @@ def _delitem(self, index: int) -> None:
"""Auxiliary routine of "__delitem__" used to enable wrapping."""
del self.data[index]

def __iadd__(self, other: Sequence[object]) -> "ClassList":
def __iadd__(self, other: Sequence[T]) -> "ClassList":
"""Define in-place addition using the "+=" operator."""
return self._iadd(other)

def _iadd(self, other: Sequence[object]) -> "ClassList":
def _iadd(self, other: Sequence[T]) -> "ClassList":
"""Auxiliary routine of "__iadd__" used to enable wrapping."""
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
other = [other]
Expand All @@ -135,13 +137,13 @@ def __imul__(self, n: int) -> None:
"""Define in-place multiplication using the "*=" operator."""
raise TypeError(f"unsupported operand type(s) for *=: '{self.__class__.__name__}' and '{n.__class__.__name__}'")

def append(self, obj: object = None, **kwargs) -> None:
def append(self, obj: T = None, **kwargs) -> None:
"""Append a new object to the ClassList using either the object itself, or keyword arguments to set attribute
values.
Parameters
----------
obj : object, optional
obj : T, optional
An instance of the class specified by self._class_handle.
**kwargs : dict[str, Any], optional
The input keyword arguments for a new object in the ClassList.
Expand Down Expand Up @@ -180,15 +182,15 @@ def append(self, obj: object = None, **kwargs) -> None:
self._validate_name_field(kwargs)
self.data.append(self._class_handle(**kwargs))

def insert(self, index: int, obj: object = None, **kwargs) -> None:
def insert(self, index: int, obj: T = None, **kwargs) -> None:
"""Insert a new object into the ClassList at a given index using either the object itself, or keyword arguments
to set attribute values.
Parameters
----------
index: int
The index at which to insert a new object in the ClassList.
obj : object, optional
obj : T, optional
An instance of the class specified by self._class_handle.
**kwargs : dict[str, Any], optional
The input keyword arguments for a new object in the ClassList.
Expand Down Expand Up @@ -227,26 +229,26 @@ def insert(self, index: int, obj: object = None, **kwargs) -> None:
self._validate_name_field(kwargs)
self.data.insert(index, self._class_handle(**kwargs))

def remove(self, item: Union[object, str]) -> None:
def remove(self, item: Union[T, str]) -> None:
"""Remove an object from the ClassList using either the object itself or its name_field value."""
item = self._get_item_from_name_field(item)
self.data.remove(item)

def count(self, item: Union[object, str]) -> int:
def count(self, item: Union[T, str]) -> int:
"""Return the number of times an object appears in the ClassList using either the object itself or its
name_field value.
"""
item = self._get_item_from_name_field(item)
return self.data.count(item)

def index(self, item: Union[object, str], offset: bool = False, *args) -> int:
def index(self, item: Union[T, str], offset: bool = False, *args) -> int:
"""Return the index of a particular object in the ClassList using either the object itself or its
name_field value. If offset is specified, add one to the index. This is used to account for one-based indexing.
"""
item = self._get_item_from_name_field(item)
return self.data.index(item, *args) + int(offset)

def extend(self, other: Sequence[object]) -> None:
def extend(self, other: Sequence[T]) -> None:
"""Extend the ClassList by adding another sequence."""
if other and not (isinstance(other, Sequence) and not isinstance(other, str)):
other = [other]
Expand Down Expand Up @@ -319,7 +321,7 @@ def _validate_name_field(self, input_args: dict[str, Any]) -> None:
f"which is already specified at index {names.index(name)} of the ClassList",
)

def _check_unique_name_fields(self, input_list: Sequence[object]) -> None:
def _check_unique_name_fields(self, input_list: Sequence[T]) -> None:
"""Raise a ValueError if any value of the name_field attribute is used more than once in a list of class
objects.
Expand Down Expand Up @@ -376,7 +378,7 @@ def _check_unique_name_fields(self, input_list: Sequence[object]) -> None:
f"{newline.join(error for error in error_list)}"
)

def _check_classes(self, input_list: Sequence[object]) -> None:
def _check_classes(self, input_list: Sequence[T]) -> None:
"""Raise a ValueError if any object in a list of objects is not of the type specified by self._class_handle.
Parameters
Expand All @@ -401,17 +403,17 @@ def _check_classes(self, input_list: Sequence[object]) -> None:
f"In the input list:\n{newline.join(error for error in error_list)}\n"
)

def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object, str]:
def _get_item_from_name_field(self, value: Union[T, str]) -> Union[T, str]:
"""Return the object with the given value of the name_field attribute in the ClassList.
Parameters
----------
value : object or str
value : T or str
Either an object in the ClassList, or the value of the name_field attribute of an object in the ClassList.
Returns
-------
instance : object or str
instance : T or str
Either the object with the value of the name_field attribute given by value, or the input value if an
object with that value of the name_field attribute cannot be found.
Expand All @@ -424,7 +426,7 @@ def _get_item_from_name_field(self, value: Union[object, str]) -> Union[object,
return next((model for model in self.data if getattr(model, self.name_field).lower() == lower_value), value)

@staticmethod
def _determine_class_handle(input_list: Sequence[object]):
def _determine_class_handle(input_list: Sequence[T]):
"""When inputting a sequence of object to a ClassList, the _class_handle should be set as the type of the
element which satisfies "issubclass" for all the other elements.
Expand All @@ -448,3 +450,50 @@ def _determine_class_handle(input_list: Sequence[object]):
class_handle = type(input_list[0])

return class_handle

# Pydantic core schema which allows ClassLists to be validated
# in short: it validates that each ClassList is indeed a ClassList,
# and then validates ClassList.data as though it were a typed list
# e.g. ClassList[str] data is validated like list[str]
@classmethod
def __get_pydantic_core_schema__(cls, source: Any, handler):
# import here so that the ClassList can be instantiated and used without Pydantic installed
from pydantic import ValidatorFunctionWrapHandler
from pydantic.types import (
core_schema, # import core_schema through here rather than making pydantic_core a dependency
)
from typing_extensions import get_args, get_origin

# if annotated with a class, get the item type of that class
origin = get_origin(source)
item_tp = Any if origin is None else get_args(source)[0]

list_schema = handler.generate_schema(list[item_tp])

def coerce(v: Any, handler: ValidatorFunctionWrapHandler) -> ClassList[T]:
"""If a sequence is given, try to coerce it to a ClassList."""
if isinstance(v, Sequence):
classlist = ClassList()
if len(v) > 0 and isinstance(v[0], dict):
# we want to be OK if the type is a model and is passed as a dict;
# pydantic will coerce it or fall over later
classlist._class_handle = dict
elif item_tp is not Any:
classlist._class_handle = item_tp
classlist.extend(v)
v = classlist
v = handler(v)
return v

def validate_items(v: ClassList[T], handler: ValidatorFunctionWrapHandler) -> ClassList[T]:
v.data = handler(v.data)
return v

schema = core_schema.chain_schema(
[
core_schema.no_info_wrap_validator_function(coerce, core_schema.is_instance_schema(cls)),
core_schema.no_info_wrap_validator_function(validate_items, list_schema),
],
)

return schema
26 changes: 24 additions & 2 deletions RATapi/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""The models module. Contains the pydantic models used by RAT to store project parameters."""

import pathlib
from typing import Any, Union
from typing import Any

import numpy as np
import prettytable
Expand Down Expand Up @@ -91,6 +91,18 @@ class Contrast(RATModel):
resample: bool = False
model: list[str] = []

@model_validator(mode="before")
@classmethod
def domain_ratio_error(cls, data: Any):
"""If the extra input 'domain_ratio' is given, give a more descriptive error."""

if isinstance(data, dict) and data.get("domain_ratio", False):
raise ValueError(
"The Contrast class does not support domain ratios. Use the ContrastWithRatio class instead."
)

return data

def __str__(self):
table = prettytable.PrettyTable()
table.field_names = [key.replace("_", " ") for key in self.__dict__]
Expand Down Expand Up @@ -155,7 +167,7 @@ class CustomFile(RATModel):
filename: str = ""
function_name: str = ""
language: Languages = Languages.Python
path: Union[str, pathlib.Path] = ""
path: pathlib.Path = pathlib.Path(".")

def model_post_init(self, __context: Any) -> None:
"""If a "filename" is supplied but the "function_name" field is not set, the "function_name" should be set to
Expand Down Expand Up @@ -291,6 +303,16 @@ class Layer(RATModel, populate_by_name=True):
hydration: str = ""
hydrate_with: Hydration = Hydration.BulkOut

@model_validator(mode="before")
@classmethod
def sld_imaginary_error(cls, data: Any):
"""If the extra input 'sld_imaginary' is given, give a more descriptive error."""

if isinstance(data, dict) and data.get("SLD_imaginary", False):
raise ValueError("The Layer class does not support imaginary SLD. Use the AbsorptionLayer class instead.")

return data


class AbsorptionLayer(RATModel, populate_by_name=True):
"""Combines parameters into defined layers including absorption terms."""
Expand Down
Loading

0 comments on commit aff6405

Please sign in to comment.