From a9f51a154f1e7e06513f5a937065f6d2fd699088 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Tue, 13 Aug 2024 22:56:11 -0600 Subject: [PATCH] Fix bug with NaNs in `by` and method='blockwise' (#384) xref https://github.com/pydata/xarray/pull/9320 --- flox/core.py | 12 ++++++++++-- tests/test_core.py | 14 ++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/flox/core.py b/flox/core.py index d7fa5f6a..bb84ab80 100644 --- a/flox/core.py +++ b/flox/core.py @@ -2663,10 +2663,18 @@ def groupby_reduce( groups = (groups[0][sorted_idx],) if factorize_early: + assert len(groups) == 1 + (groups_,) = groups # nan group labels are factorized to -1, and preserved # now we get rid of them by reindexing - # This also handles bins with no data - result = reindex_(result, from_=groups[0], to=expected_, fill_value=fill_value).reshape( + # First, for "blockwise", we can have -1 repeated in different blocks + # This breaks the reindexing so remove those first. + if method == "blockwise" and (mask := groups_ == -1).sum(axis=-1) > 1: + result = result[..., ~mask] + groups_ = groups_[..., ~mask] + + # This reindex also handles bins with no data + result = reindex_(result, from_=groups_, to=expected_, fill_value=fill_value).reshape( result.shape[:-1] + grp_shape ) groups = final_groups diff --git a/tests/test_core.py b/tests/test_core.py index 5d4e7ec3..2d225206 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1929,3 +1929,17 @@ def test_ffill_bfill(chunks, size, add_nan_by, func): expected = flox.groupby_scan(array.compute(), by, func=func) actual = flox.groupby_scan(array, by, func=func) assert_equal(expected, actual) + + +@requires_dask +def test_blockwise_nans(): + array = dask.array.ones((1, 10), chunks=2) + by = np.array([-1, 0, -1, 1, -1, 2, -1, 3, 4, 4]) + actual, actual_groups = flox.groupby_reduce( + array, by, func="sum", expected_groups=pd.RangeIndex(0, 5) + ) + expected, expected_groups = flox.groupby_reduce( + array.compute(), by, func="sum", expected_groups=pd.RangeIndex(0, 5) + ) + assert_equal(expected_groups, actual_groups) + assert_equal(expected, actual)