From b54bbf638b3d0e8c584e36a0ad22c8dcebd95b43 Mon Sep 17 00:00:00 2001 From: Patrick Hoefler <61934744+phofl@users.noreply.github.com> Date: Thu, 17 Aug 2023 10:42:36 +0200 Subject: [PATCH] Backport PR #54535: REF: Replace "pyarrow" string storage checks with variable --- pandas/conftest.py | 5 +++++ pandas/tests/arrays/string_/test_string.py | 20 +++++++++---------- .../tests/arrays/string_/test_string_arrow.py | 4 ++-- pandas/tests/extension/test_string.py | 12 +++++------ 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/pandas/conftest.py b/pandas/conftest.py index 757ca817d1b85..5210e727aeb3c 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -1996,3 +1996,8 @@ def warsaw(request) -> str: tzinfo for Europe/Warsaw using pytz, dateutil, or zoneinfo. """ return request.param + + +@pytest.fixture() +def arrow_string_storage(): + return ("pyarrow",) diff --git a/pandas/tests/arrays/string_/test_string.py b/pandas/tests/arrays/string_/test_string.py index cfd3314eb5944..de93e89ecacd5 100644 --- a/pandas/tests/arrays/string_/test_string.py +++ b/pandas/tests/arrays/string_/test_string.py @@ -115,8 +115,8 @@ def test_add(dtype): tm.assert_series_equal(result, expected) -def test_add_2d(dtype, request): - if dtype.storage == "pyarrow": +def test_add_2d(dtype, request, arrow_string_storage): + if dtype.storage in arrow_string_storage: reason = "Failed: DID NOT RAISE " mark = pytest.mark.xfail(raises=None, reason=reason) request.node.add_marker(mark) @@ -144,8 +144,8 @@ def test_add_sequence(dtype): tm.assert_extension_array_equal(result, expected) -def test_mul(dtype, request): - if dtype.storage == "pyarrow": +def test_mul(dtype, request, arrow_string_storage): + if dtype.storage in arrow_string_storage: reason = "unsupported operand type(s) for *: 'ArrowStringArray' and 'int'" mark = pytest.mark.xfail(raises=NotImplementedError, reason=reason) request.node.add_marker(mark) @@ -369,8 +369,8 @@ def test_min_max(method, skipna, dtype, request): @pytest.mark.parametrize("method", ["min", "max"]) @pytest.mark.parametrize("box", [pd.Series, pd.array]) -def test_min_max_numpy(method, box, dtype, request): - if dtype.storage == "pyarrow" and box is pd.array: +def test_min_max_numpy(method, box, dtype, request, arrow_string_storage): + if dtype.storage in arrow_string_storage and box is pd.array: if box is pd.array: reason = "'<=' not supported between instances of 'str' and 'NoneType'" else: @@ -384,7 +384,7 @@ def test_min_max_numpy(method, box, dtype, request): assert result == expected -def test_fillna_args(dtype, request): +def test_fillna_args(dtype, request, arrow_string_storage): # GH 37987 arr = pd.array(["a", pd.NA], dtype=dtype) @@ -397,7 +397,7 @@ def test_fillna_args(dtype, request): expected = pd.array(["a", "b"], dtype=dtype) tm.assert_extension_array_equal(res, expected) - if dtype.storage == "pyarrow": + if dtype.storage in arrow_string_storage: msg = "Invalid value '1' for dtype string" else: msg = "Cannot set non-string value '1' into a StringArray." @@ -503,10 +503,10 @@ def test_use_inf_as_na(values, expected, dtype): tm.assert_frame_equal(result, expected) -def test_memory_usage(dtype): +def test_memory_usage(dtype, arrow_string_storage): # GH 33963 - if dtype.storage == "pyarrow": + if dtype.storage in arrow_string_storage: pytest.skip(f"not applicable for {dtype.storage}") series = pd.Series(["a", "b", "c"], dtype=dtype) diff --git a/pandas/tests/arrays/string_/test_string_arrow.py b/pandas/tests/arrays/string_/test_string_arrow.py index 6912d5038ae0d..1ab628f186b47 100644 --- a/pandas/tests/arrays/string_/test_string_arrow.py +++ b/pandas/tests/arrays/string_/test_string_arrow.py @@ -49,10 +49,10 @@ def test_config_bad_storage_raises(): @skip_if_no_pyarrow @pytest.mark.parametrize("chunked", [True, False]) @pytest.mark.parametrize("array", ["numpy", "pyarrow"]) -def test_constructor_not_string_type_raises(array, chunked): +def test_constructor_not_string_type_raises(array, chunked, arrow_string_storage): import pyarrow as pa - array = pa if array == "pyarrow" else np + array = pa if array in arrow_string_storage else np arr = array.array([1, 2, 3]) if chunked: diff --git a/pandas/tests/extension/test_string.py b/pandas/tests/extension/test_string.py index 6597ff84e3ca4..4e142eb6e14b8 100644 --- a/pandas/tests/extension/test_string.py +++ b/pandas/tests/extension/test_string.py @@ -103,8 +103,8 @@ def test_is_not_string_type(self, dtype): class TestInterface(base.BaseInterfaceTests): - def test_view(self, data, request): - if data.dtype.storage == "pyarrow": + def test_view(self, data, request, arrow_string_storage): + if data.dtype.storage in arrow_string_storage: pytest.skip(reason="2D support not implemented for ArrowStringArray") super().test_view(data) @@ -116,8 +116,8 @@ def test_from_dtype(self, data): class TestReshaping(base.BaseReshapingTests): - def test_transpose(self, data, request): - if data.dtype.storage == "pyarrow": + def test_transpose(self, data, request, arrow_string_storage): + if data.dtype.storage in arrow_string_storage: pytest.skip(reason="2D support not implemented for ArrowStringArray") super().test_transpose(data) @@ -127,8 +127,8 @@ class TestGetitem(base.BaseGetitemTests): class TestSetitem(base.BaseSetitemTests): - def test_setitem_preserves_views(self, data, request): - if data.dtype.storage == "pyarrow": + def test_setitem_preserves_views(self, data, request, arrow_string_storage): + if data.dtype.storage in arrow_string_storage: pytest.skip(reason="2D support not implemented for ArrowStringArray") super().test_setitem_preserves_views(data)