Skip to content

Commit

Permalink
Merge pull request #1407 from metno/collections
Browse files Browse the repository at this point in the history
Make Aeroval Collections Composite patterns
  • Loading branch information
lewisblake authored Nov 18, 2024
2 parents b15dadf + d68a9f8 commit 7b6616d
Show file tree
Hide file tree
Showing 11 changed files with 228 additions and 130 deletions.
191 changes: 111 additions & 80 deletions pyaerocom/aeroval/collections.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,76 @@
import abc
from fnmatch import fnmatch
import json

from pyaerocom._lowlevel_helpers import BrowseDict
from pyaerocom.aeroval.modelentry import ModelEntry
from pyaerocom.aeroval.obsentry import ObsEntry
from pyaerocom.exceptions import EntryNotAvailable, EvalEntryNameError
from pyaerocom.exceptions import EntryNotAvailable


class BaseCollection(BrowseDict, abc.ABC):
#: maximum length of entry names
MAXLEN_KEYS = 25
#: Invalid chars in entry names
FORBIDDEN_CHARS_KEYS = []
class BaseCollection(abc.ABC):
def __init__(self):
"""
Initialize an instance of BaseCollection.
The instance maintains a dictionary of entries.
"""
self._entries = {}

# TODO: Wait a few release cycles after v0.23.0 and see if this can be removed
def _check_entry_name(self, key):
if any([x in key for x in self.FORBIDDEN_CHARS_KEYS]):
raise EvalEntryNameError(
f"Invalid name: {key}. Must not contain any of the following "
f"characters: {self.FORBIDDEN_CHARS_KEYS}"
)
def __iter__(self):
"""
Iterates over each entry in the collection.
def __setitem__(self, key, value):
self._check_entry_name(key)
super().__setitem__(key, value)
Yields
------
object
The next entry in the collection.
"""
yield from self._entries.values()

def keylist(self, name_or_pattern: str = None) -> list:
"""Find model names that match input search pattern(s)
@abc.abstractmethod
def add_entry(self, key, value) -> None:
"""
Abstract method to add an entry to the collection.
Parameters
----------
key: Hashable
The key of the entry.
value: object
The value of the entry.
"""
pass

@abc.abstractmethod
def remove_entry(self, key) -> None:
"""
Abstract method to remove an entry from the collection.
Parameters
----------
key: Hashable
The key of the entry to be removed.
"""
pass

@abc.abstractmethod
def get_entry(self, key) -> object:
"""
Abstract method to get an entry from the collection.
Parameters
----------
key: Hashable
The key of the entry to retrieve.
Returns
-------
object
The entry associated with the provided key.
"""
pass

def keylist(self, name_or_pattern: str = None) -> list[str]:
"""Find model / obs names that match input search pattern(s)
Parameters
----------
Expand All @@ -48,39 +92,34 @@ def keylist(self, name_or_pattern: str = None) -> list:
name_or_pattern = "*"

matches = []
for key in self.keys():
for key in self._entries.keys():
if fnmatch(key, name_or_pattern) and key not in matches:
matches.append(key)
if len(matches) == 0:
raise KeyError(f"No matches could be found that match input {name_or_pattern}")
return matches

@abc.abstractmethod
def get_entry(self, key) -> object:
"""
Getter for eval entries
Raises
------
KeyError
if input name is not in this collection
"""
pass

@property
@abc.abstractmethod
def web_interface_names(self) -> list:
"""
List of webinterface names for
List of web interface names for each obs entry
Returns
-------
list
"""
pass
return self.keylist()

def to_json(self) -> str:
"""Serialize ModelCollection to a JSON string."""
return json.dumps({k: v.dict() for k, v in self._entries.items()}, default=str)


class ObsCollection(BaseCollection):
"""
Dict-like object that represents a collection of obs entries
Object that represents a collection of obs entries
Keys are obs names, values are instances of :class:`ObsEntry`. Values can
"Keys" are obs names, values are instances of :class:`ObsEntry`. Values can
also be assigned as dict and will automatically be converted into
instances of :class:`ObsEntry`.
Expand All @@ -93,9 +132,16 @@ class ObsCollection(BaseCollection):
"""

