Skip to content

Commit

Permalink
[backport 2.3.x] String dtype: implement sum reduction (pandas-dev#59853
Browse files Browse the repository at this point in the history
) (pandas-dev#60157)

String dtype: implement sum reduction (pandas-dev#59853)

(cherry picked from commit 2fdb16b)
  • Loading branch information
jorisvandenbossche authored Oct 31, 2024
1 parent e620e9d commit 4f189a4
Show file tree
Hide file tree
Showing 15 changed files with 121 additions and 150 deletions.
2 changes: 1 addition & 1 deletion doc/source/whatsnew/v2.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ enhancement1
Other enhancements
^^^^^^^^^^^^^^^^^^

-
- The :meth:`~Series.sum` reduction is now implemented for ``StringDtype`` columns (:issue:`59853`)
-

.. ---------------------------------------------------------------------------
Expand Down
4 changes: 4 additions & 0 deletions pandas/core/array_algos/masked_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def _reductions(
):
return libmissing.NA

if values.dtype == np.dtype(object):
# object dtype does not support `where` without passing an initial
values = values[~mask]
return func(values, axis=axis, **kwargs)
return func(values, where=~mask, axis=axis, **kwargs)


Expand Down
32 changes: 32 additions & 0 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
unpack_tuple_and_ellipses,
validate_indices,
)
from pandas.core.nanops import check_below_min_count
from pandas.core.strings.base import BaseStringArrayMethods

from pandas.io._util import _arrow_dtype_mapping
Expand Down Expand Up @@ -1694,6 +1695,37 @@ def pyarrow_meth(data, skip_nulls, **kwargs):
denominator = pc.sqrt_checked(pc.count(self._pa_array))
return pc.divide_checked(numerator, denominator)

elif name == "sum" and (
pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type)
):

def pyarrow_meth(data, skip_nulls, min_count=0): # type: ignore[misc]
mask = pc.is_null(data) if data.null_count > 0 else None
if skip_nulls:
if min_count > 0 and check_below_min_count(
(len(data),),
None if mask is None else mask.to_numpy(),
min_count,
):
return pa.scalar(None, type=data.type)
if data.null_count > 0:
# binary_join returns null if there is any null ->
# have to filter out any nulls
data = data.filter(pc.invert(mask))
else:
if mask is not None or check_below_min_count(
(len(data),), None, min_count
):
return pa.scalar(None, type=data.type)

if pa.types.is_large_string(data.type):
# binary_join only supports string, not large_string
data = data.cast(pa.string())
data_list = pa.ListArray.from_arrays(
[0, len(data)], data.combine_chunks()
)[0]
return pc.binary_join(data_list, "")

