Skip to content

Commit

Permalink
Backport PR pandas-dev#54574: ENH: add cummax/cummin/cumprod support …
Browse files Browse the repository at this point in the history
…for arrow dtypes
  • Loading branch information
lukemanley authored and meeseeksmachine committed Aug 17, 2023
1 parent 005d876 commit d7f62e5
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 13 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ Other enhancements
- Many read/to_* functions, such as :meth:`DataFrame.to_pickle` and :func:`read_csv`, support forwarding compression arguments to ``lzma.LZMAFile`` (:issue:`52979`)
- Reductions :meth:`Series.argmax`, :meth:`Series.argmin`, :meth:`Series.idxmax`, :meth:`Series.idxmin`, :meth:`Index.argmax`, :meth:`Index.argmin`, :meth:`DataFrame.idxmax`, :meth:`DataFrame.idxmin` are now supported for object-dtype (:issue:`4279`, :issue:`18021`, :issue:`40685`, :issue:`43697`)
- :meth:`DataFrame.to_parquet` and :func:`read_parquet` will now write and read ``attrs`` respectively (:issue:`54346`)
- :meth:`Series.cummax`, :meth:`Series.cummin` and :meth:`Series.cumprod` are now supported for pyarrow dtypes with pyarrow version 13.0 and above (:issue:`52085`)
- Added support for the DataFrame Consortium Standard (:issue:`54383`)
- Performance improvement in :meth:`.DataFrameGroupBy.quantile` and :meth:`.SeriesGroupBy.quantile` (:issue:`51722`)

Expand Down
17 changes: 14 additions & 3 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,9 @@ def _accumulate(
NotImplementedError : subclass does not define accumulations
"""
pyarrow_name = {
"cummax": "cumulative_max",
"cummin": "cumulative_min",
"cumprod": "cumulative_prod_checked",
"cumsum": "cumulative_sum_checked",
}.get(name, name)
pyarrow_meth = getattr(pc, pyarrow_name, None)
Expand All @@ -1398,12 +1401,20 @@ def _accumulate(
data_to_accum = self._pa_array

pa_dtype = data_to_accum.type
if pa.types.is_duration(pa_dtype):
data_to_accum = data_to_accum.cast(pa.int64())

convert_to_int = (
pa.types.is_temporal(pa_dtype) and name in ["cummax", "cummin"]
) or (pa.types.is_duration(pa_dtype) and name == "cumsum")

if convert_to_int:
if pa_dtype.bit_width == 32:
data_to_accum = data_to_accum.cast(pa.int32())
else:
data_to_accum = data_to_accum.cast(pa.int64())

result = pyarrow_meth(data_to_accum, skip_nulls=skipna, **kwargs)

if pa.types.is_duration(pa_dtype):
if convert_to_int:
result = result.cast(pa_dtype)

return type(self)(result)
Expand Down
33 changes: 23 additions & 10 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,10 +347,15 @@ class TestBaseAccumulateTests(base.BaseAccumulateTests):
def check_accumulate(self, ser, op_name, skipna):
result = getattr(ser, op_name)(skipna=skipna)

if ser.dtype.kind == "m":
pa_type = ser.dtype.pyarrow_dtype
if pa.types.is_temporal(pa_type):
# Just check that we match the integer behavior.
ser = ser.astype("int64[pyarrow]")
result = result.astype("int64[pyarrow]")
if pa_type.bit_width == 32:
int_type = "int32[pyarrow]"
else:
int_type = "int64[pyarrow]"
ser = ser.astype(int_type)
result = result.astype(int_type)

result = result.astype("Float64")
expected = getattr(ser.astype("Float64"), op_name)(skipna=skipna)
Expand All @@ -361,14 +366,20 @@ def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
# attribute "pyarrow_dtype"
pa_type = ser.dtype.pyarrow_dtype # type: ignore[union-attr]

if pa.types.is_string(pa_type) or pa.types.is_binary(pa_type):
if op_name in ["cumsum", "cumprod"]:
if (
pa.types.is_string(pa_type)
or pa.types.is_binary(pa_type)
or pa.types.is_decimal(pa_type)
):
if op_name in ["cumsum", "cumprod", "cummax", "cummin"]:
return False
elif pa.types.is_temporal(pa_type) and not pa.types.is_duration(pa_type):
if op_name in ["cumsum", "cumprod"]:
elif pa.types.is_boolean(pa_type):
if op_name in ["cumprod", "cummax", "cummin"]:
return False
elif pa.types.is_duration(pa_type):
if op_name == "cumprod":
elif pa.types.is_temporal(pa_type):
if op_name == "cumsum" and not pa.types.is_duration(pa_type):
return False
elif op_name == "cumprod":
return False
return True

Expand All @@ -384,7 +395,9 @@ def test_accumulate_series(self, data, all_numeric_accumulations, skipna, reques
data, all_numeric_accumulations, skipna
)

if all_numeric_accumulations != "cumsum" or pa_version_under9p0:
if pa_version_under9p0 or (
pa_version_under13p0 and all_numeric_accumulations != "cumsum"
):
# xfailing takes a long time to run because pytest
# renders the exception messages even when not showing them
opt = request.config.option
Expand Down

0 comments on commit d7f62e5

Please sign in to comment.