Skip to content

Commit

Permalink
Backport PR pandas-dev#54566 on branch 2.1.x (ENH: support Index.any/…
Browse files Browse the repository at this point in the history
…all with float, timedelta64 dtypes) (pandas-dev#54693)

Backport PR pandas-dev#54566: ENH: support Index.any/all with float, timedelta64 dtypes

Co-authored-by: jbrockmendel <[email protected]>
  • Loading branch information
meeseeksmachine and jbrockmendel authored Aug 22, 2023
1 parent 6c5e79b commit 1dbd792
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 23 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ Other enhancements
- Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to ``lzma.LZMAFile`` (:issue:`52979`)
- Reductions :meth:`Series.argmax`, :meth:`Series.argmin`, :meth:`Series.idxmax`, :meth:`Series.idxmin`, :meth:`Index.argmax`, :meth:`Index.argmin`, :meth:`DataFrame.idxmax`, :meth:`DataFrame.idxmin` are now supported for object-dtype (:issue:`4279`, :issue:`18021`, :issue:`40685`, :issue:`43697`)
- :meth:`DataFrame.to_parquet` and :func:`read_parquet` will now write and read ``attrs`` respectively (:issue:`54346`)
- :meth:`Index.all` and :meth:`Index.any` with floating dtypes and timedelta64 dtypes no longer raise ``TypeError``, matching the :meth:`Series.all` and :meth:`Series.any` behavior (:issue:`54566`)
- :meth:`Series.cummax`, :meth:`Series.cummin` and :meth:`Series.cumprod` are now supported for pyarrow dtypes with pyarrow version 13.0 and above (:issue:`52085`)
- Added support for the DataFrame Consortium Standard (:issue:`54383`)
- Performance improvement in :meth:`.DataFrameGroupBy.quantile` and :meth:`.SeriesGroupBy.quantile` (:issue:`51722`)
Expand Down
28 changes: 15 additions & 13 deletions pandas/core/indexes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7215,11 +7215,12 @@ def any(self, *args, **kwargs):
"""
nv.validate_any(args, kwargs)
self._maybe_disable_logical_methods("any")
# error: Argument 1 to "any" has incompatible type "ArrayLike"; expected
# "Union[Union[int, float, complex, str, bytes, generic], Sequence[Union[int,
# float, complex, str, bytes, generic]], Sequence[Sequence[Any]],
# _SupportsArray]"
return np.any(self.values) # type: ignore[arg-type]
vals = self._values
if not isinstance(vals, np.ndarray):
# i.e. EA, call _reduce instead of "any" to get TypeError instead
# of AttributeError
return vals._reduce("any")
return np.any(vals)

def all(self, *args, **kwargs):
"""
Expand Down Expand Up @@ -7262,11 +7263,12 @@ def all(self, *args, **kwargs):
"""
nv.validate_all(args, kwargs)
self._maybe_disable_logical_methods("all")
# error: Argument 1 to "all" has incompatible type "ArrayLike"; expected
# "Union[Union[int, float, complex, str, bytes, generic], Sequence[Union[int,
# float, complex, str, bytes, generic]], Sequence[Sequence[Any]],
# _SupportsArray]"
return np.all(self.values) # type: ignore[arg-type]
vals = self._values
if not isinstance(vals, np.ndarray):
# i.e. EA, call _reduce instead of "all" to get TypeError instead
# of AttributeError
return vals._reduce("all")
return np.all(vals)

@final
def _maybe_disable_logical_methods(self, opname: str_t) -> None:
Expand All @@ -7275,9 +7277,9 @@ def _maybe_disable_logical_methods(self, opname: str_t) -> None:
"""
if (
isinstance(self, ABCMultiIndex)
or needs_i8_conversion(self.dtype)
or isinstance(self.dtype, (IntervalDtype, CategoricalDtype))
or is_float_dtype(self.dtype)
# TODO(3.0): PeriodArray and DatetimeArray any/all will raise,
# so checking needs_i8_conversion will be unnecessary
or (needs_i8_conversion(self.dtype) and self.dtype.kind != "m")
):
# This call will raise
make_invalid_op(opname)(self)
Expand Down
8 changes: 8 additions & 0 deletions pandas/tests/indexes/numeric/test_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,14 @@ def test_fillna_float64(self):
exp = Index([1.0, "obj", 3.0], name="x")
tm.assert_index_equal(idx.fillna("obj"), exp, exact=True)

def test_logical_compat(self, simple_index):
idx = simple_index
assert idx.all() == idx.values.all()
assert idx.any() == idx.values.any()

assert idx.all() == idx.to_series().all()
assert idx.any() == idx.to_series().any()


class TestNumericInt:
@pytest.fixture(params=[np.int64, np.int32, np.int16, np.int8, np.uint64])
Expand Down
7 changes: 6 additions & 1 deletion pandas/tests/indexes/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,12 @@ def test_format_missing(self, vals, nulls_fixture):
@pytest.mark.parametrize("op", ["any", "all"])
def test_logical_compat(self, op, simple_index):
index = simple_index
assert getattr(index, op)() == getattr(index.values, op)()
left = getattr(index, op)()
assert left == getattr(index.values, op)()
right = getattr(index.to_series(), op)()
# left might not match right exactly in e.g. string cases where the
# because we use np.any/all instead of .any/all
assert bool(left) == bool(right)

@pytest.mark.parametrize(
"index", ["string", "int64", "int32", "float64", "float32"], indirect=True
Expand Down
26 changes: 17 additions & 9 deletions pandas/tests/indexes/test_old_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,17 +209,25 @@ def test_numeric_compat(self, simple_index):
1 // idx

def test_logical_compat(self, simple_index):
if (
isinstance(simple_index, RangeIndex)
or is_numeric_dtype(simple_index.dtype)
or simple_index.dtype == object
):
if simple_index.dtype == object:
pytest.skip("Tested elsewhere.")
idx = simple_index
with pytest.raises(TypeError, match="cannot perform all"):
idx.all()
with pytest.raises(TypeError, match="cannot perform any"):
idx.any()
if idx.dtype.kind in "iufcbm":
assert idx.all() == idx._values.all()
assert idx.all() == idx.to_series().all()
assert idx.any() == idx._values.any()
assert idx.any() == idx.to_series().any()
else:
msg = "cannot perform (any|all)"
if isinstance(idx, IntervalIndex):
msg = (
r"'IntervalArray' with dtype interval\[.*\] does "
"not support reduction '(any|all)'"
)
with pytest.raises(TypeError, match=msg):
idx.all()
with pytest.raises(TypeError, match=msg):
idx.any()

def test_repr_roundtrip(self, simple_index):
if isinstance(simple_index, IntervalIndex):
Expand Down

0 comments on commit 1dbd792

Please sign in to comment.