Skip to content

Commit

Permalink
Fix first, last again (#381)
Browse files Browse the repository at this point in the history
* Fix first, last again

Add more first, last tests

* Fix

* fix type ignores

* Add one more property test

* Support cohorts and grouped_combine

* fix docs

* fix profile
  • Loading branch information
dcherian authored Aug 7, 2024
1 parent b05586c commit f0ce343
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 20 deletions.
52 changes: 39 additions & 13 deletions flox/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def _is_minmax_reduction(func: T_Agg) -> bool:


def _is_first_last_reduction(func: T_Agg) -> bool:
return isinstance(func, str) and func in ["nanfirst", "nanlast", "first", "last"]
if isinstance(func, Aggregation):
func = func.name
return func in ["nanfirst", "nanlast", "first", "last"]


def _get_expected_groups(by: T_By, sort: bool) -> T_ExpectIndex:
Expand Down Expand Up @@ -1642,7 +1644,12 @@ def dask_groupby_agg(
# This allows us to discover groups at compute time, support argreductions, lower intermediate
# memory usage (but method="cohorts" would also work to reduce memory in some cases)
labels_are_unknown = is_duck_dask_array(by_input) and expected_groups is None
do_simple_combine = not _is_arg_reduction(agg) and not labels_are_unknown
do_grouped_combine = (
_is_arg_reduction(agg)
or labels_are_unknown
or (_is_first_last_reduction(agg) and array.dtype.kind != "f")
)
do_simple_combine = not do_grouped_combine

if method == "blockwise":
# use the "non dask" code path, but applied blockwise
Expand Down Expand Up @@ -1698,7 +1705,7 @@ def dask_groupby_agg(

tree_reduce = partial(
dask.array.reductions._tree_reduce,
name=f"{name}-reduce",
name=f"{name}-simple-reduce",
dtype=array.dtype,
axis=axis,
keepdims=True,
Expand Down Expand Up @@ -1733,14 +1740,20 @@ def dask_groupby_agg(
groups_ = []
for blks, cohort in chunks_cohorts.items():
cohort_index = pd.Index(cohort)
reindexer = partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
reindexer = (
partial(reindex_intermediates, agg=agg, unique_groups=cohort_index)
if do_simple_combine
else identity
)
reindexed = subset_to_blocks(intermediate, blks, block_shape, reindexer)
# now that we have reindexed, we can set reindex=True explicitlly
reduced_.append(
tree_reduce(
reindexed,
combine=partial(combine, agg=agg, reindex=True),
aggregate=partial(aggregate, expected_groups=cohort_index, reindex=True),
combine=partial(combine, agg=agg, reindex=do_simple_combine),
aggregate=partial(
aggregate, expected_groups=cohort_index, reindex=do_simple_combine
),
)
)
# This is done because pandas promotes to 64-bit types when an Index is created
Expand Down Expand Up @@ -1986,8 +1999,13 @@ def _validate_reindex(
expected_groups,
any_by_dask: bool,
is_dask_array: bool,
array_dtype: Any,
) -> bool | None:
# logger.debug("Entering _validate_reindex: reindex is {}".format(reindex)) # noqa
def first_or_last():
return func in ["first", "last"] or (
_is_first_last_reduction(func) and array_dtype.kind != "f"
)

all_numpy = not is_dask_array and not any_by_dask
if reindex is True and not all_numpy:
Expand All @@ -1997,7 +2015,7 @@ def _validate_reindex(
raise ValueError(
"reindex=True is not a valid choice for method='blockwise' or method='cohorts'."
)
if func in ["first", "last"]:
if first_or_last():
raise ValueError("reindex must be None or False when func is 'first' or 'last.")

if reindex is None:
Expand All @@ -2008,9 +2026,10 @@ def _validate_reindex(
if all_numpy:
return True

if func in ["first", "last"]:
if first_or_last():
# have to do the grouped_combine since there's no good fill_value
reindex = False
# Also needed for nanfirst, nanlast with no-NaN dtypes
return False

if method == "blockwise":
# for grouping by dask arrays, we set reindex=True
Expand Down Expand Up @@ -2412,12 +2431,19 @@ def groupby_reduce(
if method == "cohorts" and any_by_dask:
raise ValueError(f"method={method!r} can only be used when grouping by numpy arrays.")

if not is_duck_array(array):
array = np.asarray(array)

reindex = _validate_reindex(
reindex, func, method, expected_groups, any_by_dask, is_duck_dask_array(array)
reindex,
func,
method,
expected_groups,
any_by_dask,
is_duck_dask_array(array),
array.dtype,
)

if not is_duck_array(array):
array = np.asarray(array)
is_bool_array = np.issubdtype(array.dtype, bool)
array = array.astype(np.intp) if is_bool_array else array

Expand Down Expand Up @@ -2601,7 +2627,7 @@ def groupby_reduce(

# TODO: clean this up
reindex = _validate_reindex(
reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array)
reindex, func, method, expected_, any_by_dask, is_duck_dask_array(array), array.dtype
)

if TYPE_CHECKING:
Expand Down
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.too_slow],
)
settings.register_profile(
"local",
"default",
max_examples=300,
suppress_health_check=[HealthCheck.filter_too_much, HealthCheck.too_slow],
verbosity=Verbosity.verbose,
)
settings.load_profile("default")


@pytest.fixture(
Expand Down
93 changes: 87 additions & 6 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,33 @@ def test_dask_reduce_axis_subset():
)


@pytest.mark.parametrize("group_idx", [[0, 1, 0], [0, 0, 1], [1, 0, 0], [1, 1, 0]])
@pytest.mark.parametrize(
"func",
[
# "first", "last",
"nanfirst",
"nanlast",
],
)
@pytest.mark.parametrize(
"chunks",
[
None,
pytest.param(1, marks=pytest.mark.skipif(not has_dask, reason="no dask")),
pytest.param(2, marks=pytest.mark.skipif(not has_dask, reason="no dask")),
pytest.param(3, marks=pytest.mark.skipif(not has_dask, reason="no dask")),
],
)
def test_first_last_useless(func, chunks, group_idx):
array = np.array([[0, 0, 0], [0, 0, 0]], dtype=np.int8)
if chunks is not None:
array = dask.array.from_array(array, chunks=chunks)
actual, _ = groupby_reduce(array, np.array(group_idx), func=func, engine="numpy")
expected = np.array([[0, 0], [0, 0]], dtype=np.int8)
assert_equal(actual, expected)


@pytest.mark.parametrize("func", ["first", "last", "nanfirst", "nanlast"])
@pytest.mark.parametrize("axis", [(0, 1)])
def test_first_last_disallowed(axis, func):
Expand Down Expand Up @@ -1563,18 +1590,36 @@ def test_validate_reindex_map_reduce(
dask_expected, reindex, func, expected_groups, any_by_dask
) -> None:
actual = _validate_reindex(
reindex, func, "map-reduce", expected_groups, any_by_dask, is_dask_array=True
reindex,
func,
"map-reduce",
expected_groups,
any_by_dask,
is_dask_array=True,
array_dtype=np.dtype("int32"),
)
assert actual is dask_expected

# always reindex with all numpy inputs
actual = _validate_reindex(
reindex, func, "map-reduce", expected_groups, any_by_dask=False, is_dask_array=False
reindex,
func,
"map-reduce",
expected_groups,
any_by_dask=False,
is_dask_array=False,
array_dtype=np.dtype("int32"),
)
assert actual

actual = _validate_reindex(
True, func, "map-reduce", expected_groups, any_by_dask=False, is_dask_array=False
True,
func,
"map-reduce",
expected_groups,
any_by_dask=False,
is_dask_array=False,
array_dtype=np.dtype("int32"),
)
assert actual

Expand All @@ -1584,19 +1629,37 @@ def test_validate_reindex() -> None:
for method in methods:
with pytest.raises(NotImplementedError):
_validate_reindex(
True, "argmax", method, expected_groups=None, any_by_dask=False, is_dask_array=True
True,
"argmax",
method,
expected_groups=None,
any_by_dask=False,
is_dask_array=True,
array_dtype=np.dtype("int32"),
)

methods: list[T_Method] = ["blockwise", "cohorts"]
for method in methods:
with pytest.raises(ValueError):
_validate_reindex(
True, "sum", method, expected_groups=None, any_by_dask=False, is_dask_array=True
True,
"sum",
method,
expected_groups=None,
any_by_dask=False,
is_dask_array=True,
array_dtype=np.dtype("int32"),
)

for func in ["sum", "argmax"]:
actual = _validate_reindex(
None, func, method, expected_groups=None, any_by_dask=False, is_dask_array=True
None,
func,
method,
expected_groups=None,
any_by_dask=False,
is_dask_array=True,
array_dtype=np.dtype("int32"),
)
assert actual is False

Expand All @@ -1608,6 +1671,7 @@ def test_validate_reindex() -> None:
expected_groups=np.array([1, 2, 3]),
any_by_dask=False,
is_dask_array=True,
array_dtype=np.dtype("int32"),
)

assert _validate_reindex(
Expand All @@ -1617,6 +1681,7 @@ def test_validate_reindex() -> None:
expected_groups=np.array([1, 2, 3]),
any_by_dask=True,
is_dask_array=True,
array_dtype=np.dtype("int32"),
)
assert _validate_reindex(
None,
Expand All @@ -1625,8 +1690,24 @@ def test_validate_reindex() -> None:
expected_groups=np.array([1, 2, 3]),
any_by_dask=True,
is_dask_array=True,
array_dtype=np.dtype("int32"),
)

kwargs = dict(
method="blockwise",
expected_groups=np.array([1, 2, 3]),
any_by_dask=True,
is_dask_array=True,
)

for func in ["nanfirst", "nanlast"]:
assert not _validate_reindex(None, func, array_dtype=np.dtype("int32"), **kwargs) # type: ignore[arg-type]
assert _validate_reindex(None, func, array_dtype=np.dtype("float32"), **kwargs) # type: ignore[arg-type]

for func in ["first", "last"]:
assert not _validate_reindex(None, func, array_dtype=np.dtype("int32"), **kwargs) # type: ignore[arg-type]
assert not _validate_reindex(None, func, array_dtype=np.dtype("float32"), **kwargs) # type: ignore[arg-type]


@requires_dask
def test_1d_blockwise_sort_optimization():
Expand Down
15 changes: 15 additions & 0 deletions tests/test_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
pytest.importorskip("cftime")

import dask
import hypothesis.extra.numpy as npst
import hypothesis.strategies as st
import numpy as np
from hypothesis import assume, given, note
Expand All @@ -19,6 +20,7 @@

from . import assert_equal
from .strategies import by_arrays, chunked_arrays, func_st, numeric_arrays
from .strategies import chunks as chunks_strategy

dask.config.set(scheduler="sync")

Expand Down Expand Up @@ -208,3 +210,16 @@ def test_first_last(data, array: dask.array.Array, func: str) -> None:
first, *_ = groupby_reduce(array, by, func=func, engine="flox")
second, *_ = groupby_reduce(array, by, func=mate, engine="flox")
assert_equal(first, second)


@given(data=st.data(), func=st.sampled_from(["nanfirst", "nanlast"]))
def test_first_last_useless(data, func):
shape = data.draw(npst.array_shapes())
by = data.draw(by_arrays(shape=shape[slice(-1, None)]))
chunks = data.draw(chunks_strategy(shape=shape))
array = np.zeros(shape, dtype=np.int8)
if chunks is not None:
array = dask.array.from_array(array, chunks=chunks)
actual, groups = groupby_reduce(array, by, axis=-1, func=func, engine="numpy")
expected = np.zeros(shape[:-1] + (len(groups),), dtype=array.dtype)
assert_equal(actual, expected)

0 comments on commit f0ce343

Please sign in to comment.