Skip to content

Commit

Permalink
REF (string): de-duplicate str_endswith, startswith
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel committed Aug 26, 2024
1 parent 15e9e7a commit f8e3f5e
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 72 deletions.
69 changes: 67 additions & 2 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,35 @@

from pandas.compat import pa_version_under10p1

from pandas.core.dtypes.missing import isna

if not pa_version_under10p1:
import pyarrow as pa
import pyarrow.compute as pc

if TYPE_CHECKING:
from pandas._typing import Self
from collections.abc import Sized

from pandas._typing import (
Scalar,
Self,
)


class ArrowStringArrayMixin:
_pa_array = None
# _object_compat specifies whether we should 1) attempt to match behaviors
# of the object-backed StringDtype and 2) fall back to object-based
# computation for cases that pyarrow does not support natively.
_object_compat = False
_pa_array: Sized

def __init__(self, *args, **kwargs) -> None:
raise NotImplementedError

def _result_converter(self, values, na=None):
# Convert a bool-dtype pyarrow result to appropriate output type.
raise NotImplementedError

def _str_pad(
self,
width: int,
Expand Down Expand Up @@ -89,3 +104,53 @@ def _str_removesuffix(self, suffix: str):
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)

def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
if isinstance(pat, str):
result = pc.starts_with(self._pa_array, pattern=pat)
else:
if len(pat) == 0:
if self._object_compat:
# mimic existing behaviour of string extension array
# and python string method
result = pa.array(
np.zeros(len(self._pa_array), dtype=np.bool_),
mask=isna(self._pa_array),
)
else:
# For empty tuple, pd.StringDtype() returns null for missing values
# and false for valid values.
result = pc.if_else(pc.is_null(self._pa_array), None, False)
else:
result = pc.starts_with(self._pa_array, pattern=pat[0])

for p in pat[1:]:
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
if not isna(na):
result = result.fill_null(na)
return self._result_converter(result)

def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
if isinstance(pat, str):
result = pc.ends_with(self._pa_array, pattern=pat)
else:
if len(pat) == 0:
if self._object_compat:
# mimic existing behaviour of string extension array
# and python string method
result = pa.array(
np.zeros(len(self._pa_array), dtype=np.bool_),
mask=isna(self._pa_array),
)
else:
# For empty tuple, pd.StringDtype() returns null for missing values
# and false for valid values.
result = pc.if_else(pc.is_null(self._pa_array), None, False)
else:
result = pc.ends_with(self._pa_array, pattern=pat[0])

for p in pat[1:]:
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
if not isna(na):
result = result.fill_null(na)
return self._result_converter(result)
33 changes: 1 addition & 32 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2305,38 +2305,7 @@ def _str_contains(
result = result.fill_null(na)
return type(self)(result)

def _str_startswith(self, pat: str | tuple[str, ...], na=None) -> Self:
if isinstance(pat, str):
result = pc.starts_with(self._pa_array, pattern=pat)
else:
if len(pat) == 0:
# For empty tuple, pd.StringDtype() returns null for missing values
# and false for valid values.
result = pc.if_else(pc.is_null(self._pa_array), None, False)
else:
result = pc.starts_with(self._pa_array, pattern=pat[0])

for p in pat[1:]:
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
if not isna(na):
result = result.fill_null(na)
return type(self)(result)

def _str_endswith(self, pat: str | tuple[str, ...], na=None) -> Self:
if isinstance(pat, str):
result = pc.ends_with(self._pa_array, pattern=pat)
else:
if len(pat) == 0:
# For empty tuple, pd.StringDtype() returns null for missing values
# and false for valid values.
result = pc.if_else(pc.is_null(self._pa_array), None, False)
else:
result = pc.ends_with(self._pa_array, pattern=pat[0])

for p in pat[1:]:
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
if not isna(na):
result = result.fill_null(na)
def _result_converter(self, result):
return type(self)(result)

def _str_replace(
Expand Down
39 changes: 1 addition & 38 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def astype(self, dtype, copy: bool = True):

# ------------------------------------------------------------------------
# String methods interface
_object_compat = True

_str_map = BaseStringArray._str_map

Expand All @@ -298,44 +299,6 @@ def _str_contains(
result[isna(result)] = bool(na)
return result

def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
if isinstance(pat, str):
result = pc.starts_with(self._pa_array, pattern=pat)
else:
if len(pat) == 0:
# mimic existing behaviour of string extension array
# and python string method
result = pa.array(
np.zeros(len(self._pa_array), dtype=bool), mask=isna(self._pa_array)
)
else:
result = pc.starts_with(self._pa_array, pattern=pat[0])

for p in pat[1:]:
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
if not isna(na):
result = result.fill_null(na)
return self._result_converter(result)

def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
if isinstance(pat, str):
result = pc.ends_with(self._pa_array, pattern=pat)
else:
if len(pat) == 0:
# mimic existing behaviour of string extension array
# and python string method
result = pa.array(
np.zeros(len(self._pa_array), dtype=bool), mask=isna(self._pa_array)
)
else:
result = pc.ends_with(self._pa_array, pattern=pat[0])

for p in pat[1:]:
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
if not isna(na):
result = result.fill_null(na)
return self._result_converter(result)

def _str_replace(
self,
pat: str | re.Pattern,
Expand Down

0 comments on commit f8e3f5e

Please sign in to comment.