else:
pyarrow_name = {
"median": "quantile",
Expand Down
18 changes: 16 additions & 2 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,8 +812,8 @@ def _reduce(
else:
return nanops.nanall(self._ndarray, skipna=skipna)

if name in ["min", "max"]:
result = getattr(self, name)(skipna=skipna, axis=axis)
if name in ["min", "max", "sum"]:
result = getattr(self, name)(skipna=skipna, axis=axis, **kwargs)
if keepdims:
return self._from_sequence([result], dtype=self.dtype)
return result
Expand All @@ -839,6 +839,20 @@ def max(self, axis=None, skipna: bool = True, **kwargs) -> Scalar:
)
return self._wrap_reduction_result(axis, result)

def sum(
self,
*,
axis: AxisInt | None = None,
skipna: bool = True,
min_count: int = 0,
**kwargs,
) -> Scalar:
nv.validate_sum((), kwargs)
result = masked_reductions.sum(
values=self._ndarray, mask=self.isna(), skipna=skipna
)
return self._wrap_reduction_result(axis, result)

def value_counts(self, dropna: bool = True) -> Series:
from pandas.core.algorithms import value_counts_internal as value_counts

Expand Down
6 changes: 5 additions & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,11 @@ def _reduce(
return result.astype(np.bool_)
return result

result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
if name in ("min", "max", "sum", "argmin", "argmax"):
result = self._reduce_calc(name, skipna=skipna, keepdims=keepdims, **kwargs)
else:
raise TypeError(f"Cannot perform reduction '{name}' with string dtype")

if name in ("argmin", "argmax") and isinstance(result, pa.Array):
return self._convert_int_result(result)
elif isinstance(result, pa.Array):
Expand Down
10 changes: 0 additions & 10 deletions pandas/tests/apply/test_frame_apply.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas.compat import HAS_PYARROW

from pandas.core.dtypes.dtypes import CategoricalDtype

import pandas as pd
Expand Down Expand Up @@ -1173,7 +1169,6 @@ def test_agg_with_name_as_column_name():
tm.assert_series_equal(result, expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_agg_multiple_mixed():
# GH 20909
mdf = DataFrame(
Expand Down Expand Up @@ -1202,9 +1197,6 @@ def test_agg_multiple_mixed():
tm.assert_frame_equal(result, expected)


@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
def test_agg_multiple_mixed_raises():
# GH 20909
mdf = DataFrame(
Expand Down Expand Up @@ -1294,7 +1286,6 @@ def test_agg_reduce(axis, float_frame):
tm.assert_frame_equal(result, expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)")
def test_nuiscance_columns():
# GH 15015
df = DataFrame(
Expand Down Expand Up @@ -1471,7 +1462,6 @@ def test_apply_datetime_tz_issue(engine, request):
tm.assert_series_equal(result, expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.parametrize("df", [DataFrame({"A": ["a", None], "B": ["c", "d"]})])
@pytest.mark.parametrize("method", ["min", "max", "sum"])
def test_mixed_column_raises(df, method, using_infer_string):
Expand Down
39 changes: 20 additions & 19 deletions pandas/tests/apply/test_invalid_arg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas.compat import HAS_PYARROW
from pandas.errors import SpecificationError

from pandas import (
Expand Down Expand Up @@ -212,10 +209,6 @@ def transform(row):
data.apply(transform, axis=1)


# we should raise a proper TypeError instead of propagating the pyarrow error
@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
@pytest.mark.parametrize(
"df, func, expected",
tm.get_cython_table_params(
Expand All @@ -225,21 +218,25 @@ def transform(row):
def test_agg_cython_table_raises_frame(df, func, expected, axis, using_infer_string):
# GH 21224
if using_infer_string:
import pyarrow as pa
if df.dtypes.iloc[0].storage == "pyarrow":
import pyarrow as pa

expected = (expected, pa.lib.ArrowNotImplementedError)
# TODO(infer_string)
# should raise a proper TypeError instead of propagating the pyarrow error

msg = "can't multiply sequence by non-int of type 'str'|has no kernel"
expected = (expected, pa.lib.ArrowNotImplementedError)
else:
expected = (expected, NotImplementedError)

msg = (
"can't multiply sequence by non-int of type 'str'|has no kernel|cannot perform"
)
warn = None if isinstance(func, str) else FutureWarning
with pytest.raises(expected, match=msg):
with tm.assert_produces_warning(warn, match="using DataFrame.cumprod"):
df.agg(func, axis=axis)


# we should raise a proper TypeError instead of propagating the pyarrow error
@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
@pytest.mark.parametrize(
"series, func, expected",
chain(
Expand All @@ -263,11 +260,15 @@ def test_agg_cython_table_raises_series(series, func, expected, using_infer_stri
msg = r"Cannot convert \['a' 'b' 'c'\] to numeric"

if using_infer_string:
import pyarrow as pa

expected = (expected, pa.lib.ArrowNotImplementedError)

msg = msg + "|does not support|has no kernel"
if series.dtype.storage == "pyarrow":
import pyarrow as pa

# TODO(infer_string)
# should raise a proper TypeError instead of propagating the pyarrow error
expected = (expected, pa.lib.ArrowNotImplementedError)
else:
expected = (expected, NotImplementedError)
msg = msg + "|does not support|has no kernel|Cannot perform|cannot perform"
warn = None if isinstance(func, str) else FutureWarning

with pytest.raises(expected, match=msg):
Expand Down
2 changes: 0 additions & 2 deletions pandas/tests/arrays/string_/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,15 +444,13 @@ def test_astype_float(dtype, any_float_dtype):


@pytest.mark.parametrize("skipna", [True, False])
@pytest.mark.xfail(reason="Not implemented StringArray.sum")
def test_reduce(skipna, dtype):
arr = pd.Series(["a", "b", "c"], dtype=dtype)
result = arr.sum(skipna=skipna)
assert result == "abc"


@pytest.mark.parametrize("skipna", [True, False])
@pytest.mark.xfail(reason="Not implemented StringArray.sum")
def test_reduce_missing(skipna, dtype):
arr = pd.Series([None, "a", None, "b", "c", None], dtype=dtype)
result = arr.sum(skipna=skipna)
Expand Down
26 changes: 5 additions & 21 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,11 @@ def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
pass
else:
return False
elif pa.types.is_binary(pa_dtype) and op_name == "sum":
return False
elif (
pa.types.is_string(pa_dtype) or pa.types.is_binary(pa_dtype)
) and op_name in [
"sum",
"mean",
"median",
"prod",
Expand Down Expand Up @@ -553,6 +554,7 @@ def test_reduce_series_boolean(
return super().test_reduce_series_boolean(data, all_boolean_reductions, skipna)

def _get_expected_reduction_dtype(self, arr, op_name: str, skipna: bool):
pa_type = arr._pa_array.type
if op_name in ["max", "min"]:
cmp_dtype = arr.dtype
elif arr.dtype.name == "decimal128(7, 3)[pyarrow]":
Expand All @@ -562,6 +564,8 @@ def _get_expected_reduction_dtype(self, arr, op_name: str, skipna: bool):
cmp_dtype = "float64[pyarrow]"
elif op_name in ["median", "var", "std", "mean", "skew"]:
cmp_dtype = "float64[pyarrow]"
elif op_name == "sum" and pa.types.is_string(pa_type):
cmp_dtype = arr.dtype
else:
cmp_dtype = {
"i": "int64[pyarrow]",
Expand All @@ -585,26 +589,6 @@ def test_median_not_approximate(self, typ):
result = pd.Series([1, 2], dtype=f"{typ}[pyarrow]").median()
assert result == 1.5

def test_in_numeric_groupby(self, data_for_grouping):
dtype = data_for_grouping.dtype
if is_string_dtype(dtype):
df = pd.DataFrame(
{
"A": [1, 1, 2, 2, 3, 3, 1, 4],
"B": data_for_grouping,
"C": [1, 1, 1, 1, 1, 1, 1, 1],
}
)

expected = pd.Index(["C"])
msg = re.escape(f"agg function failed [how->sum,dtype->{dtype}")
with pytest.raises(TypeError, match=msg):
df.groupby("A").sum()
result = df.groupby("A").sum(numeric_only=True).columns
tm.assert_index_equal(result, expected)
else:
super().test_in_numeric_groupby(data_for_grouping)

def test_construct_from_string_own_name(self, dtype, request):
pa_dtype = dtype.pyarrow_dtype
if pa.types.is_decimal(pa_dtype):
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/extension/test_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def _get_expected_exception(

def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
return (
op_name in ["min", "max"]
op_name in ["min", "max", "sum"]
or ser.dtype.na_value is np.nan # type: ignore[union-attr]
and op_name in ("any", "all")
)
Expand Down
Loading

0 comments on commit 4f189a4

Please sign in to comment.