Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lowlevel helpers clean up #1410

Merged
merged 7 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 1 addition & 50 deletions pyaerocom/_lowlevel_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def _class_name(obj):
return type(obj).__name__


# TODO: Check to see if instances of these classes can instead use pydantic
class Validator(abc.ABC):
def __set_name__(self, owner, name):
self._name = name
Expand Down Expand Up @@ -113,56 +114,6 @@ def validate(self, val):
return val


class StrWithDefault(Validator):
def __init__(self, default: str):
self.default = default

def validate(self, val):
if not isinstance(val, str):
if val is None:
val = self.default
else:
raise ValueError(f"need str or None, got {val}")
return val


class FlexList(Validator):
"""list that can be instantated via input str, tuple or list or None"""

def validate(self, val):
if isinstance(val, str):
val = [val]
elif isinstance(val, tuple):
val = list(val)
elif val is None:
val = []
elif not isinstance(val, list):
raise ValueError(f"failed to convert {val} to list")
return val


class EitherOf(Validator):
_allowed = FlexList()

def __init__(self, allowed: list):
self._allowed = allowed

def validate(self, val):
if not any([x == val for x in self._allowed]):
raise ValueError(f"invalid value {val}, needs to be either of {self._allowed}.")
return val


class ListOfStrings(FlexList):
def validate(self, val):
# make sure to have a list
val = super().validate(val)
# make sure all entries are strings
if not all([isinstance(x, str) for x in val]):
raise ValueError(f"not all items are str type in input list {val}")
return val


class Loc(abc.ABC):
"""Abstract descriptor representing a path location

Expand Down
58 changes: 22 additions & 36 deletions pyaerocom/aeroval/aux_io_helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import importlib
import os
import sys
from collections.abc import Callable

from pyaerocom._lowlevel_helpers import AsciiFileLoc, ListOfStrings
from pydantic import (
BaseModel,
model_validator,
)

from pyaerocom._lowlevel_helpers import AsciiFileLoc


def check_aux_info(fun, vars_required, funcs):
Expand All @@ -26,11 +32,11 @@
required.

"""
spec = _AuxReadSpec(fun, vars_required, funcs)
spec = _AuxReadSpec(fun=fun, vars_required=vars_required, funcs=funcs)
return dict(fun=spec.fun, vars_required=spec.vars_required)


class _AuxReadSpec:
class _AuxReadSpec(BaseModel):
"""
Class that specifies requirements for computation of additional variables

Expand All @@ -53,39 +59,19 @@

"""

vars_required = ListOfStrings()

def __init__(self, fun, vars_required: list, funcs: dict):
self.vars_required = vars_required
self.fun = self.get_func(fun, funcs)

def get_func(self, fun, funcs):
"""
Get callable function for computation of variable

Parameters
----------
fun : str or callable
Name of function or function.
funcs : dict
Dictionary with possible functions (values) and names (keys)

Raises
------
ValueError
If function could not be retrieved.

Returns
-------
callable
callable function object.

"""
if callable(fun):
return fun
elif isinstance(fun, str):
return funcs[fun]
raise ValueError("failed to retrieve aux func")
fun: str | Callable
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to use something like this to make fun typehint as str only during typechecking, since that's what we expect it to be after validation? I think this would work better with tools such as mypy and code autocompletion.

vars_required: list[str]
funcs: dict[str, Callable]

@model_validator(mode="after")
def validate_fun(self) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return type should be _AuxReadSpec

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can return Self

if callable(self.fun):
return self
elif isinstance(self.fun, str):
self.fun = self.funcs[self.fun]
return self
else:
raise ValueError("failed to retrieve aux func")

Check warning on line 74 in pyaerocom/aeroval/aux_io_helpers.py

View check run for this annotation

Codecov / codecov/patch

pyaerocom/aeroval/aux_io_helpers.py#L74

Added line #L74 was not covered by tests


class ReadAuxHandler:
Expand Down
17 changes: 12 additions & 5 deletions tests/aeroval/test_aux_io_helpers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path
from textwrap import dedent

from pydantic import ValidationError
from pytest import mark, param, raises

from pyaerocom.aeroval.aux_io_helpers import ReadAuxHandler, check_aux_info
Expand Down Expand Up @@ -52,20 +53,26 @@ def test_check_aux_info(fun, vars_required: list[str], funcs: dict):


@mark.parametrize(
"fun,vars_required,funcs,error",
"fun,vars_required,funcs,error,",
[
param(None, [], {}, "failed to retrieve aux func", id="no func"),
param(
None,
[],
{},
"2 validation errors for _AuxReadSpec",
id="no func",
),
param(
None,
[42],
{},
"not all items are str type in input list [42]",
"3 validation errors for _AuxReadSpec",
id="bad type vars_required",
),
],
)
def test_check_aux_info_error(fun, vars_required: list[str], funcs: dict, error: str):
with raises(ValueError) as e:
with raises(ValidationError) as e:
check_aux_info(fun, vars_required, funcs)

assert str(e.value) == error
assert error in str(e.value)
Loading