Skip to content

Commit

Permalink
Backport PR pandas-dev#54586: REF: Refactor conversion of na value
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored and meeseeksmachine committed Aug 21, 2023
1 parent f7723bb commit 3c36e73
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 31 deletions.
10 changes: 10 additions & 0 deletions pandas/tests/strings/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,12 @@
# Needed for new arrow string dtype

import pandas as pd

object_pyarrow_numpy = ("object",)


def _convert_na_value(ser, expected):
if ser.dtype != object:
# GH#18463
expected = expected.fillna(pd.NA)
return expected
9 changes: 5 additions & 4 deletions pandas/tests/strings/test_find_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
Series,
_testing as tm,
)
from pandas.tests.strings import object_pyarrow_numpy
from pandas.tests.strings import (
_convert_na_value,
object_pyarrow_numpy,
)

# --------------------------------------------------------------------------------------
# str.contains
Expand Down Expand Up @@ -758,9 +761,7 @@ def test_findall(any_string_dtype):
ser = Series(["fooBAD__barBAD", np.nan, "foo", "BAD"], dtype=any_string_dtype)
result = ser.str.findall("BAD[_]*")
expected = Series([["BAD__", "BAD"], np.nan, [], ["BAD"]])
if ser.dtype != object:
# GH#18463
expected = expected.fillna(pd.NA)
expected = _convert_na_value(ser, expected)
tm.assert_series_equal(result, expected)


Expand Down
37 changes: 10 additions & 27 deletions pandas/tests/strings/test_split_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
Series,
_testing as tm,
)
from pandas.tests.strings import _convert_na_value


@pytest.mark.parametrize("method", ["split", "rsplit"])
Expand All @@ -20,9 +21,7 @@ def test_split(any_string_dtype, method):

result = getattr(values.str, method)("_")
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
if values.dtype != object:
# GH#18463
exp = exp.fillna(pd.NA)
exp = _convert_na_value(values, exp)
tm.assert_series_equal(result, exp)


Expand All @@ -32,9 +31,7 @@ def test_split_more_than_one_char(any_string_dtype, method):
values = Series(["a__b__c", "c__d__e", np.nan, "f__g__h"], dtype=any_string_dtype)
result = getattr(values.str, method)("__")
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
if values.dtype != object:
# GH#18463
exp = exp.fillna(pd.NA)
exp = _convert_na_value(values, exp)
tm.assert_series_equal(result, exp)

result = getattr(values.str, method)("__", expand=False)
Expand All @@ -46,9 +43,7 @@ def test_split_more_regex_split(any_string_dtype):
values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"], dtype=any_string_dtype)
result = values.str.split("[,_]")
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
if values.dtype != object:
# GH#18463
exp = exp.fillna(pd.NA)
exp = _convert_na_value(values, exp)
tm.assert_series_equal(result, exp)


Expand Down Expand Up @@ -128,9 +123,7 @@ def test_rsplit(any_string_dtype):
values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"], dtype=any_string_dtype)
result = values.str.rsplit("[,_]")
exp = Series([["a,b_c"], ["c_d,e"], np.nan, ["f,g,h"]])
if values.dtype != object:
# GH#18463
exp = exp.fillna(pd.NA)
exp = _convert_na_value(values, exp)
tm.assert_series_equal(result, exp)


Expand All @@ -139,9 +132,7 @@ def test_rsplit_max_number(any_string_dtype):
values = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"], dtype=any_string_dtype)
result = values.str.rsplit("_", n=1)
exp = Series([["a_b", "c"], ["c_d", "e"], np.nan, ["f_g", "h"]])
if values.dtype != object:
# GH#18463
exp = exp.fillna(pd.NA)
exp = _convert_na_value(values, exp)
tm.assert_series_equal(result, exp)


Expand Down Expand Up @@ -455,9 +446,7 @@ def test_partition_series_more_than_one_char(method, exp, any_string_dtype):
s = Series(["a__b__c", "c__d__e", np.nan, "f__g__h", None], dtype=any_string_dtype)
result = getattr(s.str, method)("__", expand=False)
expected = Series(exp)
if s.dtype != object:
# GH#18463
expected = expected.fillna(pd.NA)
expected = _convert_na_value(s, expected)
tm.assert_series_equal(result, expected)


Expand All @@ -480,9 +469,7 @@ def test_partition_series_none(any_string_dtype, method, exp):
s = Series(["a b c", "c d e", np.nan, "f g h", None], dtype=any_string_dtype)
result = getattr(s.str, method)(expand=False)
expected = Series(exp)
if s.dtype != object:
# GH#18463
expected = expected.fillna(pd.NA)
expected = _convert_na_value(s, expected)
tm.assert_series_equal(result, expected)


Expand All @@ -505,9 +492,7 @@ def test_partition_series_not_split(any_string_dtype, method, exp):
s = Series(["abc", "cde", np.nan, "fgh", None], dtype=any_string_dtype)
result = getattr(s.str, method)("_", expand=False)
expected = Series(exp)
if s.dtype != object:
# GH#18463
expected = expected.fillna(pd.NA)
expected = _convert_na_value(s, expected)
tm.assert_series_equal(result, expected)


Expand All @@ -531,9 +516,7 @@ def test_partition_series_unicode(any_string_dtype, method, exp):

result = getattr(s.str, method)("_", expand=False)
expected = Series(exp)
if s.dtype != object:
# GH#18463
expected = expected.fillna(pd.NA)
expected = _convert_na_value(s, expected)
tm.assert_series_equal(result, expected)


Expand Down

0 comments on commit 3c36e73

Please sign in to comment.