Skip to content

Commit

Permalink
Block-mapped resample with the help of flox (#1848)
Browse files Browse the repository at this point in the history
<!--Please ensure the PR fulfills the following requirements! -->
<!-- If this is your first PR, make sure to add your details to the
AUTHORS.rst! -->
### Pull Request Checklist:
- [ ] This PR addresses an already opened issue (for bug fixes /
features)
    - This PR fixes #xyz
- [x] Tests for the changes have been added (for bug fixes / features)
- [x] (If applicable) Documentation has been added / updated (for bug
fixes / features)
- [x] CHANGELOG.rst has been updated (with summary of main changes)
- [x] Link to issue (:issue:`number`) and pull request (:pull:`number`)
has been added

### What kind of change does this PR introduce?

Implements `resample_map`. This function is meant for all
`da.resample(...).map(...)` calls. These, `flox` cannot improve
automatically so we use some flox logic to help. The idea is to map the
resample-map construct on each block in parallel. This is possible by
first rechunking the array so that chunks boundary fit with resampling
period boundaries (this is a flox function).

The main improvement should come from the fact that `map_blocks` hides
much of the complexity to `dask`, so the resulting graph is much
lighter. I still have to better test the performance of this. My goal
would be to have some short text in xclim's doc that highlights when the
option is useful and when it is not. The option is activated through
`set_options`.

The current function works only when the input object is of the same
type as the output one. So some functions couldn't be wrapped with this
yet. The most important untouched code for the moment is the missing
checks where I think this could help a lot.

### Does this PR introduce a breaking change?
It should not. This is completely optional.

### Other information:
In progress, I still need to prove the performance boost.

This depends on #1845 because I need all improvements for PC.
  • Loading branch information
aulemahal authored Oct 10, 2024
2 parents 8b644c6 + 1a584d4 commit 438ef2e
Show file tree
Hide file tree
Showing 14 changed files with 288 additions and 76 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ New features and enhancements
* Indicator parameters can now be assigned a new name, different from the argument name in the compute function. (:pull:`1885`).
* ``xclim.indices.run_length.windowed_max_run_sum`` accumulates positive values across runs and yields the the maximum valued run. (:pull:`1926`).
* Helper function ``xclim.indices.helpers.make_hourly_temperature`` to estimate hourly temperatures from daily min and max temperatures. (:pull:`1909`).
* New global option ``resample_map_blocks`` to wrap all ``resample().map()`` code inside a ``xr.map_blocks`` to lower the number of dask tasks. Uses utility ``xclim.indices.helpers.resample_map`` and requires ``flox`` to ensure the chunking allows such block-mapping. Defaults to False. (:pull:`1848`).

Bug fixes
^^^^^^^^^
Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies:
- click >=8.1
- dask >=2.6.0
- filelock >=3.14.0
- flox >= 0.9
- jsonpickle >=3.1.0
- numba >=0.54.1
- numpy >=1.23.0
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ docs = [
"sphinxcontrib-bibtex",
"sphinxcontrib-svg2pdfconverter[Cairosvg]"
]
extras = ["fastnanquantile >=0.0.2", "POT >=0.9.4"]
extras = ["fastnanquantile >=0.0.2", "flox >=0.9", "POT >=0.9.4"]
all = ["xclim[dev]", "xclim[docs]", "xclim[extras]"]

[project.scripts]
Expand Down
12 changes: 10 additions & 2 deletions tests/test_atmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import numpy as np
import pytest
import xarray as xr

from xclim import atmos, set_options
Expand Down Expand Up @@ -642,11 +643,18 @@ def test_chill_units(atmosds):
np.testing.assert_allclose(cu.isel(location=0), exp, rtol=1e-03)


def test_chill_portions(atmosds):
@pytest.mark.parametrize("use_dask", [True, False])
def test_chill_portions(atmosds, use_dask):
pytest.importorskip("flox")
tasmax = atmosds.tasmax
tasmin = atmosds.tasmin
tas = make_hourly_temperature(tasmin, tasmax)
cp = atmos.chill_portions(tas, date_bounds=("09-01", "03-30"), freq="YS-JUL")
if use_dask:
tas = tas.chunk(time=tas.time.size // 2, location=1)

with set_options(resample_map_blocks=True):
cp = atmos.chill_portions(tas, date_bounds=("09-01", "03-30"), freq="YS-JUL")

assert cp.attrs["units"] == "1"
assert cp.name == "cp"
# Although its 4 years of data its 5 seasons starting in July
Expand Down
45 changes: 45 additions & 0 deletions tests/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
import pytest
import xarray as xr

from xclim.core.options import set_options
from xclim.core.units import convert_units_to
from xclim.core.utils import uses_dask
from xclim.indices import helpers
from xclim.testing.helpers import assert_lazy


@pytest.mark.parametrize("method,rtol", [("spencer", 5e3), ("simple", 1e2)])
Expand Down Expand Up @@ -134,6 +137,48 @@ def test_cosine_of_solar_zenith_angle():
np.testing.assert_allclose(cza[:4, :], exp_cza, rtol=1e-3)


def _test_function(da, op, dim):
return getattr(da, op)(dim)


@pytest.mark.parametrize(
["in_chunks", "exp_chunks"], [(60, 6 * (2,)), (30, 12 * (1,)), (-1, (12,))]
)
def test_resample_map(tas_series, in_chunks, exp_chunks):
pytest.importorskip("flox")
tas = tas_series(365 * [1]).chunk(time=in_chunks)
with assert_lazy:
out = helpers.resample_map(
tas, "time", "MS", lambda da: da.mean("time"), map_blocks=True
)
assert out.chunks[0] == exp_chunks
out.load() # Trigger compute to see if it actually works


def test_resample_map_dataset(tas_series, pr_series):
pytest.importorskip("flox")
tas = tas_series(3 * 365 * [1], start="2000-01-01").chunk(time=365)
pr = pr_series(3 * 365 * [1], start="2000-01-01").chunk(time=365)
ds = xr.Dataset({"pr": pr, "tas": tas})
with set_options(resample_map_blocks=True):
with assert_lazy:
out = helpers.resample_map(
ds,
"time",
"YS",
lambda da: da.mean("time"),
)
assert out.chunks["time"] == (1, 1, 1)
out.load()


def test_resample_map_passthrough(tas_series):
tas = tas_series(365 * [1])
with assert_lazy:
out = helpers.resample_map(tas, "time", "MS", lambda da: da.mean("time"))
assert not uses_dask(out)


@pytest.mark.parametrize("cftime", [False, True])
def test_make_hourly_temperature(tasmax_series, tasmin_series, cftime):
tasmax = tasmax_series(np.array([20]), units="degC", cftime=cftime)
Expand Down
19 changes: 15 additions & 4 deletions tests/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -1454,13 +1454,24 @@ def test_1d(self, tasmax_series, thresh, window, op, expected):
def test_resampling_order(self, tasmax_series, resample_before_rl, expected):
a = np.zeros(365)
a[5:35] = 31
tx = tasmax_series(a + K2C)
tx = tasmax_series(a + K2C).chunk()

hsf = xci.hot_spell_frequency(
tx, resample_before_rl=resample_before_rl, freq="MS"
)
).load()
assert hsf[1] == expected

@pytest.mark.parametrize("resample_map", [True, False])
def test_resampling_map(self, tasmax_series, resample_map):
pytest.importorskip("flox")
a = np.zeros(365)
a[5:35] = 31
tx = tasmax_series(a + K2C).chunk()

with set_options(resample_map_blocks=resample_map):
hsf = xci.hot_spell_frequency(tx, resample_before_rl=True, freq="MS").load()
assert hsf[1] == 1


class TestHotSpellMaxLength:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -1746,10 +1757,10 @@ def test_run_start_at_0(self, pr_series):
def test_resampling_order(self, pr_series, resample_before_rl, expected):
a = np.zeros(365) + 10
a[5:35] = 0
pr = pr_series(a)
pr = pr_series(a).chunk()
out = xci.maximum_consecutive_dry_days(
pr, freq="ME", resample_before_rl=resample_before_rl
)
).load()
assert out[0] == expected


