Skip to content

Commit

Permalink
REF: remove test_accumulate_series_raises
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel committed Aug 1, 2023
1 parent 42e489c commit 93fea03
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 58 deletions.
28 changes: 16 additions & 12 deletions pandas/tests/extension/base/accumulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,28 +11,32 @@ class BaseAccumulateTests(BaseExtensionTests):
make sense for numeric/boolean operations.
"""

def check_accumulate(self, s, op_name, skipna):
result = getattr(s, op_name)(skipna=skipna)
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
# Do we expect this accumulation to be supported for this dtype?
# We default to assuming "no"; subclass authors should override here.
return False

def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
alt = ser.astype("float64")
result = getattr(ser, op_name)(skipna=skipna)

if result.dtype == pd.Float32Dtype() and op_name == "cumprod" and skipna:
# TODO: avoid special-casing here
pytest.skip(
f"Float32 precision lead to large differences with op {op_name} "
f"and skipna={skipna}"
)

expected = getattr(s.astype("float64"), op_name)(skipna=skipna)
expected = getattr(alt, op_name)(skipna=skipna)
tm.assert_series_equal(result, expected, check_dtype=False)

@pytest.mark.parametrize("skipna", [True, False])
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
op_name = all_numeric_accumulations
ser = pd.Series(data)

with pytest.raises(NotImplementedError):
getattr(ser, op_name)(skipna=skipna)

@pytest.mark.parametrize("skipna", [True, False])
def test_accumulate_series(self, data, all_numeric_accumulations, skipna):
op_name = all_numeric_accumulations
ser = pd.Series(data)
self.check_accumulate(ser, op_name, skipna)

if self._supports_accumulation(ser, op_name):
self.check_accumulate(ser, op_name, skipna)
else:
with pytest.raises(NotImplementedError):
getattr(ser, op_name)(skipna=skipna)
49 changes: 16 additions & 33 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,47 +354,30 @@ def check_accumulate(self, ser, op_name, skipna):
expected = getattr(ser.astype("Float64"), op_name)(skipna=skipna)
tm.assert_series_equal(result, expected, check_dtype=False)

@pytest.mark.parametrize("skipna", [True, False])
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
pa_type = data.dtype.pyarrow_dtype
if (
(
pa.types.is_integer(pa_type)
or pa.types.is_floating(pa_type)
or pa.types.is_duration(pa_type)
)
and all_numeric_accumulations == "cumsum"
and not pa_version_under9p0
):
pytest.skip("These work, are tested by test_accumulate_series.")
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
pa_type = ser.dtype.pyarrow_dtype

op_name = all_numeric_accumulations
ser = pd.Series(data)

with pytest.raises(NotImplementedError):
getattr(ser, op_name)(skipna=skipna)

@pytest.mark.parametrize("skipna", [True, False])
def test_accumulate_series(self, data, all_numeric_accumulations, skipna, request):
pa_type = data.dtype.pyarrow_dtype
op_name = all_numeric_accumulations
ser = pd.Series(data)

do_skip = False
if pa.types.is_string(pa_type) or pa.types.is_binary(pa_type):
if op_name in ["cumsum", "cumprod"]:
do_skip = True
return False
elif pa.types.is_temporal(pa_type) and not pa.types.is_duration(pa_type):
if op_name in ["cumsum", "cumprod"]:
do_skip = True
return False
elif pa.types.is_duration(pa_type):
if op_name == "cumprod":
do_skip = True
return False
return True

if do_skip:
pytest.skip(
f"{op_name} should *not* work, we test in "
"test_accumulate_series_raises that these correctly raise."
@pytest.mark.parametrize("skipna", [True, False])
def test_accumulate_series(self, data, all_numeric_accumulations, skipna, request):
pa_type = data.dtype.pyarrow_dtype
op_name = all_numeric_accumulations
ser = pd.Series(data)

if not self._supports_accumulation(ser, op_name):
# The base class test will check that we raise
return super().test_accumulate_series(
data, all_numeric_accumulations, skipna
)

if all_numeric_accumulations != "cumsum" or pa_version_under9p0:
Expand Down
7 changes: 3 additions & 4 deletions pandas/tests/extension/test_boolean.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,9 @@ class TestUnaryOps(base.BaseUnaryOpsTests):


class TestAccumulation(base.BaseAccumulateTests):
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
return True

def check_accumulate(self, s, op_name, skipna):
length = 64
if not IS64 or is_platform_windows():
Expand All @@ -286,10 +289,6 @@ def check_accumulate(self, s, op_name, skipna):
expected = expected.astype("boolean")
tm.assert_series_equal(result, expected)

@pytest.mark.parametrize("skipna", [True, False])
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
pass


class TestParsing(base.BaseParsingTests):
pass
Expand Down
4 changes: 1 addition & 3 deletions pandas/tests/extension/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,7 @@ class TestReduce(base.BaseNoReduceTests):


class TestAccumulate(base.BaseAccumulateTests):
@pytest.mark.parametrize("skipna", [True, False])
def test_accumulate_series(self, data, all_numeric_accumulations, skipna):
pass
pass


class TestMethods(base.BaseMethodsTests):
Expand Down
5 changes: 2 additions & 3 deletions pandas/tests/extension/test_masked_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,9 +324,8 @@ class TestBooleanReduce(base.BaseBooleanReduceTests):


class TestAccumulation(base.BaseAccumulateTests):
@pytest.mark.parametrize("skipna", [True, False])
def test_accumulate_series_raises(self, data, all_numeric_accumulations, skipna):
pass
def _supports_accumulation(self, ser: pd.Series, op_name: str) -> bool:
return True

def check_accumulate(self, ser: pd.Series, op_name: str, skipna: bool):
# overwrite to ensure pd.NA is tested instead of np.nan
Expand Down
4 changes: 1 addition & 3 deletions pandas/tests/extension/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,4 @@ def test_EA_types(self, engine, data):


class TestNoNumericAccumulations(base.BaseAccumulateTests):
@pytest.mark.parametrize("skipna", [True, False])
def test_accumulate_series(self, data, all_numeric_accumulations, skipna):
pass
pass

0 comments on commit 93fea03

Please sign in to comment.