diff --git a/pyaerocom/_lowlevel_helpers.py b/pyaerocom/_lowlevel_helpers.py index ace6258ee..4b374f56f 100644 --- a/pyaerocom/_lowlevel_helpers.py +++ b/pyaerocom/_lowlevel_helpers.py @@ -77,6 +77,7 @@ def _class_name(obj): return type(obj).__name__ +# TODO: Check to see if instances of these classes can instead use pydantic class Validator(abc.ABC): def __set_name__(self, owner, name): self._name = name @@ -113,56 +114,6 @@ def validate(self, val): return val -class StrWithDefault(Validator): - def __init__(self, default: str): - self.default = default - - def validate(self, val): - if not isinstance(val, str): - if val is None: - val = self.default - else: - raise ValueError(f"need str or None, got {val}") - return val - - -class FlexList(Validator): - """list that can be instantated via input str, tuple or list or None""" - - def validate(self, val): - if isinstance(val, str): - val = [val] - elif isinstance(val, tuple): - val = list(val) - elif val is None: - val = [] - elif not isinstance(val, list): - raise ValueError(f"failed to convert {val} to list") - return val - - -class EitherOf(Validator): - _allowed = FlexList() - - def __init__(self, allowed: list): - self._allowed = allowed - - def validate(self, val): - if not any([x == val for x in self._allowed]): - raise ValueError(f"invalid value {val}, needs to be either of {self._allowed}.") - return val - - -class ListOfStrings(FlexList): - def validate(self, val): - # make sure to have a list - val = super().validate(val) - # make sure all entries are strings - if not all([isinstance(x, str) for x in val]): - raise ValueError(f"not all items are str type in input list {val}") - return val - - class Loc(abc.ABC): """Abstract descriptor representing a path location diff --git a/pyaerocom/aeroval/aux_io_helpers.py b/pyaerocom/aeroval/aux_io_helpers.py index 04552f46b..c20937291 100644 --- a/pyaerocom/aeroval/aux_io_helpers.py +++ b/pyaerocom/aeroval/aux_io_helpers.py @@ -1,8 +1,21 @@ import importlib import os import sys +from collections.abc import Callable -from pyaerocom._lowlevel_helpers import AsciiFileLoc, ListOfStrings +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +from typing import TYPE_CHECKING + +from pydantic import ( + BaseModel, + model_validator, +) + +from pyaerocom._lowlevel_helpers import AsciiFileLoc def check_aux_info(fun, vars_required, funcs): @@ -26,11 +39,11 @@ def check_aux_info(fun, vars_required, funcs): required. """ - spec = _AuxReadSpec(fun, vars_required, funcs) + spec = _AuxReadSpec(fun=fun, vars_required=vars_required, funcs=funcs) return dict(fun=spec.fun, vars_required=spec.vars_required) -class _AuxReadSpec: +class _AuxReadSpec(BaseModel): """ Class that specifies requirements for computation of additional variables @@ -53,39 +66,22 @@ class _AuxReadSpec: """ - vars_required = ListOfStrings() - - def __init__(self, fun, vars_required: list, funcs: dict): - self.vars_required = vars_required - self.fun = self.get_func(fun, funcs) - - def get_func(self, fun, funcs): - """ - Get callable function for computation of variable - - Parameters - ---------- - fun : str or callable - Name of function or function. - funcs : dict - Dictionary with possible functions (values) and names (keys) - - Raises - ------ - ValueError - If function could not be retrieved. - - Returns - ------- - callable - callable function object. - - """ - if callable(fun): - return fun - elif isinstance(fun, str): - return funcs[fun] - raise ValueError("failed to retrieve aux func") + if TYPE_CHECKING: + fun: Callable + else: + fun: str | Callable + vars_required: list[str] + funcs: dict[str, Callable] + + @model_validator(mode="after") + def validate_fun(self) -> Self: + if callable(self.fun): + return self + elif isinstance(self.fun, str): + self.fun = self.funcs[self.fun] + return self + else: + raise ValueError("failed to retrieve aux func") class ReadAuxHandler: diff --git a/tests/aeroval/test_aux_io_helpers.py b/tests/aeroval/test_aux_io_helpers.py index cbe57e85e..3401cbb81 100644 --- a/tests/aeroval/test_aux_io_helpers.py +++ b/tests/aeroval/test_aux_io_helpers.py @@ -1,6 +1,7 @@ from pathlib import Path from textwrap import dedent +from pydantic import ValidationError from pytest import mark, param, raises from pyaerocom.aeroval.aux_io_helpers import ReadAuxHandler, check_aux_info @@ -52,20 +53,26 @@ def test_check_aux_info(fun, vars_required: list[str], funcs: dict): @mark.parametrize( - "fun,vars_required,funcs,error", + "fun,vars_required,funcs,error,", [ - param(None, [], {}, "failed to retrieve aux func", id="no func"), + param( + None, + [], + {}, + "2 validation errors for _AuxReadSpec", + id="no func", + ), param( None, [42], {}, - "not all items are str type in input list [42]", + "3 validation errors for _AuxReadSpec", id="bad type vars_required", ), ], ) def test_check_aux_info_error(fun, vars_required: list[str], funcs: dict, error: str): - with raises(ValueError) as e: + with raises(ValidationError) as e: check_aux_info(fun, vars_required, funcs) - assert str(e.value) == error + assert error in str(e.value)