Skip to content

Commit

Permalink
[backport 2.3.x] String dtype: propagate NaNs as False in predicate m…
Browse files Browse the repository at this point in the history
…ethods (eg .str.startswith) (#59616) (#60014)

* String dtype: propagate NaNs as False in predicate methods (eg .str.startswith) (#59616)

(cherry picked from commit 88554d0)

* ignore object dtype inference warnings
  • Loading branch information
jorisvandenbossche authored Oct 11, 2024
1 parent e3302bc commit a24a653
Show file tree
Hide file tree
Showing 12 changed files with 316 additions and 146 deletions.
44 changes: 27 additions & 17 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,14 @@

import numpy as np

from pandas._libs import lib
from pandas.compat import (
pa_version_under10p1,
pa_version_under11p0,
pa_version_under13p0,
pa_version_under17p0,
)

from pandas.core.dtypes.missing import isna

if not pa_version_under10p1:
import pyarrow as pa
import pyarrow.compute as pc
Expand All @@ -38,7 +37,7 @@ class ArrowStringArrayMixin:
def __init__(self, *args, **kwargs) -> None:
raise NotImplementedError

def _convert_bool_result(self, result):
def _convert_bool_result(self, result, na=lib.no_default, method_name=None):
# Convert a bool-dtype result to the appropriate result type
raise NotImplementedError

Expand Down Expand Up @@ -212,7 +211,9 @@ def _str_removesuffix(self, suffix: str):
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)

def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
def _str_startswith(
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
):
if isinstance(pat, str):
result = pc.starts_with(self._pa_array, pattern=pat)
else:
Expand All @@ -225,11 +226,11 @@ def _str_startswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):

for p in pat[1:]:
result = pc.or_(result, pc.starts_with(self._pa_array, pattern=p))
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)
return self._convert_bool_result(result, na=na, method_name="startswith")

def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):
def _str_endswith(
self, pat: str | tuple[str, ...], na: Scalar | lib.NoDefault = lib.no_default
):
if isinstance(pat, str):
result = pc.ends_with(self._pa_array, pattern=pat)
else:
Expand All @@ -242,9 +243,7 @@ def _str_endswith(self, pat: str | tuple[str, ...], na: Scalar | None = None):

for p in pat[1:]:
result = pc.or_(result, pc.ends_with(self._pa_array, pattern=p))
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)
return self._convert_bool_result(result, na=na, method_name="endswith")

def _str_isalnum(self):
result = pc.utf8_is_alnum(self._pa_array)
Expand Down Expand Up @@ -283,7 +282,12 @@ def _str_isupper(self):
return self._convert_bool_result(result)

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
self,
pat,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
regex: bool = True,
):
if flags:
raise NotImplementedError(f"contains not implemented with {flags=}")
Expand All @@ -293,19 +297,25 @@ def _str_contains(
else:
pa_contains = pc.match_substring
result = pa_contains(self._pa_array, pat, ignore_case=not case)
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)
return self._convert_bool_result(result, na=na, method_name="contains")

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
self,
pat: str,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
):
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
self,
pat,
case: bool = True,
flags: int = 0,
na: Scalar | lib.NoDefault = lib.no_default,
):
if not pat.endswith("$") or pat.endswith("\\$"):
pat = f"{pat}$"
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2285,7 +2285,11 @@ def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
for chunk in self._pa_array.iterchunks()
]

def _convert_bool_result(self, result):
def _convert_bool_result(self, result, na=lib.no_default, method_name=None):
if na is not lib.no_default and not isna(
na
): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return type(self)(result)

def _convert_int_result(self, result):
Expand Down
20 changes: 16 additions & 4 deletions pandas/core/arrays/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2675,16 +2675,28 @@ def _replace(self, *, to_replace, value, inplace: bool = False):
# ------------------------------------------------------------------------
# String methods interface
def _str_map(
self, f, na_value=np.nan, dtype=np.dtype("object"), convert: bool = True
self, f, na_value=lib.no_default, dtype=np.dtype("object"), convert: bool = True
):
# Optimization to apply the callable `f` to the categories once
# and rebuild the result by `take`ing from the result with the codes.
# Returns the same type as the object-dtype implementation though.
from pandas.core.arrays import NumpyExtensionArray

