diff --git a/CHANGES.rst b/CHANGES.rst index 85cb2dc15..44b75d921 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -30,6 +30,7 @@ Bug fixes * Fixed ``xclim.indices.run_length.lazy_indexing`` which would sometimes trigger the loading of auxiliary coordinates. (:issue:`1483`, :pull:`1484`). * Indicators ``snd_season_length`` and ``snw_season_length`` will return 0 instead of NaN if all inputs have a (non-NaN) zero snow depth (or water-equivalent thickness). (:pull:`1492`, :issue:`1491`) * Fixed a bug in the `pytest` configuration that could prevent testing data caching from occurring in systems where the platform-dependent cache directory is not found in the user's home. (:issue:`1468`, :pull:`1473`). +* Fix ``xclim.core.dataflags.data_flags`` variable name generation (:pull:`1507`). Breaking changes ^^^^^^^^^^^^^^^^ diff --git a/tests/test_flags.py b/tests/test_flags.py index e0b60ba94..4560ecc88 100644 --- a/tests/test_flags.py +++ b/tests/test_flags.py @@ -151,3 +151,20 @@ def test_era5_ecad_qc_flag(self, open_dataset): df_flagged = df.ecad_compliant(bad_ds) np.testing.assert_array_equal(df_flagged.ecad_qc_flag, False) + + def test_names(self, pr_series): + pr = pr_series(np.zeros(365), start="1971-01-01") + flgs = df.data_flags( + pr, + flags={ + "values_op_thresh_repeating_for_n_or_more_days": { + "op": "==", + "n": 5, + "thresh": "-5.1 mm d-1", + } + }, + ) + assert ( + list(flgs.data_vars.keys())[0] + == "values_eq_minus5point1_repeating_for_5_or_more_days" + ) diff --git a/xclim/core/dataflags.py b/xclim/core/dataflags.py index e777d5d4c..86f0606a0 100644 --- a/xclim/core/dataflags.py +++ b/xclim/core/dataflags.py @@ -12,9 +12,9 @@ from typing import Sequence import numpy as np -import pint import xarray +from ..indices.generic import binary_ops from ..indices.run_length import suspicious_run from .calendar import climatological_mean_doy, within_bnds_doy from .formatting import update_xclim_history @@ -23,6 +23,7 @@ VARIABLES, InputKind, MissingVariableError, + Quantified, infer_kind_from_parameter, raise_warn_or_log, ) @@ -41,6 +42,8 @@ class DataQualityException(Exception): Message prepended to the error messages. """ + flag_array: xarray.Dataset = None + def __init__( self, flag_array: xarray.Dataset, @@ -81,10 +84,20 @@ def __str__(self): ] -def register_methods(func): - """Summarize all methods used in dataflags checks.""" - _REGISTRY[func.__name__] = func - return func +def register_methods(variable_name=None): + """Register a data flag functioné. + + Argument can be the output variable name template. The template may use any of the stringable input arguments. + If not given, the function name is used instead, which may create variable conflicts. + """ + + def _register_methods(func): + """Summarize all methods used in dataflags checks.""" + func.__dict__["variable_name"] = variable_name or func.__name__ + _REGISTRY[func.__name__] = func + return func + + return _register_methods def _sanitize_attrs(da: xarray.DataArray) -> xarray.DataArray: @@ -97,7 +110,7 @@ def _sanitize_attrs(da: xarray.DataArray) -> xarray.DataArray: return da -@register_methods +@register_methods() @update_xclim_history @declare_units(tasmax="[temperature]", tasmin="[temperature]") def tasmax_below_tasmin( @@ -130,7 +143,7 @@ def tasmax_below_tasmin( return tasmax_lt_tasmin -@register_methods +@register_methods() @update_xclim_history @declare_units(tas="[temperature]", tasmax="[temperature]") def tas_exceeds_tasmax( @@ -163,7 +176,7 @@ def tas_exceeds_tasmax( return tas_gt_tasmax -@register_methods +@register_methods() @update_xclim_history @declare_units(tas="[temperature]", tasmin="[temperature]") def tas_below_tasmin( @@ -195,11 +208,11 @@ def tas_below_tasmin( return tas_lt_tasmin -@register_methods +@register_methods() @update_xclim_history -@declare_units(da="[temperature]") +@declare_units(da="[temperature]", thresh="[temperature]") def temperature_extremely_low( - da: xarray.DataArray, *, thresh: str = "-90 degC" + da: xarray.DataArray, *, thresh: Quantified = "-90 degC" ) -> xarray.DataArray: """Check if temperatures values are below -90 degrees Celsius for any given day. @@ -229,11 +242,11 @@ def temperature_extremely_low( return extreme_low -@register_methods +@register_methods() @update_xclim_history -@declare_units(da="[temperature]") +@declare_units(da="[temperature]", thresh="[temperature]") def temperature_extremely_high( - da: xarray.DataArray, *, thresh: str = "60 degC" + da: xarray.DataArray, *, thresh: Quantified = "60 degC" ) -> xarray.DataArray: """Check if temperatures values exceed 60 degrees Celsius for any given day. @@ -263,7 +276,7 @@ def temperature_extremely_high( return extreme_high -@register_methods +@register_methods() @update_xclim_history def negative_accumulation_values( da: xarray.DataArray, @@ -293,11 +306,11 @@ def negative_accumulation_values( return negative_accumulations -@register_methods +@register_methods() @update_xclim_history -@declare_units(da="[precipitation]") +@declare_units(da="[precipitation]", thresh="[precipitation]") def very_large_precipitation_events( - da: xarray.DataArray, *, thresh="300 mm d-1" + da: xarray.DataArray, *, thresh: Quantified = "300 mm d-1" ) -> xarray.DataArray: """Check if precipitation values exceed 300 mm/day for any given day. @@ -329,10 +342,10 @@ def very_large_precipitation_events( return very_large_events -@register_methods +@register_methods("values_{op}_{thresh}_repeating_for_{n}_or_more_days") @update_xclim_history def values_op_thresh_repeating_for_n_or_more_days( - da: xarray.DataArray, *, n: int, thresh: str, op: str = "==" + da: xarray.DataArray, *, n: int, thresh: Quantified, op: str = "==" ) -> xarray.DataArray: """Check if array values repeat at a given threshold for `N` or more days. @@ -377,11 +390,14 @@ def values_op_thresh_repeating_for_n_or_more_days( return repetitions -@register_methods +@register_methods() @update_xclim_history -@declare_units(da="[speed]") +@declare_units(da="[speed]", lower="[speed]", upper="[speed]") def wind_values_outside_of_bounds( - da: xarray.DataArray, *, lower: str = "0 m s-1", upper: str = "46 m s-1" + da: xarray.DataArray, + *, + lower: Quantified = "0 m s-1", + upper: Quantified = "46 m s-1", ) -> xarray.DataArray: """Check if variable values fall below 0% or rise above 100% for any given day. @@ -419,7 +435,7 @@ def wind_values_outside_of_bounds( # TODO: 'Many excessive dry days' = the amount of dry days lies outside a 14·bivariate standard deviation -@register_methods +@register_methods("outside_{n}_standard_deviations_of_climatology") @update_xclim_history def outside_n_standard_deviations_of_climatology( da: xarray.DataArray, @@ -475,7 +491,7 @@ def outside_n_standard_deviations_of_climatology( return ~within_bounds -@register_methods +@register_methods("values_repeating_for_{n}_or_more_days") @update_xclim_history def values_repeating_for_n_or_more_days( da: xarray.DataArray, *, n: int @@ -508,7 +524,7 @@ def values_repeating_for_n_or_more_days( return repetition -@register_methods +@register_methods() @update_xclim_history def percentage_values_outside_of_bounds(da: xarray.DataArray) -> xarray.DataArray: """Check if variable values fall below 0% or rise above 100% for any given day. @@ -589,28 +605,35 @@ def data_flags( # noqa: C901 ... ) """ - def _convert_value_to_str(var_name, val) -> str: - """Convert variable units to an xarray data variable-like string.""" - if isinstance(val, str): - try: - # Use pint to - val = str2pint(val).magnitude - if isinstance(val, float): + def get_variable_name(function, kwargs): + fmtargs = {} + kwargs = kwargs or {} + for arg, param in signature(function).parameters.items(): + val = kwargs.get(arg, param.default) + kind = infer_kind_from_parameter(param) + if arg == "op": + fmtargs[arg] = val if val not in binary_ops else binary_ops[val] + elif kind in [ + InputKind.FREQ_STR, + InputKind.NUMBER, + InputKind.STRING, + InputKind.DAY_OF_YEAR, + InputKind.DATE, + InputKind.BOOL, + ]: + fmtargs[arg] = val + elif kind == InputKind.QUANTIFIED: + if isinstance(val, xarray.DataArray): + fmtargs[arg] = "array" + else: + val = str2pint(val).magnitude if Decimal(val) % 1 == 0: val = str(int(val)) else: - val = "point".join(str(val).split(".")) - except pint.UndefinedUnitError: - pass - - if isinstance(val, (int, str)): - # Replace spaces between units with underlines - var_name = var_name.replace(f"_{param}_", f"_{str(val).replace(' ', '_')}_") - # Change hyphens in units into the word "_minus_" - if "-" in var_name: - var_name = var_name.replace("-", "_minus_") - - return var_name + val = str(val).replace(".", "point") + val = val.replace("-", "minus") + fmtargs[arg] = str(val) + return function.variable_name.format(**fmtargs) def _missing_vars(function, dataset: xarray.Dataset, var_provided: str): """Handle missing variables in passed datasets.""" @@ -659,12 +682,9 @@ def _missing_vars(function, dataset: xarray.Dataset, var_provided: str): for flag_func in flag_funcs: for name, kwargs in flag_func.items(): func = _REGISTRY[name] - variable_name = str(name) + variable_name = get_variable_name(func, kwargs) named_da_variable = None - if kwargs: - for param, value in kwargs.items(): - variable_name = _convert_value_to_str(variable_name, value) try: extras = _missing_vars(func, ds, str(da.name)) # Entries in extras implies that there are two variables being compared diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index 03a4918cf..674dc2185 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -34,6 +34,7 @@ __all__ = [ "aggregate_between_dates", + "binary_ops", "compare", "count_level_crossings", "count_occurrences",