Skip to content

Commit

Permalink
Various typing and signature fixes (#2018)
Browse files Browse the repository at this point in the history
### What kind of change does this PR introduce?

* Addresses a few typing errors.
* Fixes a few overwritten variables in doctests.
* Fixes the call signature of `hot_spell_max_magnitude` to remove `op`.
* Fixes the call signature of `mbcn_adjust` to precisely report that
DataArrays are expected and Dataset is returned.

### Does this PR introduce a breaking change?

Yes. `hot_spell_max_magnitude` previously accepted an `op` argument that
was unused in the algorithm. This has been removed.

### Other information:

There is much more work to be done, but many changes require significant
refactors. There's enough in the current release to justify stopping
here for now.
  • Loading branch information
Zeitsperre authored Dec 11, 2024
2 parents 8b31dc2 + 31f86b0 commit 4b4ae51
Show file tree
Hide file tree
Showing 17 changed files with 57 additions and 50 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Breaking changes
^^^^^^^^^^^^^^^^
* The minimum required version of `dask` has been increased to `2024.8.1`. (:issue:`1992`, :pull:`1991`).
* The docstrings of many `xclim` modules, classes, methods, and functions have been slightly adjusted to ensure stricter compliance with established `numpy` docstring conventions. (:pull:`1988`).
* The call signature of ``xclim.indices.hot_spell_magnitude`` originally asked for an `op` argument that was not used. This argument has been removed. (:pull:`2018`).

Bug fixes
^^^^^^^^^
Expand All @@ -34,6 +35,7 @@ Internal changes
* `xclim` now uses a `src` layout for the codebase. Structure-dependent functions, documentation, and build commands have been adapted to reflect these changes. (:pull:`1971`).
* Added a more robust `yamllint` configuration to ensure that all YAML files are linted consistently. (:pull:`1971`).
* Addressed a very rare singular matrix error that can happen in ``test_loess_smoothing_nan``. (:pull:`2015`).
* Addressed a handful of typing and call signature issues in the `xclim` codebase. (:pull:`2018`).

CI changes
^^^^^^^^^^
Expand Down
7 changes: 4 additions & 3 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import datetime
import json
import os
import pathlib
import sys
import warnings

Expand Down Expand Up @@ -67,13 +68,13 @@
# Dump indicators to json. The json is added to the html output (html_extra_path)
# It is read by _static/indsearch.js to populate the table in indicators.rst
os.makedirs("_dynamic", exist_ok=True)
with open("_dynamic/indicators.json", "w") as f:
with pathlib.Path("_dynamic/indicators.json").open("w") as f:
json.dump(indicators, f)


# Dump variables information
with open("variables.json", "w") as fout:
with open("../src/xclim/data/variables.yml") as fin:
with pathlib.Path("variables.json").open("w") as fout:
with pathlib.Path("../src/xclim/data/variables.yml").open() as fin:
data = yaml.safe_load(fin)
json.dump(data, fout)

Expand Down
2 changes: 1 addition & 1 deletion src/xclim/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def cli(ctx, **kwargs): # numpydoc ignore=PR01

@cli.result_callback()
@click.pass_context
def write_file(ctx, *args, **kwargs): # numpydoc ignore=PR01
def write_file(ctx, *_, **kwargs): # numpydoc ignore=PR01
"""Write the output dataset to file."""
if ctx.obj["output"] is not None:
if ctx.obj["verbose"]:
Expand Down
6 changes: 5 additions & 1 deletion src/xclim/core/bootstrapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,17 @@ def bootstrap_func(compute_index_func: Callable, **kwargs) -> xarray.DataArray:
:cite:cts:`zhang_avoiding_2005`
"""
# Identify the input and the percentile arrays from the bound arguments
per_key = None
per_key, da_key = None, None
for name, val in kwargs.items():
if isinstance(val, DataArray):
if "percentile_doy" in val.attrs.get("history", ""):
per_key = name
else:
da_key = name
if da_key is None or per_key is None:
raise KeyError(
"The input data and the percentile DataArray must be provided as named arguments."
)
# Extract the DataArray inputs from the arguments
da: DataArray = kwargs.pop(da_key)
per_da: DataArray | None = kwargs.pop(per_key, None)
Expand Down
4 changes: 2 additions & 2 deletions src/xclim/core/dataflags.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,11 +639,11 @@ def data_flags( # noqa: C901
>>> from xclim.core.dataflags import data_flags
>>> ds = xr.open_dataset(path_to_pr_file)
>>> flagged = data_flags(ds.pr, ds)
>>> flagged_multi = data_flags(ds.pr, ds)
>>> # The next example evaluates only one data flag, passing specific parameters. It also aggregates the flags
>>> # yearly over the "time" dimension only, such that a True means there is a bad data point for that year
>>> # at that location.
>>> flagged = data_flags(
>>> flagged_single = data_flags(
... ds.pr,
... ds,
... flags={"very_large_precipitation_events": {"thresh": "250 mm d-1"}},
Expand Down
6 changes: 4 additions & 2 deletions src/xclim/core/indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,7 +873,7 @@ def __call__(self, *args, **kwds):
out_attrs.pop("units", None)
else:
out_attrs = {}
out_attrs = [out_attrs.copy() for i in range(self.n_outs)]
out_attrs = [out_attrs.copy() for _ in range(self.n_outs)]

das, params = self._preprocess_and_checks(das, params)

Expand Down Expand Up @@ -943,7 +943,9 @@ def __call__(self, *args, **kwds):
return outs[0]
return tuple(outs)

def _parse_variables_from_call(self, args, kwds) -> tuple[OrderedDict, dict]:
def _parse_variables_from_call(
self, args, kwds
) -> tuple[OrderedDict, OrderedDict, OrderedDict | dict]:
"""Extract variable and optional variables from call arguments."""
# Bind call arguments to `compute` arguments and set defaults.
ba = self.__signature__.bind(*args, **kwds)
Expand Down
6 changes: 3 additions & 3 deletions src/xclim/core/units.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,7 +1308,7 @@ def _check_output_has_units(

# FIXME: This needs to be properly annotated for mypy compliance.
# See: https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators
def declare_relative_units(**units_by_name) -> Callable:
def declare_relative_units(**units_by_name: str) -> Callable:
r"""
Function decorator checking the units of arguments.
Expand All @@ -1317,7 +1317,7 @@ def declare_relative_units(**units_by_name) -> Callable:
Parameters
----------
**units_by_name : dict
**units_by_name : str
Mapping from the input parameter names to dimensions relative to other parameters.
The dimensions can be a single parameter name as `<other_var>` or more complex expressions,
such as `<other_var> * [time]`.
Expand Down Expand Up @@ -1430,7 +1430,7 @@ def declare_units(**units_by_name) -> Callable:
Parameters
----------
**units_by_name : dict
**units_by_name : str
Mapping from the input parameter names to their units or dimensionality ("[...]").
If this decorates a function previously decorated with :py:func:`declare_relative_units`,
the relative unit declarations are made absolute with the information passed here.
Expand Down
1 change: 1 addition & 0 deletions src/xclim/ensembles/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def ensemble_mean_std_max_min(
ds_out[f"{v}_max"] = ens[v].max(dim="realization")
ds_out[f"{v}_min"] = ens[v].min(dim="realization")

enough = None
if min_members != 1:
enough = ens[v].notnull().sum("realization") >= min_members

Expand Down
5 changes: 3 additions & 2 deletions src/xclim/ensembles/_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from __future__ import annotations

from typing import Any
from warnings import warn

import numpy as np
Expand Down Expand Up @@ -128,7 +129,7 @@ def kkz_reduce_ensemble(
*,
dist_method: str = "euclidean",
standardize: bool = True,
**cdist_kwargs,
**cdist_kwargs: Any,
) -> list:
r"""
Return a sample of ensemble members using KKZ selection.
Expand All @@ -152,7 +153,7 @@ def kkz_reduce_ensemble(
standardize : bool
Whether to standardize the input before running the selection or not.
Standardization consists in translation as to have a zero mean and scaling as to have a unit standard deviation.
**cdist_kwargs : dict
**cdist_kwargs : Any
All extra arguments are passed as-is to `scipy.spatial.distance.cdist`, see its docs for more information.
Returns
Expand Down
2 changes: 1 addition & 1 deletion src/xclim/indices/_agro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1215,7 +1215,7 @@ def standardized_precipitation_index(
... method="ML",
... zero_inflated=True,
... ) # First getting params
>>> spi_3 = standardized_precipitation_index(pr, params=params)
>>> spi_3_fitted = standardized_precipitation_index(pr, params=params)
"""
fitkwargs = fitkwargs or {}
dist_methods = {"gamma": ["ML", "APP"], "fisk": ["ML", "APP"]}
Expand Down
11 changes: 4 additions & 7 deletions src/xclim/indices/_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -1973,7 +1973,6 @@ def hot_spell_max_magnitude(
thresh: Quantified = "25.0 degC",
window: int = 3,
freq: str = "YS",
op: str = ">",
resample_before_rl: bool = True,
) -> xarray.DataArray:
"""
Expand All @@ -1993,8 +1992,6 @@ def hot_spell_max_magnitude(
Minimum number of days with temperature above threshold to qualify as a heatwave.
freq : str
Resampling frequency.
op : {">", ">=", "gt", "ge"}
Comparison operation. Default: ">".
resample_before_rl : bool
Determines if the resampling should take place before or after the run
length encoding (or a similar algorithm) is applied to runs.
Expand Down Expand Up @@ -3273,8 +3270,8 @@ def dry_spell_frequency(
--------
>>> from xclim.indices import dry_spell_frequency
>>> pr = xr.open_dataset(path_to_pr_file).pr
>>> dsf = dry_spell_frequency(pr=pr, op="sum")
>>> dsf = dry_spell_frequency(pr=pr, op="max")
>>> dsf_sum = dry_spell_frequency(pr=pr, op="sum")
>>> dsf_max = dry_spell_frequency(pr=pr, op="max")
"""
pram = rate2amount(convert_units_to(pr, "mm/d", context="hydro"), out_units="mm")
return spell_length_statistics(
Expand Down Expand Up @@ -3481,8 +3478,8 @@ def wet_spell_frequency(
--------
>>> from xclim.indices import wet_spell_frequency
>>> pr = xr.open_dataset(path_to_pr_file).pr
>>> dsf = wet_spell_frequency(pr=pr, op="sum")
>>> dsf = wet_spell_frequency(pr=pr, op="min")
>>> dsf_sum = wet_spell_frequency(pr=pr, op="sum")
>>> dsf_min = wet_spell_frequency(pr=pr, op="min")
"""
pram = rate2amount(convert_units_to(pr, "mm/d", context="hydro"), out_units="mm")
return spell_length_statistics(
Expand Down
4 changes: 1 addition & 3 deletions src/xclim/indices/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,9 +803,7 @@ def make_hourly_temperature(tasmin: xr.DataArray, tasmax: xr.DataArray) -> xr.Da
hourly = data.resample(time="h").ffill().isel(time=slice(0, -1))

# To avoid "invalid value encountered in log" warning we set hours before sunset to 1
nighttime_hours = nighttime_hours = (
hourly.time.dt.hour + 1 - hourly.daylength
).clip(1)
nighttime_hours = (hourly.time.dt.hour + 1 - hourly.daylength).clip(1)

return xr.where(
hourly.time.dt.hour < hourly.daylength,
Expand Down
2 changes: 1 addition & 1 deletion src/xclim/indices/run_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def resample_and_rl(
Resampling frequency.
dim : str
The dimension along which to find runs.
**kwargs : dict
**kwargs : Any
Keyword arguments needed in `compute`.
Returns
Expand Down
10 changes: 5 additions & 5 deletions src/xclim/sdba/_adjustment.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,9 +339,9 @@ def _npdft_adjust(sim, af_q, rots, quantiles, method, extrap):


def mbcn_adjust(
ref: xr.Dataset,
hist: xr.Dataset,
sim: xr.Dataset,
ref: xr.DataArray,
hist: xr.DataArray,
sim: xr.DataArray,
ds: xr.Dataset,
pts_dims: Sequence[str],
interp: str,
Expand All @@ -350,7 +350,7 @@ def mbcn_adjust(
base_kws_vars: dict,
adj_kws: dict,
period_dim: str | None,
) -> xr.DataArray:
) -> xr.Dataset:
"""Perform the adjustment portion MBCn multivariate bias correction technique.
The function :py:func:`mbcn_train` pre-computes the adjustment factors for each rotation
Expand Down Expand Up @@ -696,7 +696,7 @@ def npdf_transform(ds: xr.Dataset, **kwargs) -> xr.Dataset:
hist : simulated timeseries on the reference period
sim : Simulated timeseries on the projected period.
rot_matrices : Random rotation matrices.
**kwargs : dict
**kwargs : Any
pts_dim : multivariate dimension name
base : Adjustment class
base_kws : Kwargs for initialising the adjustment object
Expand Down
18 changes: 9 additions & 9 deletions src/xclim/sdba/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,14 +852,14 @@ def grouped_time_indexes(times, group):
Time indexes of the blocks (built with a rolling window of `group.window` if any).
"""

def _get_group_complement(da, group):
def _get_group_complement(_da, _group):
# complement of "dayofyear": "year", etc.
gr = group if isinstance(group, str) else group.name
if gr == "time.dayofyear":
return da.time.dt.year
if gr == "time.month":
return da.time.dt.strftime("%Y-%d")
raise NotImplementedError(f"Grouping {gr} not implemented.")
_gr = _group if isinstance(_group, str) else _group.name
if _gr == "time.dayofyear":
return _da.time.dt.year
if _gr == "time.month":
return _da.time.dt.strftime("%Y-%d")
raise NotImplementedError(f"Grouping {_gr} not implemented.")

# does not work with group == "time.month"
group = group if isinstance(group, Grouper) else Grouper(group)
Expand All @@ -871,14 +871,14 @@ def _get_group_complement(da, group):
)
if gr == "time.dayofyear":
# time indices for each block with window = 1
g_idxs = timeind.groupby(gr).apply(
g_idxs = timeind.groupby(gr).map(
lambda da: da.assign_coords(time=_get_group_complement(da, gr)).rename(
{"time": "year"}
)
)
# time indices for each block with general window
da = timeind.rolling(time=win, center=True).construct(window_dim=win_dim0)
gw_idxs = da.groupby(gr).apply(
gw_idxs = da.groupby(gr).map(
lambda da: da.assign_coords(time=_get_group_complement(da, gr)).stack(
{win_dim: ["time", win_dim0]}
)
Expand Down
8 changes: 5 additions & 3 deletions src/xclim/sdba/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from __future__ import annotations

import itertools
from collections.abc import Callable
from collections.abc import Callable, Sequence
from warnings import warn

import bottleneck as bn
Expand Down Expand Up @@ -80,7 +80,9 @@ def map_cdf(
)


def ecdf(x: xr.DataArray, value: float, dim: str = "time") -> xr.DataArray:
def ecdf(
x: xr.DataArray, value: float, dim: str | Sequence[str] = "time"
) -> xr.DataArray:
"""Return the empirical CDF of a sample at a given value.
Parameters
Expand Down Expand Up @@ -948,7 +950,7 @@ def _skipna_correlation(data):
# The output
out = np.empty((nv, nv), dtype=coef.dtype)
# A 2D mask of removed variables
M = (mask_omit)[:, np.newaxis] | (mask_omit)[np.newaxis, :]
M = mask_omit[:, np.newaxis] | mask_omit[np.newaxis, :]
out[~M] = coef.flatten()
out[M] = np.nan
return out
Expand Down
13 changes: 6 additions & 7 deletions src/xclim/testing/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,12 @@ def add_doctest_filepaths() -> dict[str, Any]:
dict[str, Any]
A dictionary of xdoctest namespace objects.
"""
namespace: dict = {}
namespace["np"] = np
namespace["xclim"] = xclim
namespace["tas"] = test_timeseries(
np.random.rand(365) * 20 + 253.15, variable="tas"
)
namespace["pr"] = test_timeseries(np.random.rand(365) * 5, variable="pr")
namespace = {
"np": np,
"xclim": xclim,
"tas": test_timeseries(np.random.rand(365) * 20 + 253.15, variable="tas"),
"pr": test_timeseries(np.random.rand(365) * 5, variable="pr"),
}
return namespace


Expand Down

0 comments on commit 4b4ae51

Please sign in to comment.