Skip to content

Commit

Permalink
BUG (string): ArrowStringArray.find corner cases
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel committed Aug 20, 2024
1 parent c4467a9 commit 57373c5
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 20 deletions.
7 changes: 5 additions & 2 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2387,14 +2387,17 @@ def _str_fullmatch(
pat = f"{pat}$"
return self._str_match(pat, case, flags, na)

def _convert_int_dtype(self, result):
return type(self)(result)

def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
if (start == 0 or start is None) and end is None:
result = pc.find_substring(self._pa_array, sub)
else:
if sub == "":
# GH 56792
result = self._apply_elementwise(lambda val: val.find(sub, start, end))
return type(self)(pa.chunked_array(result))
return self._convert_int_dtype(pa.chunked_array(result))
if start is None:
start_offset = 0
start = 0
Expand All @@ -2408,7 +2411,7 @@ def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
offset_result = pc.add(result, start_offset)
result = pc.if_else(found, offset_result, -1)
return type(self)(result)
return self._convert_int_dtype(result)

def _str_join(self, sep: str) -> Self:
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(
Expand Down
15 changes: 1 addition & 14 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def astype(self, dtype, copy: bool = True):
# String methods interface

_str_map = BaseStringArray._str_map
_str_find = ArrowExtensionArray._str_find

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=np.nan, regex: bool = True
Expand Down Expand Up @@ -474,20 +475,6 @@ def _str_count(self, pat: str, flags: int = 0):
result = pc.count_substring_regex(self._pa_array, pat)
return self._convert_int_dtype(result)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if start != 0 and end is not None:
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
not_found = pc.equal(result, -1)
offset_result = pc.add(result, end - start)
result = pc.if_else(not_found, result, offset_result)
elif start == 0 and end is None:
slices = self._pa_array
result = pc.find_substring(slices, sub)
else:
return super()._str_find(sub, start, end)
return self._convert_int_dtype(result)

def _str_get_dummies(self, sep: str = "|"):
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
if len(labels) == 0:
Expand Down
9 changes: 5 additions & 4 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas._libs.tslibs import timezones
from pandas.compat import (
Expand Down Expand Up @@ -1995,7 +1993,6 @@ def test_str_find_large_start():
tm.assert_series_equal(result, expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.skipif(
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311"
)
Expand All @@ -2007,11 +2004,15 @@ def test_str_find_e2e(start, end, sub):
["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""],
dtype=ArrowDtype(pa.string()),
)
object_series = s.astype(pd.StringDtype())
object_series = s.astype(pd.StringDtype(storage="python"))
result = s.str.find(sub, start, end)
expected = object_series.str.find(sub, start, end).astype(result.dtype)
tm.assert_series_equal(result, expected)

arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow"))
result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype)
tm.assert_series_equal(result2, expected)


def test_str_find_negative_start_negative_end_no_match():
# GH 56791
Expand Down

0 comments on commit 57373c5

Please sign in to comment.