categories = self.categories
codes = self.codes
result = NumpyExtensionArray(categories.to_numpy())._str_map(f, na_value, dtype)
if categories.dtype == "string":
result = categories.array._str_map(f, na_value, dtype) # type: ignore[attr-defined]
if (
categories.dtype.na_value is np.nan # type: ignore[union-attr]
and is_bool_dtype(dtype)
and (na_value is lib.no_default or isna(na_value))
):
# NaN propagates as False for functions with boolean return type
na_value = False
else:
from pandas.core.arrays import NumpyExtensionArray

result = NumpyExtensionArray(categories.to_numpy())._str_map(
f, na_value, dtype
)
return take_nd(result, codes, fill_value=na_value)

def _str_get_dummies(self, sep: str = "|"):
Expand Down
33 changes: 22 additions & 11 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,11 @@ def _from_scalars(cls, scalars, dtype: DtypeObj) -> Self:
return cls._from_sequence(scalars, dtype=dtype)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
self,
f,
na_value=lib.no_default,
dtype: Dtype | None = None,
convert: bool = True,
):
if self.dtype.na_value is np.nan:
return self._str_map_nan_semantics(
Expand All @@ -388,7 +392,7 @@ def _str_map(

if dtype is None:
dtype = self.dtype
if na_value is None:
if na_value is lib.no_default:
na_value = self.dtype.na_value

mask = isna(self)
Expand Down Expand Up @@ -458,12 +462,20 @@ def _str_map_str_or_object(
return lib.map_infer_mask(arr, f, mask.view("uint8"))

def _str_map_nan_semantics(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
self,
f,
na_value=lib.no_default,
dtype: Dtype | None = None,
convert: bool = True,
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value
if na_value is lib.no_default:
if is_bool_dtype(dtype):
# NaN propagates as False
na_value = False
else:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)
Expand All @@ -474,7 +486,8 @@ def _str_map_nan_semantics(
if is_integer_dtype(dtype):
na_value = 0
else:
na_value = True
# NaN propagates as False
na_value = False

result = lib.map_infer_mask(
arr,
Expand All @@ -484,15 +497,13 @@ def _str_map_nan_semantics(
na_value=na_value,
dtype=np.dtype(cast(type, dtype)),
)
if na_value_is_na and mask.any():
if na_value_is_na and is_integer_dtype(dtype) and mask.any():
# TODO: we could alternatively do this check before map_infer_mask
# and adjust the dtype/na_value we pass there. Which is more
# performant?
if is_integer_dtype(dtype):
result = result.astype("float64")
else:
result = result.astype("object")
result = result.astype("float64")
result[mask] = np.nan

return result

else:
Expand Down
42 changes: 28 additions & 14 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,29 @@ def insert(self, loc: int, item) -> ArrowStringArray:
raise TypeError("Scalar must be NA or str")
return super().insert(loc, item)

def _convert_bool_result(self, values):
def _convert_bool_result(self, values, na=lib.no_default, method_name=None):
if na is not lib.no_default and not isna(na) and not isinstance(na, bool):
# GH#59561
warnings.warn(
f"Allowing a non-bool 'na' in obj.str.{method_name} is deprecated "
"and will raise in a future version.",
FutureWarning,
stacklevel=find_stack_level(),
)
na = bool(na)

if self.dtype.na_value is np.nan:
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
if na is lib.no_default or isna(na):
# NaN propagates as False
values = values.fill_null(False)
else:
values = values.fill_null(na)
return values.to_numpy()
else:
if na is not lib.no_default and not isna(
na
): # pyright: ignore [reportGeneralTypeIssues]
values = values.fill_null(na)
return BooleanDtype().__from_arrow__(values)

def _maybe_convert_setitem_value(self, value):
Expand Down Expand Up @@ -309,22 +329,16 @@ def _data(self):
_str_slice = ArrowStringArrayMixin._str_slice

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
self,
pat,
case: bool = True,
flags: int = 0,
na=lib.no_default,
regex: bool = True,
):
if flags:
return super()._str_contains(pat, case, flags, na, regex)

if not isna(na):
if not isinstance(na, bool):
# GH#59561
warnings.warn(
"Allowing a non-bool 'na' in obj.str.contains is deprecated "
"and will raise in a future version.",
FutureWarning,
stacklevel=find_stack_level(),
)
na = bool(na)

return ArrowStringArrayMixin._str_contains(self, pat, case, flags, na, regex)

def _str_replace(
Expand Down
Loading

0 comments on commit a24a653

Please sign in to comment.