Skip to content

Commit

Permalink
Fix data flags var name generation (#1507)
Browse files Browse the repository at this point in the history
<!--Please ensure the PR fulfills the following requirements! -->
<!-- If this is your first PR, make sure to add your details to the
AUTHORS.rst! -->
### Pull Request Checklist:
- [ ] This PR addresses an already opened issue (for bug fixes /
features)
- This PR fixes a issue raised in private communication by @RondeauG
- [x] Tests for the changes have been added (for bug fixes / features)
- [x] (If applicable) Documentation has been added / updated (for bug
fixes / features)
- [x] CHANGES.rst has been updated (with summary of main changes)
- [x] Link to issue (:issue:`number`) and pull request (:pull:`number`)
has been added

### What kind of change does this PR introduce?

Changes how the data flags variable names are generated.

In the previous version, every kwargs was sent to `str2pint`. We were
relying on a `try: except` to catch those kwargs that weren't
quantities. This had the caveat of requiring an explicit list of
possible errors. @RondeauG had a case where `op='>='` was triggering a
`ValueError`, which wasn't listed.

I changed this "implicit" parsing to an "explicit" one:
- Data flags declare the variable name as a templated string.
- Parameters are iterated through and handled according to their
"InputKind"
- "Quantified" inputs are handled as before, but all others are passed
as-is.

This required:
- Changing the registering decorator to accept a templated string
argument
- Correctly annotating the inputs, which lead to adding missing
thresholds to the declared units

### Does this PR introduce a breaking change?
Yes, I made the arbitrary choice of changing how we stringify the minus
sign:

```
Input thresh : -5.1 mm
Template : "values_greater_{thresh}"
Before : "values_greater__minus_5point1"
This PR: "values_greater_minus5point1"
```
However, I don't think this will affect many projects.

### Other information:
@RondeauG , could you try this new branch with your code ? (That's why I
made you a reviewer)
  • Loading branch information
aulemahal authored Oct 23, 2023
2 parents 9732c00 + 8f9fedc commit 7edffda
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 49 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Bug fixes
* Fixed ``xclim.indices.run_length.lazy_indexing`` which would sometimes trigger the loading of auxiliary coordinates. (:issue:`1483`, :pull:`1484`).
* Indicators ``snd_season_length`` and ``snw_season_length`` will return 0 instead of NaN if all inputs have a (non-NaN) zero snow depth (or water-equivalent thickness). (:pull:`1492`, :issue:`1491`)
* Fixed a bug in the `pytest` configuration that could prevent testing data caching from occurring in systems where the platform-dependent cache directory is not found in the user's home. (:issue:`1468`, :pull:`1473`).
* Fix ``xclim.core.dataflags.data_flags`` variable name generation (:pull:`1507`).

Breaking changes
^^^^^^^^^^^^^^^^
Expand Down
17 changes: 17 additions & 0 deletions tests/test_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,3 +151,20 @@ def test_era5_ecad_qc_flag(self, open_dataset):

df_flagged = df.ecad_compliant(bad_ds)
np.testing.assert_array_equal(df_flagged.ecad_qc_flag, False)

def test_names(self, pr_series):
pr = pr_series(np.zeros(365), start="1971-01-01")
flgs = df.data_flags(
pr,
flags={
"values_op_thresh_repeating_for_n_or_more_days": {
"op": "==",
"n": 5,
"thresh": "-5.1 mm d-1",
}
},
)
assert (
list(flgs.data_vars.keys())[0]
== "values_eq_minus5point1_repeating_for_5_or_more_days"
)
118 changes: 69 additions & 49 deletions xclim/core/dataflags.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from typing import Sequence

import numpy as np
import pint
import xarray

from ..indices.generic import binary_ops
from ..indices.run_length import suspicious_run
from .calendar import climatological_mean_doy, within_bnds_doy
from .formatting import update_xclim_history
Expand All @@ -23,6 +23,7 @@
VARIABLES,
InputKind,
MissingVariableError,
Quantified,
infer_kind_from_parameter,
raise_warn_or_log,
)
Expand All @@ -41,6 +42,8 @@ class DataQualityException(Exception):
Message prepended to the error messages.
"""

flag_array: xarray.Dataset = None

def __init__(
self,
flag_array: xarray.Dataset,
Expand Down Expand Up @@ -81,10 +84,20 @@ def __str__(self):
]


def register_methods(func):
"""Summarize all methods used in dataflags checks."""
_REGISTRY[func.__name__] = func
return func
def register_methods(variable_name=None):
"""Register a data flag functioné.
Argument can be the output variable name template. The template may use any of the stringable input arguments.
If not given, the function name is used instead, which may create variable conflicts.
"""

def _register_methods(func):
"""Summarize all methods used in dataflags checks."""
func.__dict__["variable_name"] = variable_name or func.__name__
_REGISTRY[func.__name__] = func
return func

return _register_methods


def _sanitize_attrs(da: xarray.DataArray) -> xarray.DataArray:
Expand All @@ -97,7 +110,7 @@ def _sanitize_attrs(da: xarray.DataArray) -> xarray.DataArray:
return da


@register_methods
@register_methods()
@update_xclim_history
@declare_units(tasmax="[temperature]", tasmin="[temperature]")
def tasmax_below_tasmin(
Expand Down Expand Up @@ -130,7 +143,7 @@ def tasmax_below_tasmin(
return tasmax_lt_tasmin


@register_methods
@register_methods()
@update_xclim_history
@declare_units(tas="[temperature]", tasmax="[temperature]")
def tas_exceeds_tasmax(
Expand Down Expand Up @@ -163,7 +176,7 @@ def tas_exceeds_tasmax(
return tas_gt_tasmax


@register_methods
@register_methods()
@update_xclim_history
@declare_units(tas="[temperature]", tasmin="[temperature]")
def tas_below_tasmin(
Expand Down Expand Up @@ -195,11 +208,11 @@ def tas_below_tasmin(
return tas_lt_tasmin


@register_methods
@register_methods()
@update_xclim_history
@declare_units(da="[temperature]")
@declare_units(da="[temperature]", thresh="[temperature]")
def temperature_extremely_low(
da: xarray.DataArray, *, thresh: str = "-90 degC"
da: xarray.DataArray, *, thresh: Quantified = "-90 degC"
) -> xarray.DataArray:
"""Check if temperatures values are below -90 degrees Celsius for any given day.
Expand Down Expand Up @@ -229,11 +242,11 @@ def temperature_extremely_low(
return extreme_low


@register_methods
@register_methods()
@update_xclim_history
@declare_units(da="[temperature]")
@declare_units(da="[temperature]", thresh="[temperature]")
def temperature_extremely_high(
da: xarray.DataArray, *, thresh: str = "60 degC"
da: xarray.DataArray, *, thresh: Quantified = "60 degC"
) -> xarray.DataArray:
"""Check if temperatures values exceed 60 degrees Celsius for any given day.
Expand Down Expand Up @@ -263,7 +276,7 @@ def temperature_extremely_high(
return extreme_high


@register_methods
@register_methods()
@update_xclim_history
def negative_accumulation_values(
da: xarray.DataArray,
Expand Down Expand Up @@ -293,11 +306,11 @@ def negative_accumulation_values(
return negative_accumulations


@register_methods
@register_methods()
@update_xclim_history
@declare_units(da="[precipitation]")
@declare_units(da="[precipitation]", thresh="[precipitation]")
def very_large_precipitation_events(
da: xarray.DataArray, *, thresh="300 mm d-1"
da: xarray.DataArray, *, thresh: Quantified = "300 mm d-1"
) -> xarray.DataArray:
"""Check if precipitation values exceed 300 mm/day for any given day.
Expand Down Expand Up @@ -329,10 +342,10 @@ def very_large_precipitation_events(
return very_large_events


@register_methods
@register_methods("values_{op}_{thresh}_repeating_for_{n}_or_more_days")
@update_xclim_history
def values_op_thresh_repeating_for_n_or_more_days(
da: xarray.DataArray, *, n: int, thresh: str, op: str = "=="
da: xarray.DataArray, *, n: int, thresh: Quantified, op: str = "=="
) -> xarray.DataArray:
"""Check if array values repeat at a given threshold for `N` or more days.
Expand Down Expand Up @@ -377,11 +390,14 @@ def values_op_thresh_repeating_for_n_or_more_days(
return repetitions


@register_methods
@register_methods()
@update_xclim_history
@declare_units(da="[speed]")
@declare_units(da="[speed]", lower="[speed]", upper="[speed]")
def wind_values_outside_of_bounds(
da: xarray.DataArray, *, lower: str = "0 m s-1", upper: str = "46 m s-1"
da: xarray.DataArray,
*,
lower: Quantified = "0 m s-1",
upper: Quantified = "46 m s-1",
) -> xarray.DataArray:
"""Check if variable values fall below 0% or rise above 100% for any given day.
Expand Down Expand Up @@ -419,7 +435,7 @@ def wind_values_outside_of_bounds(
# TODO: 'Many excessive dry days' = the amount of dry days lies outside a 14·bivariate standard deviation


@register_methods
@register_methods("outside_{n}_standard_deviations_of_climatology")
@update_xclim_history
def outside_n_standard_deviations_of_climatology(
da: xarray.DataArray,
Expand Down Expand Up @@ -475,7 +491,7 @@ def outside_n_standard_deviations_of_climatology(
return ~within_bounds


@register_methods
@register_methods("values_repeating_for_{n}_or_more_days")
@update_xclim_history
def values_repeating_for_n_or_more_days(
da: xarray.DataArray, *, n: int
Expand Down Expand Up @@ -508,7 +524,7 @@ def values_repeating_for_n_or_more_days(
return repetition


@register_methods
@register_methods()
@update_xclim_history
def percentage_values_outside_of_bounds(da: xarray.DataArray) -> xarray.DataArray:
"""Check if variable values fall below 0% or rise above 100% for any given day.
Expand Down Expand Up @@ -589,28 +605,35 @@ def data_flags( # noqa: C901
... )
"""

def _convert_value_to_str(var_name, val) -> str:
"""Convert variable units to an xarray data variable-like string."""
if isinstance(val, str):
try:
# Use pint to
val = str2pint(val).magnitude
if isinstance(val, float):
def get_variable_name(function, kwargs):
fmtargs = {}
kwargs = kwargs or {}
for arg, param in signature(function).parameters.items():
val = kwargs.get(arg, param.default)
kind = infer_kind_from_parameter(param)
if arg == "op":
fmtargs[arg] = val if val not in binary_ops else binary_ops[val]
elif kind in [
InputKind.FREQ_STR,
InputKind.NUMBER,
InputKind.STRING,
InputKind.DAY_OF_YEAR,
InputKind.DATE,
InputKind.BOOL,
]:
fmtargs[arg] = val
elif kind == InputKind.QUANTIFIED:
if isinstance(val, xarray.DataArray):
fmtargs[arg] = "array"
else:
val = str2pint(val).magnitude
if Decimal(val) % 1 == 0:
val = str(int(val))
else:
val = "point".join(str(val).split("."))
except pint.UndefinedUnitError:
pass

if isinstance(val, (int, str)):
# Replace spaces between units with underlines
var_name = var_name.replace(f"_{param}_", f"_{str(val).replace(' ', '_')}_")
# Change hyphens in units into the word "_minus_"
if "-" in var_name:
var_name = var_name.replace("-", "_minus_")

return var_name
val = str(val).replace(".", "point")
val = val.replace("-", "minus")
fmtargs[arg] = str(val)
return function.variable_name.format(**fmtargs)

def _missing_vars(function, dataset: xarray.Dataset, var_provided: str):
"""Handle missing variables in passed datasets."""
Expand Down Expand Up @@ -659,12 +682,9 @@ def _missing_vars(function, dataset: xarray.Dataset, var_provided: str):
for flag_func in flag_funcs:
for name, kwargs in flag_func.items():
func = _REGISTRY[name]
variable_name = str(name)
variable_name = get_variable_name(func, kwargs)
named_da_variable = None

if kwargs:
for param, value in kwargs.items():
variable_name = _convert_value_to_str(variable_name, value)
try:
extras = _missing_vars(func, ds, str(da.name))
# Entries in extras implies that there are two variables being compared
Expand Down
1 change: 1 addition & 0 deletions xclim/indices/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

__all__ = [
"aggregate_between_dates",
"binary_ops",
"compare",
"count_level_crossings",
"count_occurrences",
Expand Down

0 comments on commit 7edffda

Please sign in to comment.