Expand Down
14 changes: 7 additions & 7 deletions xclim/core/indicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@
infer_kind_from_parameter,
is_percentile_dataarray,
load_module,
split_auxiliary_coordinates,
)

# Indicators registry
Expand Down Expand Up @@ -1446,13 +1447,12 @@ def _postprocess(self, outs, das, params):
# Reduce by or and broadcast to ensure the same length in time
# When indexing is used and there are no valid points in the last period, mask will not include it
mask = reduce(np.logical_or, miss)
if (
isinstance(mask, DataArray)
and "time" in mask.dims
and mask.time.size < outs[0].time.size
):
mask = mask.reindex(time=outs[0].time, fill_value=True)
outs = [out.where(np.logical_not(mask)) for out in outs]
if isinstance(mask, DataArray): # mask might be a bool in some cases
if "time" in mask.dims and mask.time.size < outs[0].time.size:
mask = mask.reindex(time=outs[0].time, fill_value=True)
# Remove any aux coord to avoid any unwanted dask computation in the alignment within "where"
mask, _ = split_auxiliary_coordinates(mask)
outs = [out.where(~mask) for out in outs]

return outs

Expand Down
6 changes: 6 additions & 0 deletions xclim/core/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SDBA_ENCODE_CF = "sdba_encode_cf"
KEEP_ATTRS = "keep_attrs"
AS_DATASET = "as_dataset"
MAP_BLOCKS = "resample_map_blocks"

