Skip to content

Commit

Permalink
Fix direct reductions of Xarray objects (#339)
Browse files Browse the repository at this point in the history
* Fix direct reductions of Xarray objects

Closes pydata/xarray#8819

* Fix doctest
  • Loading branch information
dcherian authored Mar 13, 2024
1 parent 41372e0 commit b0cabf3
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 38 deletions.
16 changes: 9 additions & 7 deletions flox/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,10 +201,10 @@ def xarray_reduce(
>>> da = da = xr.ones_like(labels)
>>> # Sum all values in da that matches the elements in the group index:
>>> xarray_reduce(da, labels, func="sum")
<xarray.DataArray 'label' (label: 4)>
<xarray.DataArray 'label' (label: 4)> Size: 32B
array([3, 2, 2, 2])
Coordinates:
* label (label) int64 0 1 2 3
* label (label) int64 32B 0 1 2 3
"""

if skipna is not None and isinstance(func, Aggregation):
Expand Down Expand Up @@ -303,14 +303,16 @@ def xarray_reduce(
# reducing along a dimension along which groups do not vary
# This is really just a normal reduction.
# This is not right when binning so we exclude.
if isinstance(func, str):
dsfunc = func[3:] if skipna else func
else:
if isinstance(func, str) and func.startswith("nan"):
raise ValueError(f"Specify func={func[3:]}, skipna=True instead of func={func}")
elif isinstance(func, Aggregation):
raise NotImplementedError(
"func must be a string when reducing along a dimension not present in `by`"
)
# TODO: skipna needs test
result = getattr(ds_broad, dsfunc)(dim=dim_tuple, skipna=skipna)
# skipna is not supported for all reductions
# https://github.com/pydata/xarray/issues/8819
kwargs = {"skipna": skipna} if skipna is not None else {}
result = getattr(ds_broad, func)(dim=dim_tuple, **kwargs)
if isinstance(obj, xr.DataArray):
return obj._from_temp_dataset(result)
else:
Expand Down
32 changes: 32 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,35 @@ def assert_equal_tuple(a, b):
np.testing.assert_array_equal(a_, b_)
else:
assert a_ == b_


SCIPY_STATS_FUNCS = ("mode", "nanmode")
BLOCKWISE_FUNCS = ("median", "nanmedian", "quantile", "nanquantile") + SCIPY_STATS_FUNCS
ALL_FUNCS = (
"sum",
"nansum",
"argmax",
"nanfirst",
"nanargmax",
"prod",
"nanprod",
"mean",
"nanmean",
"var",
"nanvar",
"std",
"nanstd",
"max",
"nanmax",
"min",
"nanmin",
"argmin",
"nanargmin",
"any",
"all",
"nanlast",
"median",
"nanmedian",
"quantile",
"nanquantile",
) + tuple(pytest.param(func, marks=requires_scipy) for func in SCIPY_STATS_FUNCS)
34 changes: 3 additions & 31 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@
)

from . import (
ALL_FUNCS,
BLOCKWISE_FUNCS,
SCIPY_STATS_FUNCS,
assert_equal,
assert_equal_tuple,
has_dask,
raise_if_dask_computes,
requires_dask,
requires_scipy,
)

logger = logging.getLogger("flox")
Expand All @@ -60,36 +62,6 @@ def dask_array_ones(*args):


DEFAULT_QUANTILE = 0.9
SCIPY_STATS_FUNCS = ("mode", "nanmode")
BLOCKWISE_FUNCS = ("median", "nanmedian", "quantile", "nanquantile") + SCIPY_STATS_FUNCS
ALL_FUNCS = (
"sum",
"nansum",
"argmax",
"nanfirst",
"nanargmax",
"prod",
"nanprod",
"mean",
"nanmean",
"var",
"nanvar",
"std",
"nanstd",
"max",
"nanmax",
"min",
"nanmin",
"argmin",
"nanargmin",
"any",
"all",
"nanlast",
"median",
"nanmedian",
"quantile",
"nanquantile",
) + tuple(pytest.param(func, marks=requires_scipy) for func in SCIPY_STATS_FUNCS)

if TYPE_CHECKING:
from flox.core import T_Agg, T_Engine, T_ExpectedGroupsOpt, T_Method
Expand Down
30 changes: 30 additions & 0 deletions tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from flox.xarray import rechunk_for_blockwise, xarray_reduce

from . import (
ALL_FUNCS,
assert_equal,
has_dask,
raise_if_dask_computes,
Expand Down Expand Up @@ -710,3 +711,32 @@ def test_multiple_quantiles(q, chunk, by_ndim, skipna):
with xr.set_options(use_flox=False):
expected = da.groupby(by).quantile(q, skipna=skipna)
xr.testing.assert_allclose(expected, actual)


@pytest.mark.parametrize("func", ALL_FUNCS)
def test_direct_reduction(func):
if "arg" in func or "mode" in func:
pytest.skip()
# regression test for https://github.com/pydata/xarray/issues/8819
rand = np.random.choice([True, False], size=(2, 3))
if func not in ["any", "all"]:
rand = rand.astype(float)

if "nan" in func:
func = func[3:]
kwargs = {"skipna": True}
else:
kwargs = {}

if "first" not in func and "last" not in func:
kwargs["dim"] = "y"

if "quantile" in func:
kwargs["q"] = 0.9

data = xr.DataArray(rand, dims=("x", "y"), coords={"x": [10, 20], "y": [0, 1, 2]})
with xr.set_options(use_flox=True):
actual = getattr(data.groupby("x", squeeze=False), func)(**kwargs)
with xr.set_options(use_flox=False):
expected = getattr(data.groupby("x", squeeze=False), func)(**kwargs)
xr.testing.assert_identical(expected, actual)

0 comments on commit b0cabf3

Please sign in to comment.