diff --git a/pyaerocom/aeroval/collections.py b/pyaerocom/aeroval/collections.py index dc67e3fc0..ff16cabfc 100644 --- a/pyaerocom/aeroval/collections.py +++ b/pyaerocom/aeroval/collections.py @@ -1,32 +1,76 @@ import abc from fnmatch import fnmatch +import json -from pyaerocom._lowlevel_helpers import BrowseDict from pyaerocom.aeroval.modelentry import ModelEntry from pyaerocom.aeroval.obsentry import ObsEntry -from pyaerocom.exceptions import EntryNotAvailable, EvalEntryNameError +from pyaerocom.exceptions import EntryNotAvailable -class BaseCollection(BrowseDict, abc.ABC): - #: maximum length of entry names - MAXLEN_KEYS = 25 - #: Invalid chars in entry names - FORBIDDEN_CHARS_KEYS = [] +class BaseCollection(abc.ABC): + def __init__(self): + """ + Initialize an instance of BaseCollection. + The instance maintains a dictionary of entries. + """ + self._entries = {} - # TODO: Wait a few release cycles after v0.23.0 and see if this can be removed - def _check_entry_name(self, key): - if any([x in key for x in self.FORBIDDEN_CHARS_KEYS]): - raise EvalEntryNameError( - f"Invalid name: {key}. Must not contain any of the following " - f"characters: {self.FORBIDDEN_CHARS_KEYS}" - ) + def __iter__(self): + """ + Iterates over each entry in the collection. - def __setitem__(self, key, value): - self._check_entry_name(key) - super().__setitem__(key, value) + Yields + ------ + object + The next entry in the collection. + """ + yield from self._entries.values() - def keylist(self, name_or_pattern: str = None) -> list: - """Find model names that match input search pattern(s) + @abc.abstractmethod + def add_entry(self, key, value) -> None: + """ + Abstract method to add an entry to the collection. + + Parameters + ---------- + key: Hashable + The key of the entry. + value: object + The value of the entry. + """ + pass + + @abc.abstractmethod + def remove_entry(self, key) -> None: + """ + Abstract method to remove an entry from the collection. + + Parameters + ---------- + key: Hashable + The key of the entry to be removed. + """ + pass + + @abc.abstractmethod + def get_entry(self, key) -> object: + """ + Abstract method to get an entry from the collection. + + Parameters + ---------- + key: Hashable + The key of the entry to retrieve. + + Returns + ------- + object + The entry associated with the provided key. + """ + pass + + def keylist(self, name_or_pattern: str = None) -> list[str]: + """Find model / obs names that match input search pattern(s) Parameters ---------- @@ -48,39 +92,34 @@ def keylist(self, name_or_pattern: str = None) -> list: name_or_pattern = "*" matches = [] - for key in self.keys(): + for key in self._entries.keys(): if fnmatch(key, name_or_pattern) and key not in matches: matches.append(key) if len(matches) == 0: raise KeyError(f"No matches could be found that match input {name_or_pattern}") return matches - @abc.abstractmethod - def get_entry(self, key) -> object: - """ - Getter for eval entries - - Raises - ------ - KeyError - if input name is not in this collection - """ - pass - @property - @abc.abstractmethod def web_interface_names(self) -> list: """ - List of webinterface names for + List of web interface names for each obs entry + + Returns + ------- + list """ - pass + return self.keylist() + + def to_json(self) -> str: + """Serialize ModelCollection to a JSON string.""" + return json.dumps({k: v.dict() for k, v in self._entries.items()}, default=str) class ObsCollection(BaseCollection): """ - Dict-like object that represents a collection of obs entries + Object that represents a collection of obs entries - Keys are obs names, values are instances of :class:`ObsEntry`. Values can + "Keys" are obs names, values are instances of :class:`ObsEntry`. Values can also be assigned as dict and will automatically be converted into instances of :class:`ObsEntry`. @@ -93,9 +132,16 @@ class ObsCollection(BaseCollection): """ - SETTER_CONVERT = {dict: ObsEntry} + def add_entry(self, key: str, entry: dict | ObsEntry): + if isinstance(entry, dict): + entry = ObsEntry(**entry) + self._entries[key] = entry - def get_entry(self, key) -> object: + def remove_entry(self, key: str): + if key in self._entries: + del self._entries[key] + + def get_entry(self, key: str) -> ObsEntry: """ Getter for obs entries @@ -105,7 +151,7 @@ def get_entry(self, key) -> object: if input name is not in this collection """ try: - entry = self[key] + entry = self._entries[key] entry.obs_name = self.get_web_interface_name(key) return entry except (KeyError, AttributeError): @@ -122,11 +168,11 @@ def get_all_vars(self) -> list[str]: """ vars = [] - for ocfg in self.values(): + for ocfg in self._entries.values(): vars.extend(ocfg.get_all_vars()) return sorted(list(set(vars))) - def get_web_interface_name(self, key): + def get_web_interface_name(self, key: str) -> str: """ Get webinterface name for entry @@ -147,7 +193,12 @@ def get_web_interface_name(self, key): corresponding name """ - return self[key].web_interface_name if self[key].web_interface_name is not None else key + entry = self._entries.get(key) + return ( + entry.web_interface_name + if entry is not None and entry.web_interface_name is not None + else key + ) @property def web_interface_names(self) -> list: @@ -163,45 +214,38 @@ def web_interface_names(self) -> list: @property def all_vert_types(self): """List of unique vertical types specified in this collection""" - return list({x.obs_vert_type for x in self.values()}) + return list({x.obs_vert_type for x in self._entries.values()}) class ModelCollection(BaseCollection): """ - Dict-like object that represents a collection of model entries + Object that represents a collection of model entries - Keys are model names, values are instances of :class:`ModelEntry`. Values + "Keys" are model names, values are instances of :class:`ModelEntry`. Values can also be assigned as dict and will automatically be converted into instances of :class:`ModelEntry`. - Note ---- Entries must not necessarily be only models but may also be observations. Entries provided in this collection refer to the x-axis in the AeroVal heatmap display and must fulfill the protocol defined by :class:`ModelEntry`. - """ - SETTER_CONVERT = {dict: ModelEntry} - - def get_entry(self, key) -> ModelEntry: - """Get model entry configuration - - Since the configuration files for experiments are in json format, they - do not allow the storage of executable custom methods for model data - reading. Instead, these can be specified in a python module that may - be specified via :attr:`add_methods_file` and that contains a - dictionary `FUNS` that maps the method names with the callable methods. + def add_entry(self, key: str, entry: dict | ModelEntry): + if isinstance(entry, dict): + entry = ModelEntry(**entry) + entry.model_name = key + self._entries[key] = entry - As a result, this means that, by default, custom read methods for - individual models in :attr:`model_config` do not contain the - callable methods but only the names. This method will take care of - handling this and will return a dictionary where potential custom - method strings have been converted to the corresponding callable - methods. + def remove_entry(self, key: str): + if key in self._entries: + del self._entries[key] + def get_entry(self, key: str) -> ModelEntry: + """ + Get model entry configuration Parameters ---------- model_name : str @@ -212,20 +256,7 @@ def get_entry(self, key) -> ModelEntry: dict Dictionary that specifies the model setup ready for the analysis """ - try: - entry = self[key] - entry.model_name = key - return entry - except (KeyError, AttributeError): + if key in self._entries: + return self._entries[key] + else: raise EntryNotAvailable(f"no such entry {key}") - - @property - def web_interface_names(self) -> list: - """ - List of web interface names for each obs entry - - Returns - ------- - list - """ - return self.keylist() diff --git a/pyaerocom/aeroval/experiment_output.py b/pyaerocom/aeroval/experiment_output.py index 9e6ce980e..a15d524c2 100644 --- a/pyaerocom/aeroval/experiment_output.py +++ b/pyaerocom/aeroval/experiment_output.py @@ -751,8 +751,8 @@ def _is_part_of_experiment(self, obs_name, obs_var, mod_name, mod_var) -> bool: # occurence of web_interface_name). allobs = self.cfg.obs_cfg obs_matches = [] - for key, ocfg in allobs.items(): - if obs_name == allobs.get_web_interface_name(key): + for ocfg in allobs: + if obs_name == allobs.get_web_interface_name(ocfg.obs_name): obs_matches.append(ocfg) if len(obs_matches) == 0: self._invalid["obs"].append(obs_name) diff --git a/pyaerocom/aeroval/helpers.py b/pyaerocom/aeroval/helpers.py index 0a692161f..b4a8fa43f 100644 --- a/pyaerocom/aeroval/helpers.py +++ b/pyaerocom/aeroval/helpers.py @@ -162,7 +162,7 @@ def make_dummy_model(obs_list: list, cfg) -> str: tmp_var_obj = Variable() # Loops over variables in obs for obs in obs_list: - for var in cfg.obs_cfg[obs].obs_vars: + for var in cfg.obs_cfg.get_entry(obs).obs_vars: # Create dummy cube dummy_cube = make_dummy_cube(var, start_yr=start, stop_yr=stop, freq=freq) @@ -185,13 +185,13 @@ def make_dummy_model(obs_list: list, cfg) -> str: for dummy_grid_yr in yr_gen: # Add to netcdf yr = dummy_grid_yr.years_avail()[0] - vert_code = cfg.obs_cfg[obs].obs_vert_type + vert_code = cfg.obs_cfg.get_entry(obs).obs_vert_type save_name = dummy_grid_yr.aerocom_savename(model_id, var, vert_code, yr, freq) dummy_grid_yr.to_netcdf(outdir, savename=save_name) # Add dummy model to cfg - cfg.model_cfg["dummy"] = ModelEntry(model_id="dummy_model") + cfg.model_cfg.add_entry("dummy", ModelEntry(model_id="dummy_model")) return model_id diff --git a/pyaerocom/aeroval/setup_classes.py b/pyaerocom/aeroval/setup_classes.py index 1677c1b3e..a75b36a8e 100644 --- a/pyaerocom/aeroval/setup_classes.py +++ b/pyaerocom/aeroval/setup_classes.py @@ -501,29 +501,29 @@ def colocation_opts(self) -> ColocationSetup: # These attributes require special attention b/c they're not based on Pydantic's BaseModel class. - obs_cfg: ObsCollection | dict = ObsCollection() - - @field_validator("obs_cfg") - def validate_obs_cfg(cls, v): - if isinstance(v, ObsCollection): - return v - return ObsCollection(v) + @computed_field + @cached_property + def obs_cfg(self) -> ObsCollection: + oc = ObsCollection() + for k, v in self.model_extra.get("obs_cfg", {}).items(): + oc.add_entry(k, v) + return oc @field_serializer("obs_cfg") def serialize_obs_cfg(self, obs_cfg: ObsCollection): - return obs_cfg.json_repr() + return obs_cfg.to_json() - model_cfg: ModelCollection | dict = ModelCollection() - - @field_validator("model_cfg") - def validate_model_cfg(cls, v): - if isinstance(v, ModelCollection): - return v - return ModelCollection(v) + @computed_field + @cached_property + def model_cfg(self) -> ModelCollection: + mc = ModelCollection() + for k, v in self.model_extra.get("model_cfg", {}).items(): + mc.add_entry(k, v) + return mc @field_serializer("model_cfg") def serialize_model_cfg(self, model_cfg: ModelCollection): - return model_cfg.json_repr() + return model_cfg.to_json() ########################### ## Methods diff --git a/pyaerocom/aeroval/superobs_engine.py b/pyaerocom/aeroval/superobs_engine.py index fdae8d8b7..46364caaa 100644 --- a/pyaerocom/aeroval/superobs_engine.py +++ b/pyaerocom/aeroval/superobs_engine.py @@ -75,8 +75,9 @@ def _run_var(self, model_name, obs_name, var_name, try_colocate_if_missing): coldata_files = [] coldata_resolutions = [] vert_codes = [] - obs_needed = self.cfg.obs_cfg[obs_name].obs_id - vert_code = self.cfg.obs_cfg.get_entry(obs_name).obs_vert_type + obs_entry = self.cfg.obs_cfg.get_entry(obs_name) + obs_needed = obs_entry.obs_id + vert_code = obs_entry.obs_vert_type for oname in obs_needed: fp, ts_type, vert_code = self._get_coldata_fileinfo( model_name, oname, var_name, try_colocate_if_missing diff --git a/pyaerocom/aeroval/utils.py b/pyaerocom/aeroval/utils.py index 8f680fe05..a4163ca3d 100644 --- a/pyaerocom/aeroval/utils.py +++ b/pyaerocom/aeroval/utils.py @@ -145,7 +145,8 @@ def compute_model_average_and_diversity( unit_out = get_variable(var_name).units - for mname in models: + for m in models: + mname = m.model_name logger.info(f"Adding {mname} ({var_name})") mid = cfg.cfg.model_cfg.get_entry(mname).model_id diff --git a/tests/aeroval/test_aeroval_HIGHLEV.py b/tests/aeroval/test_aeroval_HIGHLEV.py index 558983fa7..c2b07ed99 100644 --- a/tests/aeroval/test_aeroval_HIGHLEV.py +++ b/tests/aeroval/test_aeroval_HIGHLEV.py @@ -113,11 +113,11 @@ def test_reanalyse_existing(eval_config: dict, reanalyse_existing: bool): @pytest.mark.parametrize("cfg", ["cfgexp4"]) def test_superobs_different_resolutions(eval_config: dict): cfg = EvalSetup(**eval_config) - cfg.model_cfg["TM5-AP3-CTRL"].model_ts_type_read = None - cfg.model_cfg["TM5-AP3-CTRL"].flex_ts_type = True + cfg.model_cfg.get_entry("TM5-AP3-CTRL").model_ts_type_read = None + cfg.model_cfg.get_entry("TM5-AP3-CTRL").flex_ts_type = True - cfg.obs_cfg["AERONET-Sun"].ts_type = "daily" - cfg.obs_cfg["AERONET-SDA"].ts_type = "monthly" + cfg.obs_cfg.get_entry("AERONET-Sun").ts_type = "daily" + cfg.obs_cfg.get_entry("AERONET-SDA").ts_type = "monthly" proc = ExperimentProcessor(cfg) proc.exp_output.delete_experiment_data(also_coldata=True) diff --git a/tests/aeroval/test_collections.py b/tests/aeroval/test_collections.py index 5f2a98f7c..1327ca6b8 100644 --- a/tests/aeroval/test_collections.py +++ b/tests/aeroval/test_collections.py @@ -1,28 +1,91 @@ from pyaerocom.aeroval.collections import ObsCollection, ModelCollection +from pyaerocom.aeroval.obsentry import ObsEntry +from pyaerocom.aeroval.modelentry import ModelEntry +import pytest -def test_obscollection(): - oc = ObsCollection(model1=dict(obs_id="bla", obs_vars="od550aer", obs_vert_type="Column")) +def test_obscollection_init_and_add_entry(): + oc = ObsCollection() + oc.add_entry("model1", dict(obs_id="bla", obs_vars="od550aer", obs_vert_type="Column")) assert oc - oc["AN-EEA-MP"] = dict( - is_superobs=True, - obs_id=("AirNow", "EEA-NRT-rural", "MarcoPolo"), - obs_vars=["concpm10", "concpm25", "vmro3", "vmrno2"], - obs_vert_type="Surface", + oc.add_entry( + "AN-EEA-MP", + dict( + is_superobs=True, + obs_id=("AirNow", "EEA-NRT-rural", "MarcoPolo"), + obs_vars=["concpm10", "concpm25", "vmro3", "vmrno2"], + obs_vert_type="Surface", + ), ) - assert "AN-EEA-MP" in oc + assert "AN-EEA-MP" in oc.keylist() -def test_modelcollection(): - mc = ModelCollection(model1=dict(model_id="bla", obs_vars="od550aer", obs_vert_type="Column")) +def test_obscollection_add_and_get_entry(): + collection = ObsCollection() + entry = ObsEntry(obs_id="obs1", obs_vars=("var1",)) + collection.add_entry("key1", entry) + retrieved_entry = collection.get_entry("key1") + assert retrieved_entry == entry + + +def test_obscollection_add_and_remove_entry(): + collection = ObsCollection() + entry = ObsEntry(obs_id="obs1", obs_vars=("var1",)) + collection.add_entry("key1", entry) + collection.remove_entry("key1") + with pytest.raises(KeyError): + collection.get_entry("key1") + + +def test_obscollection_get_web_interface_name(): + collection = ObsCollection() + entry = ObsEntry(obs_id="obs1", obs_vars=("var1",), web_interface_name="web_name") + collection.add_entry("key1", entry) + assert collection.get_web_interface_name("key1") == "web_name" + + +def test_obscollection_all_vert_types(): + collection = ObsCollection() + entry1 = ObsEntry( + obs_id="obs1", obs_vars=("var1",), obs_vert_type="Surface" + ) # Assuming ObsEntry has an obs_vert_type attribute + entry2 = ObsEntry(obs_id="obs2", obs_vars=("var2",), obs_vert_type="Profile") + collection.add_entry("key1", entry1) + collection.add_entry("key2", entry2) + assert set(collection.all_vert_types) == {"Surface", "Profile"} + + +def test_modelcollection_init_and_add_entry(): + mc = ModelCollection() + mc.add_entry("model1", dict(model_id="bla", obs_vars="od550aer", obs_vert_type="Column")) assert mc - mc["ECMWF_OSUITE"] = dict( - model_id="ECMWF_OSUITE", - obs_vars=["concpm10"], - obs_vert_type="Surface", + mc.add_entry( + "ECMWF_OSUITE", + dict( + model_id="ECMWF_OSUITE", + obs_vars=["concpm10"], + obs_vert_type="Surface", + ), ) - assert "ECMWF_OSUITE" in mc + assert "ECMWF_OSUITE" in mc.keylist() + + +def test_modelcollection_add_and_get_entry(): + collection = ModelCollection() + entry = ModelEntry(model_id="mod1") + collection.add_entry("key1", entry) + retrieved_entry = collection.get_entry("key1") + assert retrieved_entry == entry + + +def test_modelcollection_add_and_remove_entry(): + collection = ModelCollection() + entry = ModelEntry(model_id="obs1") + collection.add_entry("key1", entry) + collection.remove_entry("key1") + with pytest.raises(KeyError): + collection.get_entry("key1") diff --git a/tests/aeroval/test_experiment_output.py b/tests/aeroval/test_experiment_output.py index a74498eb1..522516948 100644 --- a/tests/aeroval/test_experiment_output.py +++ b/tests/aeroval/test_experiment_output.py @@ -293,10 +293,10 @@ def test_Experiment_Output_clean_json_files_CFG1(eval_config: dict): @pytest.mark.parametrize("cfg", ["cfgexp1"]) def test_Experiment_Output_clean_json_files_CFG1_INVALIDMOD(eval_config: dict): cfg = EvalSetup(**eval_config) - cfg.model_cfg["mod1"] = cfg.model_cfg["TM5-AP3-CTRL"] + cfg.model_cfg.add_entry("mod1", cfg.model_cfg.get_entry("TM5-AP3-CTRL")) proc = ExperimentProcessor(cfg) proc.run() - del cfg.model_cfg["mod1"] + cfg.model_cfg.remove_entry("mod1") modified = proc.exp_output.clean_json_files() assert len(modified) == 13 @@ -305,10 +305,9 @@ def test_Experiment_Output_clean_json_files_CFG1_INVALIDMOD(eval_config: dict): @pytest.mark.parametrize("cfg", ["cfgexp1"]) def test_Experiment_Output_clean_json_files_CFG1_INVALIDOBS(eval_config: dict): cfg = EvalSetup(**eval_config) - cfg.obs_cfg["obs1"] = cfg.obs_cfg["AERONET-Sun"] + cfg.obs_cfg.add_entry("obs1", cfg.obs_cfg.get_entry("AERONET-Sun")) proc = ExperimentProcessor(cfg) proc.run() - del cfg.obs_cfg["obs1"] modified = proc.exp_output.clean_json_files() assert len(modified) == 13 @@ -354,7 +353,7 @@ def test_Experiment_Output_drop_stats_and_decimals( stats_decimals, ) cfg = EvalSetup(**eval_config) - cfg.model_cfg["mod1"] = cfg.model_cfg["TM5-AP3-CTRL"] + cfg.model_cfg.add_entry("mod1", cfg.model_cfg.get_entry("TM5-AP3-CTRL")) proc = ExperimentProcessor(cfg) proc.run() path = Path(proc.exp_output.exp_dir) diff --git a/tests/aeroval/test_experiment_processor.py b/tests/aeroval/test_experiment_processor.py index 5374e831a..af0be63c0 100644 --- a/tests/aeroval/test_experiment_processor.py +++ b/tests/aeroval/test_experiment_processor.py @@ -31,6 +31,7 @@ def test_ExperimentProcessor_run(processor: ExperimentProcessor): processor.run() +# Temporary until ObsCollection implemented simiarly then can run same test @geojson_unavail @pytest.mark.parametrize( "cfg,kwargs,error", @@ -47,7 +48,9 @@ def test_ExperimentProcessor_run(processor: ExperimentProcessor): ), ], ) -def test_ExperimentProcessor_run_error(processor: ExperimentProcessor, kwargs: dict, error: str): +def test_ExperimentProcessor_run_error_obs_name( + processor: ExperimentProcessor, kwargs: dict, error: str +): with pytest.raises(KeyError) as e: processor.run(**kwargs) assert str(e.value) == error diff --git a/tests/aeroval/test_helpers.py b/tests/aeroval/test_helpers.py index 085efe943..c2af1998a 100644 --- a/tests/aeroval/test_helpers.py +++ b/tests/aeroval/test_helpers.py @@ -137,6 +137,6 @@ def test__get_min_max_year_periods_error(): @pytest.mark.parametrize("cfg", ["cfgexp1"]) def test_make_dummy_model(eval_config: dict): cfg = EvalSetup(**eval_config) - assert cfg.obs_cfg["AERONET-Sun"] + assert cfg.obs_cfg.get_entry("AERONET-Sun") model_id = make_dummy_model(["AERONET-Sun"], cfg) assert model_id == "dummy_model"