MISSING_METHODS: dict[str, Callable] = {}

Expand All @@ -39,6 +40,7 @@
SDBA_ENCODE_CF: False,
KEEP_ATTRS: "xarray",
AS_DATASET: False,
MAP_BLOCKS: False,
}

_LOUDNESS_OPTIONS = frozenset(["log", "warn", "raise"])
Expand Down Expand Up @@ -71,6 +73,7 @@ def _valid_missing_options(mopts):
SDBA_ENCODE_CF: lambda opt: isinstance(opt, bool),
KEEP_ATTRS: _KEEP_ATTRS_OPTIONS.__contains__,
AS_DATASET: lambda opt: isinstance(opt, bool),
MAP_BLOCKS: lambda opt: isinstance(opt, bool),
}


Expand Down Expand Up @@ -185,6 +188,9 @@ class set_options:
Note that xarray's "default" is equivalent to False. Default: ``"xarray"``.
as_dataset : bool
If True, indicators output datasets. If False, they output DataArrays. Default :``False``.
resample_map_blocks: bool
If True, some indicators will wrap their resampling operations with `xr.map_blocks`, using :py:func:`xclim.indices.helpers.resample_map`.
This requires `flox` to be installed in order to ensure the chunking is appropriate.git
Examples
--------
Expand Down
41 changes: 41 additions & 0 deletions xclim/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,3 +758,44 @@ def _chunk_like(*inputs: xr.DataArray | xr.Dataset, chunks: dict[str, int] | Non
da.chunk(**{d: c for d, c in chunks.items() if d in da.dims})
)
return tuple(outputs)


def split_auxiliary_coordinates(
obj: xr.DataArray | xr.Dataset,
) -> tuple[xr.DataArray | xr.Dataset, xr.Dataset]:
"""Split auxiliary coords from the dataset.
An auxiliary coordinate is a coordinate variable that does not define a dimension and thus is not necessarily needed for dataset alignment.
Any coordinate that has a name different than its dimension(s) is flagged as auxiliary. All scalar coordinates are flagged as auxiliary.
Parameters
----------
obj : DataArray or Dataset
Xarray object
Returns
-------
clean_obj : DataArray or Dataset
Same as `obj` but without any auxiliary coordinate.
aux_coords : Dataset
The auxiliary coordinates as a dataset. Might be empty.
Note
----
This is useful to circumvent xarray's alignment checks that will sometimes look the auxiliary coordinate's data, which can trigger
unwanted dask computations.
The auxiliary coordinates can be merged back with the dataset with
:py:meth:`xarray.Dataset.assign_coords` or :py:meth:`xarray.DataArray.assign_coords`.
>>> # xdoctest: +SKIP
>>> clean, aux = split_auxiliary_coordinates(ds)
>>> merged = clean.assign_coords(da.coords)
>>> merged.identical(ds) # True
"""
aux_crd_names = [
nm for nm, crd in obj.coords.items() if len(crd.dims) != 1 or crd.dims[0] != nm
]
aux_crd_ds = obj.coords.to_dataset()[aux_crd_names]
clean_obj = obj.drop_vars(aux_crd_names)
return clean_obj, aux_crd_ds
15 changes: 7 additions & 8 deletions xclim/indices/_agro.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
rate2amount,
to_agg_units,
)
from xclim.core.utils import uses_dask
from xclim.indices._conversion import potential_evapotranspiration
from xclim.indices._simple import tn_min
from xclim.indices._threshold import (
first_day_temperature_above,
first_day_temperature_below,
)
from xclim.indices.generic import aggregate_between_dates, get_zones
from xclim.indices.helpers import _gather_lat, day_lengths
from xclim.indices.helpers import _gather_lat, day_lengths, resample_map
from xclim.indices.stats import standardized_index