SETTER_CONVERT = {dict: ObsEntry}
def add_entry(self, key: str, entry: dict | ObsEntry):
if isinstance(entry, dict):
entry = ObsEntry(**entry)
self._entries[key] = entry

def get_entry(self, key) -> object:
def remove_entry(self, key: str):
if key in self._entries:
del self._entries[key]

def get_entry(self, key: str) -> ObsEntry:
"""
Getter for obs entries
Expand All @@ -105,7 +151,7 @@ def get_entry(self, key) -> object:
if input name is not in this collection
"""
try:
entry = self[key]
entry = self._entries[key]
entry.obs_name = self.get_web_interface_name(key)
return entry
except (KeyError, AttributeError):
Expand All @@ -122,11 +168,11 @@ def get_all_vars(self) -> list[str]:
"""
vars = []
for ocfg in self.values():
for ocfg in self._entries.values():
vars.extend(ocfg.get_all_vars())
return sorted(list(set(vars)))

def get_web_interface_name(self, key):
def get_web_interface_name(self, key: str) -> str:
"""
Get webinterface name for entry
Expand All @@ -147,7 +193,12 @@ def get_web_interface_name(self, key):
corresponding name
"""
return self[key].web_interface_name if self[key].web_interface_name is not None else key
entry = self._entries.get(key)
return (
entry.web_interface_name
if entry is not None and entry.web_interface_name is not None
else key
)

@property
def web_interface_names(self) -> list:
Expand All @@ -163,45 +214,38 @@ def web_interface_names(self) -> list:
@property
def all_vert_types(self):
"""List of unique vertical types specified in this collection"""
return list({x.obs_vert_type for x in self.values()})
return list({x.obs_vert_type for x in self._entries.values()})


class ModelCollection(BaseCollection):
"""
Dict-like object that represents a collection of model entries
Object that represents a collection of model entries
Keys are model names, values are instances of :class:`ModelEntry`. Values
"Keys" are model names, values are instances of :class:`ModelEntry`. Values
can also be assigned as dict and will automatically be converted into
instances of :class:`ModelEntry`.
Note
----
Entries must not necessarily be only models but may also be observations.
Entries provided in this collection refer to the x-axis in the AeroVal
heatmap display and must fulfill the protocol defined by
:class:`ModelEntry`.
"""

SETTER_CONVERT = {dict: ModelEntry}

def get_entry(self, key) -> ModelEntry:
"""Get model entry configuration
Since the configuration files for experiments are in json format, they
do not allow the storage of executable custom methods for model data
reading. Instead, these can be specified in a python module that may
be specified via :attr:`add_methods_file` and that contains a
dictionary `FUNS` that maps the method names with the callable methods.
def add_entry(self, key: str, entry: dict | ModelEntry):
if isinstance(entry, dict):
entry = ModelEntry(**entry)
entry.model_name = key
self._entries[key] = entry

As a result, this means that, by default, custom read methods for
individual models in :attr:`model_config` do not contain the
callable methods but only the names. This method will take care of
handling this and will return a dictionary where potential custom
method strings have been converted to the corresponding callable
methods.
def remove_entry(self, key: str):
if key in self._entries:
del self._entries[key]

def get_entry(self, key: str) -> ModelEntry:
"""
Get model entry configuration
Parameters
----------
model_name : str
Expand All @@ -212,20 +256,7 @@ def get_entry(self, key) -> ModelEntry:
dict
Dictionary that specifies the model setup ready for the analysis
"""
try:
entry = self[key]
entry.model_name = key
return entry
except (KeyError, AttributeError):
if key in self._entries:
return self._entries[key]
else:
raise EntryNotAvailable(f"no such entry {key}")

