Skip to content

Commit

Permalink
Backport PR pandas-dev#54515: ENH: ArrowExtensionArray(duration) work…
Browse files Browse the repository at this point in the history
…arounds for pyarrow versions >= 11.0
  • Loading branch information
lukemanley authored and meeseeksmachine committed Aug 13, 2023
1 parent 06e04b0 commit 592d94d
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1011,12 +1011,11 @@ def factorize(
) -> tuple[np.ndarray, ExtensionArray]:
null_encoding = "mask" if use_na_sentinel else "encode"

pa_type = self._pa_array.type
if pa.types.is_duration(pa_type):
data = self._pa_array
pa_type = data.type
if pa_version_under11p0 and pa.types.is_duration(pa_type):
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
data = self._pa_array.cast(pa.int64())
else:
data = self._pa_array
data = data.cast(pa.int64())

if pa.types.is_dictionary(data.type):
encoded = data
Expand All @@ -1034,7 +1033,7 @@ def factorize(
)
uniques = type(self)(encoded.chunk(0).dictionary)

if pa.types.is_duration(pa_type):
if pa_version_under11p0 and pa.types.is_duration(pa_type):
uniques = cast(ArrowExtensionArray, uniques.astype(self.dtype))
return indices, uniques

Expand Down Expand Up @@ -1273,15 +1272,15 @@ def unique(self) -> Self:
"""
pa_type = self._pa_array.type

if pa.types.is_duration(pa_type):
if pa_version_under11p0 and pa.types.is_duration(pa_type):
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
data = self._pa_array.cast(pa.int64())
else:
data = self._pa_array

pa_result = pc.unique(data)

if pa.types.is_duration(pa_type):
if pa_version_under11p0 and pa.types.is_duration(pa_type):
pa_result = pa_result.cast(pa_type)

return type(self)(pa_result)
Expand All @@ -1304,7 +1303,7 @@ def value_counts(self, dropna: bool = True) -> Series:
Series.value_counts
"""
pa_type = self._pa_array.type
if pa.types.is_duration(pa_type):
if pa_version_under11p0 and pa.types.is_duration(pa_type):
# https://github.com/apache/arrow/issues/15226#issuecomment-1376578323
data = self._pa_array.cast(pa.int64())
else:
Expand All @@ -1324,7 +1323,7 @@ def value_counts(self, dropna: bool = True) -> Series:
values = values.filter(mask)
counts = counts.filter(mask)

if pa.types.is_duration(pa_type):
if pa_version_under11p0 and pa.types.is_duration(pa_type):
values = values.cast(pa_type)

counts = ArrowExtensionArray(counts)
Expand Down

0 comments on commit 592d94d

Please sign in to comment.