From de603135bde50754fc8ef26751c0cfece624f585 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Wed, 27 Sep 2023 09:59:49 -0400 Subject: [PATCH 1/2] Raise when reduced dims are chunked in map_blocks --- CHANGES.rst | 1 + tests/test_sdba/test_base.py | 13 ++++++++++ xclim/sdba/base.py | 50 +++++++++++++++++++++--------------- 3 files changed, 44 insertions(+), 20 deletions(-) diff --git a/CHANGES.rst b/CHANGES.rst index 933c7377b..d5cbd74e3 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -16,6 +16,7 @@ Bug fixes ^^^^^^^^^ * Fixed an error in the `pytest` configuration that prevented copying of testing data to thread-safe caches of workers under certain conditions (this should always occur). (:pull:`1473`). * Coincidentally, this also fixes an error that caused `pytest` to error-out when invoked without an active internet connection. Running `pytest` without network access is now supported (requires cached testing data). (:issue:`1468`). +* Calling a ``sdba.map_blocks``-wrapped function with data chunked along the reduced dimensions will raise an error. This forbids chunking the trained dataset along the distribution dimensions, for example. (:issue:`1481`, :pull:`1482`). Breaking changes ^^^^^^^^^^^^^^^^ diff --git a/tests/test_sdba/test_base.py b/tests/test_sdba/test_base.py index d7117c40d..3cbe2867e 100644 --- a/tests/test_sdba/test_base.py +++ b/tests/test_sdba/test_base.py @@ -203,3 +203,16 @@ def func(ds, *, dim): ).load() assert set(data.data.dims) == {"dayofyear"} assert "leftover" in data + + +def test_map_blocks_error(tas_series): + tas = tas_series(np.arange(366), start="2000-01-01") + tas = tas.expand_dims(lat=[1, 2, 3, 4]).chunk(lat=1) + + # Test dim parsing + @map_blocks(reduces=["lat"], data=[]) + def func(ds, *, group, lon=None): + return ds.tas.rename("data").to_dataset() + + with pytest.raises(ValueError, match="cannot be chunked"): + func(xr.Dataset(dict(tas=tas)), group="time") diff --git a/xclim/sdba/base.py b/xclim/sdba/base.py index d02e3ec47..313e7edab 100644 --- a/xclim/sdba/base.py +++ b/xclim/sdba/base.py @@ -557,26 +557,6 @@ def _map_blocks(ds, **kwargs): ) and group is None: raise ValueError("Missing required `group` argument.") - if uses_dask(ds): - # Use dask if any of the input is dask-backed. - chunks = ( - dict(ds.chunks) - if isinstance(ds, xr.Dataset) - else dict(zip(ds.dims, ds.chunks)) - ) - if group is not None: - badchunks = { - dim: chunks.get(dim) - for dim in group.add_dims + [group.dim] - if len(chunks.get(dim, [])) > 1 - } - if badchunks: - raise ValueError( - f"The dimension(s) over which we group cannot be chunked ({badchunks})." - ) - else: - chunks = None - # Make translation dict if group is not None: placeholders = { @@ -602,6 +582,36 @@ def _map_blocks(ds, **kwargs): f"Dimension {dim} is meant to be added by the computation but it is already on one of the inputs." ) + if uses_dask(ds): + # Use dask if any of the input is dask-backed. + chunks = ( + dict(ds.chunks) + if isinstance(ds, xr.Dataset) + else dict(zip(ds.dims, ds.chunks)) + ) + badchunks = {} + if group is not None: + badchunks.update( + { + dim: chunks.get(dim) + for dim in group.add_dims + [group.dim] + if len(chunks.get(dim, [])) > 1 + } + ) + badchunks.update( + { + dim: chunks.get(dim) + for dim in reduced_dims + if len(chunks.get(dim)) > 1 + } + ) + if badchunks: + raise ValueError( + f"The dimension(s) over which we group, reduce or interpolate cannot be chunked ({badchunks})." + ) + else: + chunks = None + # Dimensions untouched by the function. base_dims = list(set(ds.dims) - set(new_dims) - set(reduced_dims)) From c846ca6c02c286f95fe4d57ce796880d8c649d85 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Wed, 27 Sep 2023 10:41:03 -0400 Subject: [PATCH 2/2] Fix for new xr(?) --- xclim/sdba/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xclim/sdba/base.py b/xclim/sdba/base.py index 313e7edab..8715e7db8 100644 --- a/xclim/sdba/base.py +++ b/xclim/sdba/base.py @@ -492,7 +492,7 @@ def duck_empty(dims, sizes, dtype="float64", chunks=None): def _decode_cf_coords(ds): """Decode coords in-place.""" crds = xr.decode_cf(ds.coords.to_dataset()) - for crdname in ds.coords.keys(): + for crdname in list(ds.coords.keys()): ds[crdname] = crds[crdname] # decode_cf introduces an encoding key for the dtype, which can confuse the netCDF writer dtype = ds[crdname].encoding.get("dtype")