@property
def web_interface_names(self) -> list:
"""
List of web interface names for each obs entry
Returns
-------
list
"""
return self.keylist()
4 changes: 2 additions & 2 deletions pyaerocom/aeroval/experiment_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,8 +751,8 @@ def _is_part_of_experiment(self, obs_name, obs_var, mod_name, mod_var) -> bool:
# occurence of web_interface_name).
allobs = self.cfg.obs_cfg
obs_matches = []
for key, ocfg in allobs.items():
if obs_name == allobs.get_web_interface_name(key):
for ocfg in allobs:
if obs_name == allobs.get_web_interface_name(ocfg.obs_name):
obs_matches.append(ocfg)
if len(obs_matches) == 0:
self._invalid["obs"].append(obs_name)
Expand Down
6 changes: 3 additions & 3 deletions pyaerocom/aeroval/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def make_dummy_model(obs_list: list, cfg) -> str:
tmp_var_obj = Variable()
# Loops over variables in obs
for obs in obs_list:
for var in cfg.obs_cfg[obs].obs_vars:
for var in cfg.obs_cfg.get_entry(obs).obs_vars:
# Create dummy cube

dummy_cube = make_dummy_cube(var, start_yr=start, stop_yr=stop, freq=freq)
Expand All @@ -185,13 +185,13 @@ def make_dummy_model(obs_list: list, cfg) -> str:
for dummy_grid_yr in yr_gen:
# Add to netcdf
yr = dummy_grid_yr.years_avail()[0]
vert_code = cfg.obs_cfg[obs].obs_vert_type
vert_code = cfg.obs_cfg.get_entry(obs).obs_vert_type

save_name = dummy_grid_yr.aerocom_savename(model_id, var, vert_code, yr, freq)
dummy_grid_yr.to_netcdf(outdir, savename=save_name)

# Add dummy model to cfg
cfg.model_cfg["dummy"] = ModelEntry(model_id="dummy_model")
cfg.model_cfg.add_entry("dummy", ModelEntry(model_id="dummy_model"))

return model_id

Expand Down
32 changes: 16 additions & 16 deletions pyaerocom/aeroval/setup_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,29 +501,29 @@ def colocation_opts(self) -> ColocationSetup:

# These attributes require special attention b/c they're not based on Pydantic's BaseModel class.

obs_cfg: ObsCollection | dict = ObsCollection()

@field_validator("obs_cfg")
def validate_obs_cfg(cls, v):
if isinstance(v, ObsCollection):
return v
return ObsCollection(v)
@computed_field
@cached_property
def obs_cfg(self) -> ObsCollection:
oc = ObsCollection()
for k, v in self.model_extra.get("obs_cfg", {}).items():
oc.add_entry(k, v)
return oc

@field_serializer("obs_cfg")
def serialize_obs_cfg(self, obs_cfg: ObsCollection):
return obs_cfg.json_repr()
return obs_cfg.to_json()

model_cfg: ModelCollection | dict = ModelCollection()

@field_validator("model_cfg")
def validate_model_cfg(cls, v):
if isinstance(v, ModelCollection):
return v
return ModelCollection(v)
@computed_field
@cached_property
def model_cfg(self) -> ModelCollection:
mc = ModelCollection()
for k, v in self.model_extra.get("model_cfg", {}).items():
mc.add_entry(k, v)
return mc

@field_serializer("model_cfg")
def serialize_model_cfg(self, model_cfg: ModelCollection):
return model_cfg.json_repr()
return model_cfg.to_json()

###########################
## Methods
Expand Down
5 changes: 3 additions & 2 deletions pyaerocom/aeroval/superobs_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def _run_var(self, model_name, obs_name, var_name, try_colocate_if_missing):
coldata_files = []
coldata_resolutions = []
vert_codes = []
obs_needed = self.cfg.obs_cfg[obs_name].obs_id
vert_code = self.cfg.obs_cfg.get_entry(obs_name).obs_vert_type
obs_entry = self.cfg.obs_cfg.get_entry(obs_name)
obs_needed = obs_entry.obs_id
vert_code = obs_entry.obs_vert_type
for oname in obs_needed:
fp, ts_type, vert_code = self._get_coldata_fileinfo(
model_name, oname, var_name, try_colocate_if_missing
Expand Down
3 changes: 2 additions & 1 deletion pyaerocom/aeroval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,8 @@ def compute_model_average_and_diversity(

unit_out = get_variable(var_name).units

for mname in models:
for m in models:
mname = m.model_name
logger.info(f"Adding {mname} ({var_name})")

mid = cfg.cfg.model_cfg.get_entry(mname).model_id
Expand Down
Loading

0 comments on commit 7b6616d

Please sign in to comment.