# Frequencies : YS: year start, QS-DEC: seasons starting in december, MS: month start
Expand Down Expand Up @@ -1564,7 +1565,8 @@ def _chill_portion_one_season(tas_K):

def _apply_chill_portion_one_season(tas_K):
"""Apply the chill portion function on to an xarray DataArray."""
tas_K = tas_K.chunk(time=-1)
if uses_dask(tas_K):
tas_K = tas_K.chunk(time=-1)
return xarray.apply_ufunc(
_chill_portion_one_season,
tas_K,
Expand Down Expand Up @@ -1627,12 +1629,9 @@ def chill_portions(
tas_K: xarray.DataArray = select_time(
convert_units_to(tas, "K"), drop=True, **indexer
)
# TODO: use resample_map once #1848 is merged
return (
tas_K.resample(time=freq)
.map(_apply_chill_portion_one_season)
.assign_attrs(units="")
)
return resample_map(
tas_K, "time", freq, _apply_chill_portion_one_season
).assign_attrs(units="")


@declare_units(tas="[temperature]")
Expand Down
36 changes: 22 additions & 14 deletions xclim/indices/_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
spell_length_statistics,
threshold_count,
)
from xclim.indices.helpers import resample_map

# Frequencies : YS: year start, QS-DEC: seasons starting in december, MS: month start
# See http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases
Expand Down Expand Up @@ -1492,12 +1493,17 @@ def last_spring_frost(
thresh = convert_units_to(thresh, tasmin)
cond = compare(tasmin, op, thresh, constrain=("<", "<="))

out = cond.resample(time=freq).map(
out = resample_map(
cond,
"time",
freq,
rl.last_run_before_date,
window=window,
date=before_date,
dim="time",
coord="dayofyear",
map_kwargs=dict(
window=window,
date=before_date,
dim="time",
coord="dayofyear",
),
)
out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(tasmin))
return out
Expand Down Expand Up @@ -1663,11 +1669,12 @@ def first_snowfall(
thresh = convert_units_to(thresh, prsn, context="hydro")
cond = prsn >= thresh

out = cond.resample(time=freq).map(
out = resample_map(
cond,
"time",
freq,
rl.first_run,
window=1,
dim="time",
coord="dayofyear",
map_kwargs=dict(window=1, dim="time", coord="dayofyear"),
)
out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(prsn))
return out
Expand Down Expand Up @@ -1718,11 +1725,12 @@ def last_snowfall(
thresh = convert_units_to(thresh, prsn, context="hydro")
cond = prsn >= thresh

out = cond.resample(time=freq).map(
out = resample_map(
cond,
"time",
freq,
rl.last_run,
window=1,
dim="time",
coord="dayofyear",
map_kwargs=dict(window=1, dim="time", coord="dayofyear"),
)
out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(prsn))
return out
Expand Down Expand Up @@ -3151,7 +3159,7 @@ def _exceedance_date(grp):
never_reached_val = never_reached
return xarray.where((cumsum <= sum_thresh).all("time"), never_reached_val, out)

dded = c.clip(0).resample(time=freq).map(_exceedance_date)
dded = resample_map(c.clip(0), "time", freq, _exceedance_date)
dded = dded.assign_attrs(
units="", is_dayofyear=np.int32(1), calendar=get_calendar(tas)
)
Expand Down
Loading

0 comments on commit 438ef2e

Please sign in to comment.