From 3d34b77814da5bb1001599c7648e42da32fb6046 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 4 Sep 2024 15:33:07 -0700 Subject: [PATCH] REF: move implementations to mixing class --- pandas/core/arrays/_arrow_string_mixins.py | 45 ++++++++++++++++++++ pandas/core/arrays/arrow/array.py | 49 +--------------------- pandas/core/arrays/string_arrow.py | 44 +++++-------------- 3 files changed, 56 insertions(+), 82 deletions(-) diff --git a/pandas/core/arrays/_arrow_string_mixins.py b/pandas/core/arrays/_arrow_string_mixins.py index ba20111e0d8587..726cc2068a47a8 100644 --- a/pandas/core/arrays/_arrow_string_mixins.py +++ b/pandas/core/arrays/_arrow_string_mixins.py @@ -42,6 +42,37 @@ def _convert_int_result(self, result): # Convert an integer-dtype result to the appropriate result type raise NotImplementedError + def _str_len(self): + result = pc.utf8_length(self._pa_array) + return self._convert_int_result(result) + + def _str_lower(self) -> Self: + return type(self)(pc.utf8_lower(self._pa_array)) + + def _str_upper(self) -> Self: + return type(self)(pc.utf8_upper(self._pa_array)) + + def _str_strip(self, to_strip=None) -> Self: + if to_strip is None: + result = pc.utf8_trim_whitespace(self._pa_array) + else: + result = pc.utf8_trim(self._pa_array, characters=to_strip) + return type(self)(result) + + def _str_lstrip(self, to_strip=None) -> Self: + if to_strip is None: + result = pc.utf8_ltrim_whitespace(self._pa_array) + else: + result = pc.utf8_ltrim(self._pa_array, characters=to_strip) + return type(self)(result) + + def _str_rstrip(self, to_strip=None) -> Self: + if to_strip is None: + result = pc.utf8_rtrim_whitespace(self._pa_array) + else: + result = pc.utf8_rtrim(self._pa_array, characters=to_strip) + return type(self)(result) + def _str_pad( self, width: int, @@ -190,3 +221,17 @@ def _str_istitle(self): def _str_isupper(self): result = pc.utf8_is_upper(self._pa_array) return self._convert_bool_result(result) + + def _str_match( + self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None + ): + if not pat.startswith("^"): + pat = f"^{pat}" + return self._str_contains(pat, case, flags, na, regex=True) + + def _str_fullmatch( + self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None + ): + if not pat.endswith("$") or pat.endswith("\\$"): + pat = f"{pat}$" + return self._str_match(pat, case, flags, na) diff --git a/pandas/core/arrays/arrow/array.py b/pandas/core/arrays/arrow/array.py index 807854a13f285d..8827776dbc55c3 100644 --- a/pandas/core/arrays/arrow/array.py +++ b/pandas/core/arrays/arrow/array.py @@ -1998,7 +1998,7 @@ def _rank( """ See Series.rank.__doc__. """ - return type(self)( + return self._convert_int_result( self._rank_calc( axis=axis, method=method, @@ -2337,9 +2337,6 @@ def _str_contains( result = result.fill_null(na) return type(self)(result) - def _result_converter(self, result): - return type(self)(result) - def _str_replace( self, pat: str | re.Pattern, @@ -2374,20 +2371,6 @@ def _str_repeat(self, repeats: int | Sequence[int]) -> Self: ) return type(self)(pc.binary_repeat(self._pa_array, repeats)) - def _str_match( - self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None - ) -> Self: - if not pat.startswith("^"): - pat = f"^{pat}" - return self._str_contains(pat, case, flags, na, regex=True) - - def _str_fullmatch( - self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None - ) -> Self: - if not pat.endswith("$") or pat.endswith("\\$"): - pat = f"{pat}$" - return self._str_match(pat, case, flags, na) - def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self: if (start == 0 or start is None) and end is None: result = pc.find_substring(self._pa_array, sub) @@ -2442,36 +2425,6 @@ def _str_slice( pc.utf8_slice_codeunits(self._pa_array, start=start, stop=stop, step=step) ) - def _str_len(self) -> Self: - return type(self)(pc.utf8_length(self._pa_array)) - - def _str_lower(self) -> Self: - return type(self)(pc.utf8_lower(self._pa_array)) - - def _str_upper(self) -> Self: - return type(self)(pc.utf8_upper(self._pa_array)) - - def _str_strip(self, to_strip=None) -> Self: - if to_strip is None: - result = pc.utf8_trim_whitespace(self._pa_array) - else: - result = pc.utf8_trim(self._pa_array, characters=to_strip) - return type(self)(result) - - def _str_lstrip(self, to_strip=None) -> Self: - if to_strip is None: - result = pc.utf8_ltrim_whitespace(self._pa_array) - else: - result = pc.utf8_ltrim(self._pa_array, characters=to_strip) - return type(self)(result) - - def _str_rstrip(self, to_strip=None) -> Self: - if to_strip is None: - result = pc.utf8_rtrim_whitespace(self._pa_array) - else: - result = pc.utf8_rtrim(self._pa_array, characters=to_strip) - return type(self)(result) - def _str_removeprefix(self, prefix: str): if not pa_version_under13p0: starts_with = pc.starts_with(self._pa_array, pattern=prefix) diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index 2db10bdbd4d489..71628644945e7c 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -54,7 +54,6 @@ from pandas._typing import ( ArrayLike, - AxisInt, Dtype, Self, npt, @@ -294,19 +293,22 @@ def astype(self, dtype, copy: bool = True): _str_startswith = ArrowStringArrayMixin._str_startswith _str_endswith = ArrowStringArrayMixin._str_endswith _str_pad = ArrowStringArrayMixin._str_pad - _str_match = ArrowExtensionArray._str_match - _str_fullmatch = ArrowExtensionArray._str_fullmatch - _str_lower = ArrowExtensionArray._str_lower - _str_upper = ArrowExtensionArray._str_upper - _str_strip = ArrowExtensionArray._str_strip - _str_lstrip = ArrowExtensionArray._str_lstrip - _str_rstrip = ArrowExtensionArray._str_rstrip + _str_match = ArrowStringArrayMixin._str_match + _str_fullmatch = ArrowStringArrayMixin._str_fullmatch + _str_lower = ArrowStringArrayMixin._str_lower + _str_upper = ArrowStringArrayMixin._str_upper + _str_strip = ArrowStringArrayMixin._str_strip + _str_lstrip = ArrowStringArrayMixin._str_lstrip + _str_rstrip = ArrowStringArrayMixin._str_rstrip _str_removesuffix = ArrowStringArrayMixin._str_removesuffix _str_get = ArrowStringArrayMixin._str_get _str_capitalize = ArrowStringArrayMixin._str_capitalize _str_title = ArrowStringArrayMixin._str_title _str_swapcase = ArrowStringArrayMixin._str_swapcase _str_slice_replace = ArrowStringArrayMixin._str_slice_replace + _str_len = ArrowStringArrayMixin._str_len + + _rank = ArrowExtensionArray._rank def _str_contains( self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True @@ -362,10 +364,6 @@ def _str_slice( return super()._str_slice(start, stop, step) return ArrowExtensionArray._str_slice(self, start=start, stop=stop, step=step) - def _str_len(self): - result = pc.utf8_length(self._pa_array) - return self._convert_int_result(result) - def _str_removeprefix(self, prefix: str): if not pa_version_under13p0: return ArrowExtensionArray._str_removeprefix(self, prefix) @@ -431,28 +429,6 @@ def _reduce( else: return result - def _rank( - self, - *, - axis: AxisInt = 0, - method: str = "average", - na_option: str = "keep", - ascending: bool = True, - pct: bool = False, - ): - """ - See Series.rank.__doc__. - """ - return self._convert_int_result( - self._rank_calc( - axis=axis, - method=method, - na_option=na_option, - ascending=ascending, - pct=pct, - ) - ) - def value_counts(self, dropna: bool = True) -> Series: result = super().value_counts(dropna=dropna) if self.dtype.na_value is np.nan: