diff --git a/pyaerocom/_lowlevel_helpers.py b/pyaerocom/_lowlevel_helpers.py index ace6258ee..4b374f56f 100644 --- a/pyaerocom/_lowlevel_helpers.py +++ b/pyaerocom/_lowlevel_helpers.py @@ -77,6 +77,7 @@ def _class_name(obj): return type(obj).__name__ +# TODO: Check to see if instances of these classes can instead use pydantic class Validator(abc.ABC): def __set_name__(self, owner, name): self._name = name @@ -113,56 +114,6 @@ def validate(self, val): return val -class StrWithDefault(Validator): - def __init__(self, default: str): - self.default = default - - def validate(self, val): - if not isinstance(val, str): - if val is None: - val = self.default - else: - raise ValueError(f"need str or None, got {val}") - return val - - -class FlexList(Validator): - """list that can be instantated via input str, tuple or list or None""" - - def validate(self, val): - if isinstance(val, str): - val = [val] - elif isinstance(val, tuple): - val = list(val) - elif val is None: - val = [] - elif not isinstance(val, list): - raise ValueError(f"failed to convert {val} to list") - return val - - -class EitherOf(Validator): - _allowed = FlexList() - - def __init__(self, allowed: list): - self._allowed = allowed - - def validate(self, val): - if not any([x == val for x in self._allowed]): - raise ValueError(f"invalid value {val}, needs to be either of {self._allowed}.") - return val - - -class ListOfStrings(FlexList): - def validate(self, val): - # make sure to have a list - val = super().validate(val) - # make sure all entries are strings - if not all([isinstance(x, str) for x in val]): - raise ValueError(f"not all items are str type in input list {val}") - return val - - class Loc(abc.ABC): """Abstract descriptor representing a path location diff --git a/pyaerocom/aeroval/aux_io_helpers.py b/pyaerocom/aeroval/aux_io_helpers.py index 04552f46b..c20937291 100644 --- a/pyaerocom/aeroval/aux_io_helpers.py +++ b/pyaerocom/aeroval/aux_io_helpers.py @@ -1,8 +1,21 @@ import importlib import os import sys +from collections.abc import Callable -from pyaerocom._lowlevel_helpers import AsciiFileLoc, ListOfStrings +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + +from typing import TYPE_CHECKING + +from pydantic import ( + BaseModel, + model_validator, +) + +from pyaerocom._lowlevel_helpers import AsciiFileLoc def check_aux_info(fun, vars_required, funcs): @@ -26,11 +39,11 @@ def check_aux_info(fun, vars_required, funcs): required. """ - spec = _AuxReadSpec(fun, vars_required, funcs) + spec = _AuxReadSpec(fun=fun, vars_required=vars_required, funcs=funcs) return dict(fun=spec.fun, vars_required=spec.vars_required) -class _AuxReadSpec: +class _AuxReadSpec(BaseModel): """ Class that specifies requirements for computation of additional variables @@ -53,39 +66,22 @@ class _AuxReadSpec: """ - vars_required = ListOfStrings() - - def __init__(self, fun, vars_required: list, funcs: dict): - self.vars_required = vars_required - self.fun = self.get_func(fun, funcs) - - def get_func(self, fun, funcs): - """ - Get callable function for computation of variable - - Parameters - ---------- - fun : str or callable - Name of function or function. - funcs : dict - Dictionary with possible functions (values) and names (keys) - - Raises - ------ - ValueError - If function could not be retrieved. - - Returns - ------- - callable - callable function object. - - """ - if callable(fun): - return fun - elif isinstance(fun, str): - return funcs[fun] - raise ValueError("failed to retrieve aux func") + if TYPE_CHECKING: + fun: Callable + else: + fun: str | Callable + vars_required: list[str] + funcs: dict[str, Callable] + + @model_validator(mode="after") + def validate_fun(self) -> Self: + if callable(self.fun): + return self + elif isinstance(self.fun, str): + self.fun = self.funcs[self.fun] + return self + else: + raise ValueError("failed to retrieve aux func") class ReadAuxHandler: diff --git a/pyaerocom/aeroval/coldatatojson_engine.py b/pyaerocom/aeroval/coldatatojson_engine.py index 368948eed..3a4058ce8 100644 --- a/pyaerocom/aeroval/coldatatojson_engine.py +++ b/pyaerocom/aeroval/coldatatojson_engine.py @@ -141,6 +141,18 @@ def process_coldata(self, coldata: ColocatedData): if "var_name_input" in coldata.metadata: obs_var = coldata.metadata["var_name_input"][0] model_var = coldata.metadata["var_name_input"][1] + elif ( + "obs_vars" in coldata.metadata + ): # Try and get from obs_vars. Should not be needed in reading in a ColocatedData object created with pyaerocom. + obs_var = model_var = coldata.metadata["obs_vars"] + coldata.metadata["var_name_input"] = [obs_var, model_var] + logger.warning( + "ColdataToJsonEngine: Failed to access var_name_input from coldata.metadata. " + "This could be because you're using a ColocatedData object created outside of pyaerocom. " + "Setting obs_var and model_data to value in obs_vars instead. " + "Setting var_input_data to these values as well. " + ) + else: obs_var = model_var = "UNDEFINED" @@ -369,7 +381,13 @@ def _process_stats_timeseries_for_all_regions( region = regnames[reg] self.exp_output.add_heatmap_timeseries_entry( - stats_ts, region, obs_name, var_name_web, vert_code, model_name, model_var + stats_ts, + region, + obs_name, + var_name_web, + vert_code, + model_name, + model_var, ) logger.info("Processing heatmap data for all regions") diff --git a/pyaerocom/aeroval/collections.py b/pyaerocom/aeroval/collections.py index dc67e3fc0..79c46c6a8 100644 --- a/pyaerocom/aeroval/collections.py +++ b/pyaerocom/aeroval/collections.py @@ -1,32 +1,75 @@ import abc from fnmatch import fnmatch -from pyaerocom._lowlevel_helpers import BrowseDict from pyaerocom.aeroval.modelentry import ModelEntry from pyaerocom.aeroval.obsentry import ObsEntry -from pyaerocom.exceptions import EntryNotAvailable, EvalEntryNameError +from pyaerocom.exceptions import EntryNotAvailable -class BaseCollection(BrowseDict, abc.ABC): - #: maximum length of entry names - MAXLEN_KEYS = 25 - #: Invalid chars in entry names - FORBIDDEN_CHARS_KEYS = [] +class BaseCollection(abc.ABC): + def __init__(self): + """ + Initialize an instance of BaseCollection. + The instance maintains a dictionary of entries. + """ + self._entries = {} + + def __iter__(self): + """ + Iterates over each entry in the collection. + + Yields + ------ + object + The next entry in the collection. + """ + yield from self._entries.values() - # TODO: Wait a few release cycles after v0.23.0 and see if this can be removed - def _check_entry_name(self, key): - if any([x in key for x in self.FORBIDDEN_CHARS_KEYS]): - raise EvalEntryNameError( - f"Invalid name: {key}. Must not contain any of the following " - f"characters: {self.FORBIDDEN_CHARS_KEYS}" - ) + @abc.abstractmethod + def add_entry(self, key, value) -> None: + """ + Abstract method to add an entry to the collection. - def __setitem__(self, key, value): - self._check_entry_name(key) - super().__setitem__(key, value) + Parameters + ---------- + key: Hashable + The key of the entry. + value: object + The value of the entry. + """ + pass - def keylist(self, name_or_pattern: str = None) -> list: - """Find model names that match input search pattern(s) + @abc.abstractmethod + def remove_entry(self, key) -> None: + """ + Abstract method to remove an entry from the collection. + + Parameters + ---------- + key: Hashable + The key of the entry to be removed. + """ + pass + + @abc.abstractmethod + def get_entry(self, key) -> object: + """ + Abstract method to get an entry from the collection. + + Parameters + ---------- + key: Hashable + The key of the entry to retrieve. + + Returns + ------- + object + The entry associated with the provided key. + """ + pass + + def keylist(self, name_or_pattern: str = None) -> list[str]: + """Find model / obs names that match input search pattern(s) Parameters ---------- @@ -48,39 +91,47 @@ def keylist(self, name_or_pattern: str = None) -> list: name_or_pattern = "*" matches = [] - for key in self.keys(): + for key in self._entries.keys(): if fnmatch(key, name_or_pattern) and key not in matches: matches.append(key) if len(matches) == 0: raise KeyError(f"No matches could be found that match input {name_or_pattern}") return matches - @abc.abstractmethod - def get_entry(self, key) -> object: + @property + def web_interface_names(self) -> list: """ - Getter for eval entries + List of web interface names for each obs entry - Raises - ------ - KeyError - if input name is not in this collection + Returns + ------- + list """ - pass + return self.keylist() - @property - @abc.abstractmethod - def web_interface_names(self) -> list: + def as_dict(self) -> dict: """ - List of webinterface names for + Convert object to serializable dict + + Returns + ------- + dict + content of class + """ - pass + output = {} + for key, val in self._entries.items(): + if hasattr(val, "json_repr"): + val = val.json_repr() + output[key] = val + return output class ObsCollection(BaseCollection): """ - Dict-like object that represents a collection of obs entries + Object that represents a collection of obs entries - Keys are obs names, values are instances of :class:`ObsEntry`. Values can + "Keys" are obs names, values are instances of :class:`ObsEntry`. Values can also be assigned as dict and will automatically be converted into instances of :class:`ObsEntry`. @@ -93,9 +144,16 @@ class ObsCollection(BaseCollection): """ - SETTER_CONVERT = {dict: ObsEntry} + def add_entry(self, key: str, entry: dict | ObsEntry): + if isinstance(entry, dict): + entry = ObsEntry(**entry) + self._entries[key] = entry - def get_entry(self, key) -> object: + def remove_entry(self, key: str): + if key in self._entries: + del self._entries[key] + + def get_entry(self, key: str) -> ObsEntry: """ Getter for obs entries @@ -105,7 +163,7 @@ def get_entry(self, key) -> object: if input name is not in this collection """ try: - entry = self[key] + entry = self._entries[key] entry.obs_name = self.get_web_interface_name(key) return entry except (KeyError, AttributeError): @@ -122,11 +180,11 @@ def get_all_vars(self) -> list[str]: """ vars = [] - for ocfg in self.values(): + for ocfg in self._entries.values(): vars.extend(ocfg.get_all_vars()) return sorted(list(set(vars))) - def get_web_interface_name(self, key): + def get_web_interface_name(self, key: str) -> str: """ Get webinterface name for entry @@ -147,7 +205,12 @@ def get_web_interface_name(self, key): corresponding name """ - return self[key].web_interface_name if self[key].web_interface_name is not None else key + entry = self._entries.get(key) + return ( + entry.web_interface_name + if entry is not None and entry.web_interface_name is not None + else key + ) @property def web_interface_names(self) -> list: @@ -163,45 +226,38 @@ def web_interface_names(self) -> list: @property def all_vert_types(self): """List of unique vertical types specified in this collection""" - return list({x.obs_vert_type for x in self.values()}) + return list({x.obs_vert_type for x in self._entries.values()}) class ModelCollection(BaseCollection): """ - Dict-like object that represents a collection of model entries + Object that represents a collection of model entries - Keys are model names, values are instances of :class:`ModelEntry`. Values + "Keys" are model names, values are instances of :class:`ModelEntry`. Values can also be assigned as dict and will automatically be converted into instances of :class:`ModelEntry`. - Note ---- Entries must not necessarily be only models but may also be observations. Entries provided in this collection refer to the x-axis in the AeroVal heatmap display and must fulfill the protocol defined by :class:`ModelEntry`. - """ - SETTER_CONVERT = {dict: ModelEntry} - - def get_entry(self, key) -> ModelEntry: - """Get model entry configuration + def add_entry(self, key: str, entry: dict | ModelEntry): + if isinstance(entry, dict): + entry = ModelEntry(**entry) + entry.model_name = key + self._entries[key] = entry - Since the configuration files for experiments are in json format, they - do not allow the storage of executable custom methods for model data - reading. Instead, these can be specified in a python module that may - be specified via :attr:`add_methods_file` and that contains a - dictionary `FUNS` that maps the method names with the callable methods. - - As a result, this means that, by default, custom read methods for - individual models in :attr:`model_config` do not contain the - callable methods but only the names. This method will take care of - handling this and will return a dictionary where potential custom - method strings have been converted to the corresponding callable - methods. + def remove_entry(self, key: str): + if key in self._entries: + del self._entries[key] + def get_entry(self, key: str) -> ModelEntry: + """ + Get model entry configuration Parameters ---------- model_name : str @@ -212,20 +268,7 @@ def get_entry(self, key) -> ModelEntry: dict Dictionary that specifies the model setup ready for the analysis """ - try: - entry = self[key] - entry.model_name = key - return entry - except (KeyError, AttributeError): + if key in self._entries: + return self._entries[key] + else: raise EntryNotAvailable(f"no such entry {key}") - - @property - def web_interface_names(self) -> list: - """ - List of web interface names for each obs entry - - Returns - ------- - list - """ - return self.keylist() diff --git a/pyaerocom/aeroval/experiment_output.py b/pyaerocom/aeroval/experiment_output.py index 9e6ce980e..a15d524c2 100644 --- a/pyaerocom/aeroval/experiment_output.py +++ b/pyaerocom/aeroval/experiment_output.py @@ -751,8 +751,8 @@ def _is_part_of_experiment(self, obs_name, obs_var, mod_name, mod_var) -> bool: # occurence of web_interface_name). allobs = self.cfg.obs_cfg obs_matches = [] - for key, ocfg in allobs.items(): - if obs_name == allobs.get_web_interface_name(key): + for ocfg in allobs: + if obs_name == allobs.get_web_interface_name(ocfg.obs_name): obs_matches.append(ocfg) if len(obs_matches) == 0: self._invalid["obs"].append(obs_name) diff --git a/pyaerocom/aeroval/helpers.py b/pyaerocom/aeroval/helpers.py index 0a692161f..b4a8fa43f 100644 --- a/pyaerocom/aeroval/helpers.py +++ b/pyaerocom/aeroval/helpers.py @@ -162,7 +162,7 @@ def make_dummy_model(obs_list: list, cfg) -> str: tmp_var_obj = Variable() # Loops over variables in obs for obs in obs_list: - for var in cfg.obs_cfg[obs].obs_vars: + for var in cfg.obs_cfg.get_entry(obs).obs_vars: # Create dummy cube dummy_cube = make_dummy_cube(var, start_yr=start, stop_yr=stop, freq=freq) @@ -185,13 +185,13 @@ def make_dummy_model(obs_list: list, cfg) -> str: for dummy_grid_yr in yr_gen: # Add to netcdf yr = dummy_grid_yr.years_avail()[0] - vert_code = cfg.obs_cfg[obs].obs_vert_type + vert_code = cfg.obs_cfg.get_entry(obs).obs_vert_type save_name = dummy_grid_yr.aerocom_savename(model_id, var, vert_code, yr, freq) dummy_grid_yr.to_netcdf(outdir, savename=save_name) # Add dummy model to cfg - cfg.model_cfg["dummy"] = ModelEntry(model_id="dummy_model") + cfg.model_cfg.add_entry("dummy", ModelEntry(model_id="dummy_model")) return model_id diff --git a/pyaerocom/aeroval/setup_classes.py b/pyaerocom/aeroval/setup_classes.py index 1677c1b3e..8fb575f06 100644 --- a/pyaerocom/aeroval/setup_classes.py +++ b/pyaerocom/aeroval/setup_classes.py @@ -501,29 +501,29 @@ def colocation_opts(self) -> ColocationSetup: # These attributes require special attention b/c they're not based on Pydantic's BaseModel class. - obs_cfg: ObsCollection | dict = ObsCollection() - - @field_validator("obs_cfg") - def validate_obs_cfg(cls, v): - if isinstance(v, ObsCollection): - return v - return ObsCollection(v) + @computed_field + @cached_property + def obs_cfg(self) -> ObsCollection: + oc = ObsCollection() + for k, v in self.model_extra.get("obs_cfg", {}).items(): + oc.add_entry(k, v) + return oc @field_serializer("obs_cfg") def serialize_obs_cfg(self, obs_cfg: ObsCollection): - return obs_cfg.json_repr() + return obs_cfg.as_dict() - model_cfg: ModelCollection | dict = ModelCollection() - - @field_validator("model_cfg") - def validate_model_cfg(cls, v): - if isinstance(v, ModelCollection): - return v - return ModelCollection(v) + @computed_field + @cached_property + def model_cfg(self) -> ModelCollection: + mc = ModelCollection() + for k, v in self.model_extra.get("model_cfg", {}).items(): + mc.add_entry(k, v) + return mc @field_serializer("model_cfg") def serialize_model_cfg(self, model_cfg: ModelCollection): - return model_cfg.json_repr() + return model_cfg.as_dict() ########################### ## Methods diff --git a/pyaerocom/aeroval/superobs_engine.py b/pyaerocom/aeroval/superobs_engine.py index fdae8d8b7..46364caaa 100644 --- a/pyaerocom/aeroval/superobs_engine.py +++ b/pyaerocom/aeroval/superobs_engine.py @@ -75,8 +75,9 @@ def _run_var(self, model_name, obs_name, var_name, try_colocate_if_missing): coldata_files = [] coldata_resolutions = [] vert_codes = [] - obs_needed = self.cfg.obs_cfg[obs_name].obs_id - vert_code = self.cfg.obs_cfg.get_entry(obs_name).obs_vert_type + obs_entry = self.cfg.obs_cfg.get_entry(obs_name) + obs_needed = obs_entry.obs_id + vert_code = obs_entry.obs_vert_type for oname in obs_needed: fp, ts_type, vert_code = self._get_coldata_fileinfo( model_name, oname, var_name, try_colocate_if_missing diff --git a/pyaerocom/aeroval/utils.py b/pyaerocom/aeroval/utils.py index 8f680fe05..a4163ca3d 100644 --- a/pyaerocom/aeroval/utils.py +++ b/pyaerocom/aeroval/utils.py @@ -145,7 +145,8 @@ def compute_model_average_and_diversity( unit_out = get_variable(var_name).units - for mname in models: + for m in models: + mname = m.model_name logger.info(f"Adding {mname} ({var_name})") mid = cfg.cfg.model_cfg.get_entry(mname).model_id diff --git a/pyaerocom/colocation/colocator.py b/pyaerocom/colocation/colocator.py index 21fa3bf6e..744b52173 100644 --- a/pyaerocom/colocation/colocator.py +++ b/pyaerocom/colocation/colocator.py @@ -830,7 +830,9 @@ def _read_gridded(self, var_name, is_model): def _try_get_vert_which_alt(self, is_model, var_name): if is_model: if self.colocation_setup.obs_vert_type in self.colocation_setup.OBS_VERT_TYPES_ALT: - return self.OBS_VERT_TYPES_ALT[self.colocation_setup.obs_vert_type] + return self.colocation_setup.OBS_VERT_TYPES_ALT[ + self.colocation_setup.obs_vert_type + ] raise DataCoverageError(f"No alternative vert type found for {var_name}") def _check_remove_outliers_gridded(self, data, var_name, is_model): diff --git a/pyaerocom/extras/satellite_l2/aeolus_l2a.py b/pyaerocom/extras/satellite_l2/aeolus_l2a.py index 806a04931..8edb1eb45 100755 --- a/pyaerocom/extras/satellite_l2/aeolus_l2a.py +++ b/pyaerocom/extras/satellite_l2/aeolus_l2a.py @@ -715,7 +715,7 @@ def read_file( seconds_to_add = np.datetime64("2000-01-01T00:00:00") - np.datetime64( "1970-01-01T00:00:00" ) - seconds_to_add = seconds_to_add.astype(np.float_) + seconds_to_add = seconds_to_add.astype(np.float64) # the same can be achieved using pandas, but we stick to numpy here # base_time = pd.DatetimeIndex(['2000-01-01']) @@ -832,9 +832,9 @@ def read_file( # return as one multidimensional numpy array that can be put into self.data directly # (column wise because the column numbers do not match) index_pointer = 0 - data = np.empty([self._ROWNO, self._COLNO], dtype=np.float_) + data = np.empty([self._ROWNO, self._COLNO], dtype=np.float64) - for idx, _time in enumerate(file_data["time"].astype(np.float_)): + for idx, _time in enumerate(file_data["time"].astype(np.float64)): # skip times of profiles without a single valid extinction # the following is deprecated in current nympy # if _time in times_to_skip: @@ -912,7 +912,7 @@ def read_file( if index_pointer >= self._ROWNO: # add another array chunk to self.data - chunk = np.empty([self._CHUNKSIZE, self._COLNO], dtype=np.float_) + chunk = np.empty([self._CHUNKSIZE, self._COLNO], dtype=np.float64) data = np.append(data, chunk, axis=0) # return only the needed elements... @@ -1137,7 +1137,7 @@ def colocate( start = time.perf_counter() data = ungridded_data_obj._data - ret_data = np.empty([self._ROWNO, self._COLNO], dtype=np.float_) + ret_data = np.empty([self._ROWNO, self._COLNO], dtype=np.float64) index_counter = 0 cut_flag = True matching_indexes = [] @@ -1164,7 +1164,7 @@ def colocate( ret_data, np.zeros( [end_index - len(ret_data), self._COLNO], - dtype=np.float_, + dtype=np.float64, axis=0, ), ) @@ -1874,7 +1874,7 @@ def read_data_fields( seconds_to_add = np.datetime64("2000-01-01T00:00:00") - np.datetime64( "1970-01-01T00:00:00" ) - seconds_to_add = seconds_to_add.astype(np.float_) + seconds_to_add = seconds_to_add.astype(np.float64) # the same can be achieved using pandas, but we stick to numpy here # base_time = pd.DatetimeIndex(['2000-01-01']) @@ -1952,7 +1952,7 @@ def codarecord2pythonstruct(self, codaRec): out_struct[codaRec._registeredFields[idx]] = {} for str_name in dummy: out_struct[codaRec._registeredFields[idx]][str_name] = np.empty( - rec_length, dtype=np.float_ + rec_length, dtype=np.float64 ) for str_name in dummy: diff --git a/pyaerocom/extras/satellite_l2/base_reader.py b/pyaerocom/extras/satellite_l2/base_reader.py index ec95f79c7..eef613860 100644 --- a/pyaerocom/extras/satellite_l2/base_reader.py +++ b/pyaerocom/extras/satellite_l2/base_reader.py @@ -389,7 +389,7 @@ def read( index_store[_key] = file_data[_key].shape[0] input_shape = list(file_data[_key].shape) input_shape[0] = self._ROWNO - data_obj._data[_key] = np.empty(input_shape, dtype=np.float_) + data_obj._data[_key] = np.empty(input_shape, dtype=np.float64) if len(input_shape) == 1: data_obj._data[_key][0 : file_data[_key].shape[0]] = file_data[_key] elif len(input_shape) == 2: @@ -423,7 +423,7 @@ def read( if index_store[_key] + elements_to_add > data_obj._data[_key].shape[0]: current_shape = list(data_obj._data[_key].shape) current_shape[0] = current_shape[0] + self._CHUNKSIZE - tmp_data = np.empty(current_shape, dtype=np.float_) + tmp_data = np.empty(current_shape, dtype=np.float64) if len(current_shape) == 1: tmp_data[0 : data_obj._data[_key].shape[0]] = data_obj._data[_key] elif len(current_shape) == 2: diff --git a/pyaerocom/extras/satellite_l2/sentinel5p.py b/pyaerocom/extras/satellite_l2/sentinel5p.py index 53df591f2..b535bf826 100755 --- a/pyaerocom/extras/satellite_l2/sentinel5p.py +++ b/pyaerocom/extras/satellite_l2/sentinel5p.py @@ -124,7 +124,7 @@ def __init__( self.NAN_DICT.update({self._ALTITUDENAME: -1.0}) # scaling factors e.g. for unit conversion - self.SCALING_FACTORS[self._NO2NAME] = np.float_(6.022140857e19 / 1.0e15) + self.SCALING_FACTORS[self._NO2NAME] = np.float64(6.022140857e19 / 1.0e15) # the following defines necessary quality flags for a value to make it into the used data set # the flag needs to have a HIGHER or EQUAL value than the one listed here @@ -137,11 +137,11 @@ def __init__( self.CODA_READ_PARAMETERS[self._NO2NAME] = {} self.CODA_READ_PARAMETERS[self._NO2NAME]["metadata"] = {} self.CODA_READ_PARAMETERS[self._NO2NAME]["vars"] = {} - self.CODA_READ_PARAMETERS[self._NO2NAME]["time_offset"] = np.float_(24.0 * 60.0 * 60.0) + self.CODA_READ_PARAMETERS[self._NO2NAME]["time_offset"] = np.float64(24.0 * 60.0 * 60.0) self.CODA_READ_PARAMETERS[self._O3NAME] = {} self.CODA_READ_PARAMETERS[self._O3NAME]["metadata"] = {} self.CODA_READ_PARAMETERS[self._O3NAME]["vars"] = {} - self.CODA_READ_PARAMETERS[self._O3NAME]["time_offset"] = np.float_(24.0 * 60.0 * 60.0) + self.CODA_READ_PARAMETERS[self._O3NAME]["time_offset"] = np.float64(24.0 * 60.0 * 60.0) # self.CODA_READ_PARAMETERS[DATASET_NAME]['metadata'][_TIME_NAME] = 'PRODUCT/time_utc' self.CODA_READ_PARAMETERS[self._NO2NAME]["metadata"][self._TIME_NAME] = "PRODUCT/time" @@ -293,7 +293,7 @@ def __init__( self.CODA_READ_PARAMETERS[self._AVERAGINGKERNELNAME] = {} self.CODA_READ_PARAMETERS[self._AVERAGINGKERNELNAME]["metadata"] = {} self.CODA_READ_PARAMETERS[self._AVERAGINGKERNELNAME]["vars"] = {} - self.CODA_READ_PARAMETERS[self._AVERAGINGKERNELNAME]["time_offset"] = np.float_( + self.CODA_READ_PARAMETERS[self._AVERAGINGKERNELNAME]["time_offset"] = np.float64( 24.0 * 60.0 * 60.0 ) self.CODA_READ_PARAMETERS[self._AVERAGINGKERNELNAME]["metadata"][self._TIME_NAME] = ( @@ -330,7 +330,7 @@ def __init__( self.CODA_READ_PARAMETERS[self._LEVELSNAME] = {} self.CODA_READ_PARAMETERS[self._LEVELSNAME]["metadata"] = {} self.CODA_READ_PARAMETERS[self._LEVELSNAME]["vars"] = {} - self.CODA_READ_PARAMETERS[self._LEVELSNAME]["time_offset"] = np.float_( + self.CODA_READ_PARAMETERS[self._LEVELSNAME]["time_offset"] = np.float64( 24.0 * 60.0 * 60.0 ) self.CODA_READ_PARAMETERS[self._LEVELSNAME]["vars"][self._LEVELSNAME] = "PRODUCT/layer" @@ -338,7 +338,7 @@ def __init__( self.CODA_READ_PARAMETERS[self._GROUNDPRESSURENAME] = {} self.CODA_READ_PARAMETERS[self._GROUNDPRESSURENAME]["metadata"] = {} self.CODA_READ_PARAMETERS[self._GROUNDPRESSURENAME]["vars"] = {} - self.CODA_READ_PARAMETERS[self._GROUNDPRESSURENAME]["time_offset"] = np.float_( + self.CODA_READ_PARAMETERS[self._GROUNDPRESSURENAME]["time_offset"] = np.float64( 24.0 * 60.0 * 60.0 ) self.CODA_READ_PARAMETERS[self._GROUNDPRESSURENAME]["vars"][ @@ -349,7 +349,7 @@ def __init__( self.CODA_READ_PARAMETERS[self._TM5_TROPOPAUSE_LAYER_INDEX_NAME]["metadata"] = {} self.CODA_READ_PARAMETERS[self._TM5_TROPOPAUSE_LAYER_INDEX_NAME]["vars"] = {} self.CODA_READ_PARAMETERS[self._TM5_TROPOPAUSE_LAYER_INDEX_NAME]["time_offset"] = ( - np.float_(24.0 * 60.0 * 60.0) + np.float64(24.0 * 60.0 * 60.0) ) self.CODA_READ_PARAMETERS[self._TM5_TROPOPAUSE_LAYER_INDEX_NAME]["vars"][ self._TM5_TROPOPAUSE_LAYER_INDEX_NAME @@ -358,7 +358,7 @@ def __init__( self.CODA_READ_PARAMETERS[self._TM5_CONSTANT_A_NAME] = {} self.CODA_READ_PARAMETERS[self._TM5_CONSTANT_A_NAME]["metadata"] = {} self.CODA_READ_PARAMETERS[self._TM5_CONSTANT_A_NAME]["vars"] = {} - self.CODA_READ_PARAMETERS[self._TM5_CONSTANT_A_NAME]["time_offset"] = np.float_( + self.CODA_READ_PARAMETERS[self._TM5_CONSTANT_A_NAME]["time_offset"] = np.float64( 24.0 * 60.0 * 60.0 ) self.CODA_READ_PARAMETERS[self._TM5_CONSTANT_A_NAME]["vars"][ @@ -367,7 +367,7 @@ def __init__( self.CODA_READ_PARAMETERS[self._TM5_CONSTANT_B_NAME] = {} self.CODA_READ_PARAMETERS[self._TM5_CONSTANT_B_NAME]["metadata"] = {} self.CODA_READ_PARAMETERS[self._TM5_CONSTANT_B_NAME]["vars"] = {} - self.CODA_READ_PARAMETERS[self._TM5_CONSTANT_B_NAME]["time_offset"] = np.float_( + self.CODA_READ_PARAMETERS[self._TM5_CONSTANT_B_NAME]["time_offset"] = np.float64( 24.0 * 60.0 * 60.0 ) self.CODA_READ_PARAMETERS[self._TM5_CONSTANT_B_NAME]["vars"][ @@ -514,7 +514,7 @@ def read_file( seconds_to_add = np.datetime64("2010-01-01T00:00:00") - np.datetime64( "1970-01-01T00:00:00" ) - seconds_to_add = seconds_to_add.astype(np.float_) + seconds_to_add = seconds_to_add.astype(np.float64) # the same can be achieved using pandas, but we stick to numpy here # base_time = pd.DatetimeIndex(['2000-01-01']) @@ -591,7 +591,7 @@ def read_file( # return as one multidimensional numpy array that can be put into self.data directly # (column wise because the column numbers do not match) index_pointer = 0 - data = np.empty([self._ROWNO, colno], dtype=np.float_) + data = np.empty([self._ROWNO, colno], dtype=np.float64) # loop over the times for idx, _time in enumerate(file_data[self._TIME_OFFSET_NAME]): # loop over the number of ground pixels @@ -610,9 +610,9 @@ def read_file( # time can be a scalar... try: - data[index_pointer, self._TIMEINDEX] = _time.astype(np.float_) + data[index_pointer, self._TIMEINDEX] = _time.astype(np.float64) except Exception: - data[index_pointer, self._TIMEINDEX] = _time[_index].astype(np.float_) + data[index_pointer, self._TIMEINDEX] = _time[_index].astype(np.float64) # loop over the variables for var in vars_to_read_in: @@ -633,7 +633,9 @@ def read_file( if index_pointer >= self._ROWNO: # add another array chunk to self.data data = np.append( - data, np.empty([self._CHUNKSIZE, self._COLNO], dtype=np.float_), axis=0 + data, + np.empty([self._CHUNKSIZE, self._COLNO], dtype=np.float64), + axis=0, ) # unneeded after update (_ROWNO is now dynamic and returns shape index 0 of numpy array) # self._ROWNO += self._CHUNKSIZE diff --git a/pyaerocom/io/gaw/reader.py b/pyaerocom/io/gaw/reader.py index 2e61e992b..0f519ff77 100644 --- a/pyaerocom/io/gaw/reader.py +++ b/pyaerocom/io/gaw/reader.py @@ -136,7 +136,7 @@ def read_file(self, filename, vars_to_retrieve=None, vars_as_series=False): if np.shape(data[i])[0] != 10: del data[i] - data = np.array(data) + data = np.array(data, dtype=object) # names of the columns in the file that I want to use file_vars = file_vars[5:9] @@ -233,8 +233,7 @@ def read_file(self, filename, vars_to_retrieve=None, vars_as_series=False): if any("99:99" in s for s in data[:, 1]): datestring = data[:, 0] else: - datestring = np.core.defchararray.add(data[:, 0], "T") - datestring = np.core.defchararray.add(datestring, data[:, 1]) + datestring = data[:, 0] + "T" + data[:, 1] data_out["dtime"] = datestring.astype("datetime64[s]") # Replace invalid measurements with nan values diff --git a/pyaerocom/io/pyaro/pyaro_config.py b/pyaerocom/io/pyaro/pyaro_config.py index 4891cff8e..a05d4f747 100644 --- a/pyaerocom/io/pyaro/pyaro_config.py +++ b/pyaerocom/io/pyaro/pyaro_config.py @@ -3,7 +3,7 @@ import logging from importlib import resources from pathlib import Path -from typing import ClassVar +from typing import ClassVar, Any import yaml from pydantic import BaseModel, ConfigDict @@ -15,6 +15,8 @@ # TODO Check a validator if extra/kwarg is serializable. Either in json_repr or as a @field_validator on extra +FilterArgs = dict[str, Any] + class PyaroConfig(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") @@ -30,7 +32,7 @@ class PyaroConfig(BaseModel): name: str data_id: str filename_or_obj_or_url: str | list[str] | Path | list[Path] - filters: dict[str, dict[str, list[str]] | dict[str, list[tuple]]] + filters: dict[str, FilterArgs] name_map: dict[str, str] | None = None # no Unit conversion option ########################## diff --git a/pyaerocom/io/read_aeronet_invv3.py b/pyaerocom/io/read_aeronet_invv3.py index eec1868a3..61909ad75 100644 --- a/pyaerocom/io/read_aeronet_invv3.py +++ b/pyaerocom/io/read_aeronet_invv3.py @@ -196,7 +196,7 @@ def read_file(self, filename, vars_to_retrieve=None, vars_as_series=False): # copy the data fields that are available (rest will be filled # below) for var, idx in vars_available.items(): - val = np.float_(dummy_arr[idx]) + val = np.float64(dummy_arr[idx]) if val == self.NAN_VAL: val = np.nan data_out[var].append(val) diff --git a/pyaerocom/io/read_aeronet_sdav3.py b/pyaerocom/io/read_aeronet_sdav3.py index 85ecf90b1..dcbd0b230 100644 --- a/pyaerocom/io/read_aeronet_sdav3.py +++ b/pyaerocom/io/read_aeronet_sdav3.py @@ -199,7 +199,7 @@ def read_file(self, filename, vars_to_retrieve=None, vars_as_series=False): # copy the data fields for var, idx in vars_available.items(): - val = np.float_(dummy_arr[idx]) + val = np.float64(dummy_arr[idx]) if val == self.NAN_VAL: val = np.nan data_out[var].append(val) diff --git a/pyaerocom/io/read_aeronet_sunv3.py b/pyaerocom/io/read_aeronet_sunv3.py index fe99b1a71..b56fb0c5d 100644 --- a/pyaerocom/io/read_aeronet_sunv3.py +++ b/pyaerocom/io/read_aeronet_sunv3.py @@ -266,7 +266,7 @@ def read_file(self, filename, vars_to_retrieve=None, vars_as_series=False): data_out["dtime"].append(np.datetime64(datestring)) for var, idx in vars_available.items(): - val = np.float_(dummy_arr[idx]) + val = np.float64(dummy_arr[idx]) if val == self.NAN_VAL: val = np.nan data_out[var].append(val) diff --git a/pyaerocom/io/read_eea_aqerep_base.py b/pyaerocom/io/read_eea_aqerep_base.py index 576eed027..32aef8841 100644 --- a/pyaerocom/io/read_eea_aqerep_base.py +++ b/pyaerocom/io/read_eea_aqerep_base.py @@ -112,16 +112,16 @@ class ReadEEAAQEREPBase(ReadUngriddedBase): # conversion factor between concX and vmrX CONV_FACTOR = {} - CONV_FACTOR["concSso2"] = np.float_(0.50052292274792) - CONV_FACTOR["concNno2"] = np.float_(0.3044517868011477) - CONV_FACTOR["concNno"] = np.float_(0.466788868521913) - CONV_FACTOR["vmro3"] = np.float_( + CONV_FACTOR["concSso2"] = np.float64(0.50052292274792) + CONV_FACTOR["concNno2"] = np.float64(0.3044517868011477) + CONV_FACTOR["concNno"] = np.float64(0.466788868521913) + CONV_FACTOR["vmro3"] = np.float64( 0.493 ) # retrieved using STD atmosphere from geonum and pya.mathutils.concx_to_vmrx - CONV_FACTOR["vmro3max"] = np.float_( + CONV_FACTOR["vmro3max"] = np.float64( 0.493 ) # retrieved using STD atmosphere from geonum and pya.mathutils.concx_to_vmrx - CONV_FACTOR["vmrno2"] = np.float_( + CONV_FACTOR["vmrno2"] = np.float64( 0.514 ) # retrieved using STD atmosphere from geonum and pya.mathutils.concx_to_vmrx @@ -315,7 +315,7 @@ def read_file(self, filename, var_name, vars_as_series=False): if idx in time_indexes: data_dict[header[idx]] = np.zeros(self.MAX_LINES_TO_READ, dtype="datetime64[s]") else: - data_dict[header[idx]] = np.empty(self.MAX_LINES_TO_READ, dtype=np.float_) + data_dict[header[idx]] = np.empty(self.MAX_LINES_TO_READ, dtype=np.float64) # read the data... # DE,http://gdi.uba.de/arcgis/rest/services/inspire/DE.UBA.AQD,NET.DE_BB,STA.DE_DEBB054,DEBB054,SPO.DE_DEBB054_PM2_dataGroup1,SPP.DE_DEBB054_PM2_automatic_light-scat_Duration-30minute,SAM.DE_DEBB054_2,PM2.5,http://dd.eionet.europa.eu/vocabulary/aq/pollutant/6001,hour,3.2000000000,µg/m3,2020-01-04 00:00:00 +01:00,2020-01-04 01:00:00 +01:00,1,2 @@ -353,7 +353,7 @@ def read_file(self, filename, var_name, vars_as_series=False): # data is not a time # sometimes there's no value in the file. Set that to nan try: - data_dict[header[idx]][lineidx] = np.float_(rows[idx]) + data_dict[header[idx]][lineidx] = np.float64(rows[idx]) except (ValueError, IndexError): data_dict[header[idx]][lineidx] = np.nan diff --git a/pyaerocom/sample_data_access/minimal_dataset.py b/pyaerocom/sample_data_access/minimal_dataset.py index f76a39ca6..c887a81cf 100644 --- a/pyaerocom/sample_data_access/minimal_dataset.py +++ b/pyaerocom/sample_data_access/minimal_dataset.py @@ -10,7 +10,7 @@ __all__ = ["download_minimal_dataset"] #: tarfile to download -DEFAULT_TESTDATA_FILE = "testdata-minimal.tar.gz.20240722" +DEFAULT_TESTDATA_FILE = "testdata-minimal.tar.gz.20241120" minimal_dataset = pooch.create( path=const.OUTPUTDIR, # ~/MyPyaerocom/ @@ -24,6 +24,7 @@ "testdata-minimal.tar.gz.20231019": "md5:f8912ee83d6749fb2a9b1eda1d664ca2", "testdata-minimal.tar.gz.20231116": "md5:5da747f6596817295ba7affe3402b722", "testdata-minimal.tar.gz.20240722": "md5:7d933901c6d273d012f132c60df086cc", + "testdata-minimal.tar.gz.20241120": "md5:4d2bc1782b1f468321817139d327e014", }, ) diff --git a/pyaerocom_env.yml b/pyaerocom_env.yml index 560f153be..1ec9b1104 100644 --- a/pyaerocom_env.yml +++ b/pyaerocom_env.yml @@ -3,7 +3,7 @@ channels: - conda-forge dependencies: - - iris >=3.8.1 + - iris >=3.11.0 - xarray >=2022.12.0 - cartopy >=0.21.1 - matplotlib-base >=3.7.1 @@ -31,7 +31,7 @@ dependencies: - pip: - geojsoncontour - geocoder_reverse_natural_earth >= 0.0.2 - - pyaro >= 0.0.12 + - pyaro >= 0.0.14 - aerovaldb >= 0.1.1 ## testing - pytest >=7.4 diff --git a/pyproject.toml b/pyproject.toml index 23ca3ca87..1c504c142 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,13 +22,13 @@ classifiers = [ requires-python = ">=3.10" dependencies = [ "aerovaldb>=0.1.1", - "scitools-iris>=3.8.1", + "scitools-iris>=3.11.0", "xarray>=2022.12.0", "cartopy>=0.21.1", "matplotlib>=3.7.1", "scipy>=1.10.1", "pandas>=1.5.3", - "numpy>=1.24.4, <2.0.0", + "numpy>=1.24.4", "seaborn>=0.12.2", "dask", "geonum==1.5.0", @@ -44,10 +44,10 @@ dependencies = [ 'importlib-resources>=5.10; python_version < "3.11"', 'typing-extensions>=4.6.1; python_version < "3.11"', # https://github.com/SciTools/cf-units/issues/218 - 'cf-units>=3.1', + 'cf-units>=3.3.0', "pydantic>=2.7.1", "pyproj>=3.0.0", - "pyaro>=0.0.12", + "pyaro>=0.0.14", "pooch>=1.7.0", "psutil>=5.0.0", ] @@ -215,7 +215,7 @@ extras = setenv = UDUNITS2_XML_PATH=/usr/share/xml/udunits/udunits2-common.xml deps = - scitools-iris ==3.8.1; python_version < "3.11" + scitools-iris ==3.11.0; python_version < "3.11" cartopy ==0.21.1; python_version < "3.11" matplotlib ==3.7.1; python_version < "3.11" scipy ==1.10.1; python_version < "3.11" @@ -226,9 +226,9 @@ deps = tomli ==2.0.1; python_version < "3.11" importlib-resources ==5.10; python_version < "3.11" typing-extensions ==4.6.1; python_version < "3.11" - cf-units ==3.1; python_version < "3.11" + cf-units ==3.3.0; python_version < "3.11" pydantic ==2.7.1; python_version < "3.11" - pyaro == 0.0.10; python_version < "3.11" + pyaro == 0.0.14; python_version < "3.11" pooch ==1.7.0; python_version < "3.11" xarray ==2022.12.0; python_version < "3.11" pandas ==1.5.3; python_version < "3.11" diff --git a/scripts/aeolus2netcdf.py b/scripts/aeolus2netcdf.py index dfb626b6b..65ebb6689 100644 --- a/scripts/aeolus2netcdf.py +++ b/scripts/aeolus2netcdf.py @@ -60,10 +60,10 @@ def main(): help="set path of CODA_DEFINITION env variable", default="/lustre/storeA/project/aerocom/aerocom1/ADM_CALIPSO_TEST/", ) - parser.add_argument("--latmin", help="min latitude to return", default=np.float_(30.0)) - parser.add_argument("--latmax", help="max latitude to return", default=np.float_(76.0)) - parser.add_argument("--lonmin", help="min longitude to return", default=np.float_(-30.0)) - parser.add_argument("--lonmax", help="max longitude to return", default=np.float_(45.0)) + parser.add_argument("--latmin", help="min latitude to return", default=np.float64(30.0)) + parser.add_argument("--latmax", help="max latitude to return", default=np.float64(76.0)) + parser.add_argument("--lonmin", help="min longitude to return", default=np.float64(-30.0)) + parser.add_argument("--lonmax", help="max longitude to return", default=np.float64(45.0)) parser.add_argument( "--dir", help="work on all files below this directory", @@ -186,16 +186,16 @@ def main(): options["tempdir"] = args.tempdir if args.latmin: - options["latmin"] = np.float_(args.latmin) + options["latmin"] = np.float64(args.latmin) if args.latmax: - options["latmax"] = np.float_(args.latmax) + options["latmax"] = np.float64(args.latmax) if args.lonmin: - options["lonmin"] = np.float_(args.lonmin) + options["lonmin"] = np.float64(args.lonmin) if args.lonmax: - options["lonmax"] = np.float_(args.lonmax) + options["lonmax"] = np.float64(args.lonmax) if args.emep: options["emepflag"] = args.emep @@ -437,7 +437,7 @@ def main(): nc_longitudes = nc_data["lon"].data nc_lev_no = len(nc_data["lev"]) nc_colocated_data = np.zeros( - [aeolus_profile_no * nc_lev_no, obj._COLNO], dtype=np.float_ + [aeolus_profile_no * nc_lev_no, obj._COLNO], dtype=np.float64 ) # locate current rounded Aeolus time in netcdf file diff --git a/scripts/read_sentinel5p.py b/scripts/read_sentinel5p.py index c87b3c548..060bce7f9 100644 --- a/scripts/read_sentinel5p.py +++ b/scripts/read_sentinel5p.py @@ -52,10 +52,10 @@ def main(): ) # parser.add_argument("--codadef", help="set path of CODA_DEFINITION env variable", # default='/lustre/storeA/project/aerocom/aerocom1/ADM_CALIPSO_TEST/') - parser.add_argument("--latmin", help="min latitude to return", default=np.float_(30.0)) - parser.add_argument("--latmax", help="max latitude to return", default=np.float_(76.0)) - parser.add_argument("--lonmin", help="min longitude to return", default=np.float_(-30.0)) - parser.add_argument("--lonmax", help="max longitude to return", default=np.float_(45.0)) + parser.add_argument("--latmin", help="min latitude to return", default=np.float64(30.0)) + parser.add_argument("--latmax", help="max latitude to return", default=np.float64(76.0)) + parser.add_argument("--lonmin", help="min longitude to return", default=np.float64(-30.0)) + parser.add_argument("--lonmax", help="max longitude to return", default=np.float64(45.0)) parser.add_argument( "--dir", help="work on all files below this directory", @@ -152,16 +152,16 @@ def main(): options["tempdir"] = args.tempdir if args.latmin: - options["latmin"] = np.float_(args.latmin) + options["latmin"] = np.float64(args.latmin) if args.latmax: - options["latmax"] = np.float_(args.latmax) + options["latmax"] = np.float64(args.latmax) if args.lonmin: - options["lonmin"] = np.float_(args.lonmin) + options["lonmin"] = np.float64(args.lonmin) if args.lonmax: - options["lonmax"] = np.float_(args.lonmax) + options["lonmax"] = np.float64(args.lonmax) if args.emep: options["emepflag"] = args.emep diff --git a/tests/aeroval/test_aeroval_HIGHLEV.py b/tests/aeroval/test_aeroval_HIGHLEV.py index 558983fa7..c2b07ed99 100644 --- a/tests/aeroval/test_aeroval_HIGHLEV.py +++ b/tests/aeroval/test_aeroval_HIGHLEV.py @@ -113,11 +113,11 @@ def test_reanalyse_existing(eval_config: dict, reanalyse_existing: bool): @pytest.mark.parametrize("cfg", ["cfgexp4"]) def test_superobs_different_resolutions(eval_config: dict): cfg = EvalSetup(**eval_config) - cfg.model_cfg["TM5-AP3-CTRL"].model_ts_type_read = None - cfg.model_cfg["TM5-AP3-CTRL"].flex_ts_type = True + cfg.model_cfg.get_entry("TM5-AP3-CTRL").model_ts_type_read = None + cfg.model_cfg.get_entry("TM5-AP3-CTRL").flex_ts_type = True - cfg.obs_cfg["AERONET-Sun"].ts_type = "daily" - cfg.obs_cfg["AERONET-SDA"].ts_type = "monthly" + cfg.obs_cfg.get_entry("AERONET-Sun").ts_type = "daily" + cfg.obs_cfg.get_entry("AERONET-SDA").ts_type = "monthly" proc = ExperimentProcessor(cfg) proc.exp_output.delete_experiment_data(also_coldata=True) diff --git a/tests/aeroval/test_aux_io_helpers.py b/tests/aeroval/test_aux_io_helpers.py index cbe57e85e..3401cbb81 100644 --- a/tests/aeroval/test_aux_io_helpers.py +++ b/tests/aeroval/test_aux_io_helpers.py @@ -1,6 +1,7 @@ from pathlib import Path from textwrap import dedent +from pydantic import ValidationError from pytest import mark, param, raises from pyaerocom.aeroval.aux_io_helpers import ReadAuxHandler, check_aux_info @@ -52,20 +53,26 @@ def test_check_aux_info(fun, vars_required: list[str], funcs: dict): @mark.parametrize( - "fun,vars_required,funcs,error", + "fun,vars_required,funcs,error,", [ - param(None, [], {}, "failed to retrieve aux func", id="no func"), + param( + None, + [], + {}, + "2 validation errors for _AuxReadSpec", + id="no func", + ), param( None, [42], {}, - "not all items are str type in input list [42]", + "3 validation errors for _AuxReadSpec", id="bad type vars_required", ), ], ) def test_check_aux_info_error(fun, vars_required: list[str], funcs: dict, error: str): - with raises(ValueError) as e: + with raises(ValidationError) as e: check_aux_info(fun, vars_required, funcs) - assert str(e.value) == error + assert error in str(e.value) diff --git a/tests/aeroval/test_collections.py b/tests/aeroval/test_collections.py index 5f2a98f7c..1327ca6b8 100644 --- a/tests/aeroval/test_collections.py +++ b/tests/aeroval/test_collections.py @@ -1,28 +1,91 @@ from pyaerocom.aeroval.collections import ObsCollection, ModelCollection +from pyaerocom.aeroval.obsentry import ObsEntry +from pyaerocom.aeroval.modelentry import ModelEntry +import pytest -def test_obscollection(): - oc = ObsCollection(model1=dict(obs_id="bla", obs_vars="od550aer", obs_vert_type="Column")) +def test_obscollection_init_and_add_entry(): + oc = ObsCollection() + oc.add_entry("model1", dict(obs_id="bla", obs_vars="od550aer", obs_vert_type="Column")) assert oc - oc["AN-EEA-MP"] = dict( - is_superobs=True, - obs_id=("AirNow", "EEA-NRT-rural", "MarcoPolo"), - obs_vars=["concpm10", "concpm25", "vmro3", "vmrno2"], - obs_vert_type="Surface", + oc.add_entry( + "AN-EEA-MP", + dict( + is_superobs=True, + obs_id=("AirNow", "EEA-NRT-rural", "MarcoPolo"), + obs_vars=["concpm10", "concpm25", "vmro3", "vmrno2"], + obs_vert_type="Surface", + ), ) - assert "AN-EEA-MP" in oc + assert "AN-EEA-MP" in oc.keylist() -def test_modelcollection(): - mc = ModelCollection(model1=dict(model_id="bla", obs_vars="od550aer", obs_vert_type="Column")) +def test_obscollection_add_and_get_entry(): + collection = ObsCollection() + entry = ObsEntry(obs_id="obs1", obs_vars=("var1",)) + collection.add_entry("key1", entry) + retrieved_entry = collection.get_entry("key1") + assert retrieved_entry == entry + + +def test_obscollection_add_and_remove_entry(): + collection = ObsCollection() + entry = ObsEntry(obs_id="obs1", obs_vars=("var1",)) + collection.add_entry("key1", entry) + collection.remove_entry("key1") + with pytest.raises(KeyError): + collection.get_entry("key1") + + +def test_obscollection_get_web_interface_name(): + collection = ObsCollection() + entry = ObsEntry(obs_id="obs1", obs_vars=("var1",), web_interface_name="web_name") + collection.add_entry("key1", entry) + assert collection.get_web_interface_name("key1") == "web_name" + + +def test_obscollection_all_vert_types(): + collection = ObsCollection() + entry1 = ObsEntry( + obs_id="obs1", obs_vars=("var1",), obs_vert_type="Surface" + ) # Assuming ObsEntry has an obs_vert_type attribute + entry2 = ObsEntry(obs_id="obs2", obs_vars=("var2",), obs_vert_type="Profile") + collection.add_entry("key1", entry1) + collection.add_entry("key2", entry2) + assert set(collection.all_vert_types) == {"Surface", "Profile"} + + +def test_modelcollection_init_and_add_entry(): + mc = ModelCollection() + mc.add_entry("model1", dict(model_id="bla", obs_vars="od550aer", obs_vert_type="Column")) assert mc - mc["ECMWF_OSUITE"] = dict( - model_id="ECMWF_OSUITE", - obs_vars=["concpm10"], - obs_vert_type="Surface", + mc.add_entry( + "ECMWF_OSUITE", + dict( + model_id="ECMWF_OSUITE", + obs_vars=["concpm10"], + obs_vert_type="Surface", + ), ) - assert "ECMWF_OSUITE" in mc + assert "ECMWF_OSUITE" in mc.keylist() + + +def test_modelcollection_add_and_get_entry(): + collection = ModelCollection() + entry = ModelEntry(model_id="mod1") + collection.add_entry("key1", entry) + retrieved_entry = collection.get_entry("key1") + assert retrieved_entry == entry + + +def test_modelcollection_add_and_remove_entry(): + collection = ModelCollection() + entry = ModelEntry(model_id="obs1") + collection.add_entry("key1", entry) + collection.remove_entry("key1") + with pytest.raises(KeyError): + collection.get_entry("key1") diff --git a/tests/aeroval/test_experiment_output.py b/tests/aeroval/test_experiment_output.py index a74498eb1..522516948 100644 --- a/tests/aeroval/test_experiment_output.py +++ b/tests/aeroval/test_experiment_output.py @@ -293,10 +293,10 @@ def test_Experiment_Output_clean_json_files_CFG1(eval_config: dict): @pytest.mark.parametrize("cfg", ["cfgexp1"]) def test_Experiment_Output_clean_json_files_CFG1_INVALIDMOD(eval_config: dict): cfg = EvalSetup(**eval_config) - cfg.model_cfg["mod1"] = cfg.model_cfg["TM5-AP3-CTRL"] + cfg.model_cfg.add_entry("mod1", cfg.model_cfg.get_entry("TM5-AP3-CTRL")) proc = ExperimentProcessor(cfg) proc.run() - del cfg.model_cfg["mod1"] + cfg.model_cfg.remove_entry("mod1") modified = proc.exp_output.clean_json_files() assert len(modified) == 13 @@ -305,10 +305,9 @@ def test_Experiment_Output_clean_json_files_CFG1_INVALIDMOD(eval_config: dict): @pytest.mark.parametrize("cfg", ["cfgexp1"]) def test_Experiment_Output_clean_json_files_CFG1_INVALIDOBS(eval_config: dict): cfg = EvalSetup(**eval_config) - cfg.obs_cfg["obs1"] = cfg.obs_cfg["AERONET-Sun"] + cfg.obs_cfg.add_entry("obs1", cfg.obs_cfg.get_entry("AERONET-Sun")) proc = ExperimentProcessor(cfg) proc.run() - del cfg.obs_cfg["obs1"] modified = proc.exp_output.clean_json_files() assert len(modified) == 13 @@ -354,7 +353,7 @@ def test_Experiment_Output_drop_stats_and_decimals( stats_decimals, ) cfg = EvalSetup(**eval_config) - cfg.model_cfg["mod1"] = cfg.model_cfg["TM5-AP3-CTRL"] + cfg.model_cfg.add_entry("mod1", cfg.model_cfg.get_entry("TM5-AP3-CTRL")) proc = ExperimentProcessor(cfg) proc.run() path = Path(proc.exp_output.exp_dir) diff --git a/tests/aeroval/test_experiment_processor.py b/tests/aeroval/test_experiment_processor.py index 5374e831a..af0be63c0 100644 --- a/tests/aeroval/test_experiment_processor.py +++ b/tests/aeroval/test_experiment_processor.py @@ -31,6 +31,7 @@ def test_ExperimentProcessor_run(processor: ExperimentProcessor): processor.run() +# Temporary until ObsCollection implemented simiarly then can run same test @geojson_unavail @pytest.mark.parametrize( "cfg,kwargs,error", @@ -47,7 +48,9 @@ def test_ExperimentProcessor_run(processor: ExperimentProcessor): ), ], ) -def test_ExperimentProcessor_run_error(processor: ExperimentProcessor, kwargs: dict, error: str): +def test_ExperimentProcessor_run_error_obs_name( + processor: ExperimentProcessor, kwargs: dict, error: str +): with pytest.raises(KeyError) as e: processor.run(**kwargs) assert str(e.value) == error diff --git a/tests/aeroval/test_helpers.py b/tests/aeroval/test_helpers.py index 085efe943..c2af1998a 100644 --- a/tests/aeroval/test_helpers.py +++ b/tests/aeroval/test_helpers.py @@ -137,6 +137,6 @@ def test__get_min_max_year_periods_error(): @pytest.mark.parametrize("cfg", ["cfgexp1"]) def test_make_dummy_model(eval_config: dict): cfg = EvalSetup(**eval_config) - assert cfg.obs_cfg["AERONET-Sun"] + assert cfg.obs_cfg.get_entry("AERONET-Sun") model_id = make_dummy_model(["AERONET-Sun"], cfg) assert model_id == "dummy_model" diff --git a/tests/aeroval/test_json_utils.py b/tests/aeroval/test_json_utils.py index 5969e449f..c2d06a434 100644 --- a/tests/aeroval/test_json_utils.py +++ b/tests/aeroval/test_json_utils.py @@ -29,13 +29,13 @@ def json_path(tmp_path: Path) -> Path: id="single float", ), pytest.param( - [np.float_(2.3456789), np.float32(3.456789012)], + [np.float64(2.3456789), np.float32(3.456789012)], 3, [2.346, pytest.approx(3.457, 1e-3)], id="np.float list", ), pytest.param( - (np.float128(4.567890123), np.float_(5.6789012345)), + (np.float128(4.567890123), np.float64(5.6789012345)), 5, [pytest.approx(4.56789, 1e-5), 5.67890], id="np.float tuple", diff --git a/tests/fixtures/collocated_data.py b/tests/fixtures/collocated_data.py index 445271d76..a975adaf0 100644 --- a/tests/fixtures/collocated_data.py +++ b/tests/fixtures/collocated_data.py @@ -246,7 +246,7 @@ def _create_fake_partial_trends_coldata_3d(colocate_time: bool = False): if d in dtimes[i]: data[0, j, i] = 1 else: - data[0, j, i] = np.NaN + data[0, j, i] = np.nan meta = { "data_source": ["fakeobs", "fakemod"],