Skip to content

Commit

Permalink
REF: move implementations to mixing class
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel committed Sep 4, 2024
1 parent ab0c761 commit 3d34b77
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 82 deletions.
45 changes: 45 additions & 0 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,37 @@ def _convert_int_result(self, result):
# Convert an integer-dtype result to the appropriate result type
raise NotImplementedError

def _str_len(self):
result = pc.utf8_length(self._pa_array)
return self._convert_int_result(result)

def _str_lower(self) -> Self:
return type(self)(pc.utf8_lower(self._pa_array))

def _str_upper(self) -> Self:
return type(self)(pc.utf8_upper(self._pa_array))

def _str_strip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_trim_whitespace(self._pa_array)
else:
result = pc.utf8_trim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_lstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_ltrim_whitespace(self._pa_array)
else:
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_rstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_rtrim_whitespace(self._pa_array)
else:
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_pad(
self,
width: int,
Expand Down Expand Up @@ -190,3 +221,17 @@ def _str_istitle(self):
def _str_isupper(self):
result = pc.utf8_is_upper(self._pa_array)
return self._convert_bool_result(result)

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.endswith("$") or pat.endswith("\\$"):
pat = f"{pat}$"
return self._str_match(pat, case, flags, na)
49 changes: 1 addition & 48 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1998,7 +1998,7 @@ def _rank(
"""
See Series.rank.__doc__.
"""
return type(self)(
return self._convert_int_result(
self._rank_calc(
axis=axis,
method=method,
Expand Down Expand Up @@ -2337,9 +2337,6 @@ def _str_contains(
result = result.fill_null(na)
return type(self)(result)

def _result_converter(self, result):
return type(self)(result)

def _str_replace(
self,
pat: str | re.Pattern,
Expand Down Expand Up @@ -2374,20 +2371,6 @@ def _str_repeat(self, repeats: int | Sequence[int]) -> Self:
)
return type(self)(pc.binary_repeat(self._pa_array, repeats))

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
) -> Self:
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
) -> Self:
if not pat.endswith("$") or pat.endswith("\\$"):
pat = f"{pat}$"
return self._str_match(pat, case, flags, na)

def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
if (start == 0 or start is None) and end is None:
result = pc.find_substring(self._pa_array, sub)
Expand Down Expand Up @@ -2442,36 +2425,6 @@ def _str_slice(
pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step)
)

def _str_len(self) -> Self:
return type(self)(pc.utf8_length(self._pa_array))

def _str_lower(self) -> Self:
return type(self)(pc.utf8_lower(self._pa_array))

def _str_upper(self) -> Self:
return type(self)(pc.utf8_upper(self._pa_array))

def _str_strip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_trim_whitespace(self._pa_array)
else:
result = pc.utf8_trim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_lstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_ltrim_whitespace(self._pa_array)
else:
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_rstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_rtrim_whitespace(self._pa_array)
else:
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_removeprefix(self, prefix: str):
if not pa_version_under13p0:
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
Expand Down
44 changes: 10 additions & 34 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@

from pandas._typing import (
ArrayLike,
AxisInt,
Dtype,
Self,
npt,
Expand Down Expand Up @@ -294,19 +293,22 @@ def astype(self, dtype, copy: bool = True):
_str_startswith = ArrowStringArrayMixin._str_startswith
_str_endswith = ArrowStringArrayMixin._str_endswith
_str_pad = ArrowStringArrayMixin._str_pad
_str_match = ArrowExtensionArray._str_match
_str_fullmatch = ArrowExtensionArray._str_fullmatch
_str_lower = ArrowExtensionArray._str_lower
_str_upper = ArrowExtensionArray._str_upper
_str_strip = ArrowExtensionArray._str_strip
_str_lstrip = ArrowExtensionArray._str_lstrip
_str_rstrip = ArrowExtensionArray._str_rstrip
_str_match = ArrowStringArrayMixin._str_match
_str_fullmatch = ArrowStringArrayMixin._str_fullmatch
_str_lower = ArrowStringArrayMixin._str_lower
_str_upper = ArrowStringArrayMixin._str_upper
_str_strip = ArrowStringArrayMixin._str_strip
_str_lstrip = ArrowStringArrayMixin._str_lstrip
_str_rstrip = ArrowStringArrayMixin._str_rstrip
_str_removesuffix = ArrowStringArrayMixin._str_removesuffix
_str_get = ArrowStringArrayMixin._str_get
_str_capitalize = ArrowStringArrayMixin._str_capitalize
_str_title = ArrowStringArrayMixin._str_title
_str_swapcase = ArrowStringArrayMixin._str_swapcase
_str_slice_replace = ArrowStringArrayMixin._str_slice_replace
_str_len = ArrowStringArrayMixin._str_len

_rank = ArrowExtensionArray._rank

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
Expand Down Expand Up @@ -362,10 +364,6 @@ def _str_slice(
return super()._str_slice(start, stop, step)
return ArrowExtensionArray._str_slice(self, start=start, stop=stop, step=step)

def _str_len(self):
result = pc.utf8_length(self._pa_array)
return self._convert_int_result(result)

def _str_removeprefix(self, prefix: str):
if not pa_version_under13p0:
return ArrowExtensionArray._str_removeprefix(self, prefix)
Expand Down Expand Up @@ -431,28 +429,6 @@ def _reduce(
else:
return result

def _rank(
self,
*,
axis: AxisInt = 0,
method: str = "average",
na_option: str = "keep",
ascending: bool = True,
pct: bool = False,
):
"""
See Series.rank.__doc__.
"""
return self._convert_int_result(
self._rank_calc(
axis=axis,
method=method,
na_option=na_option,
ascending=ascending,
pct=pct,
)
)

def value_counts(self, dropna: bool = True) -> Series:
result = super().value_counts(dropna=dropna)
if self.dtype.na_value is np.nan:
Expand Down

0 comments on commit 3d34b77

Please sign in to comment.