Skip to content

Commit

Permalink
REF: dont patch assert_frame_equal
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel committed Aug 1, 2023
1 parent 4357621 commit 95593db
Showing 1 changed file with 64 additions and 61 deletions.
125 changes: 64 additions & 61 deletions pandas/tests/extension/json/test_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,74 +79,14 @@ def data_for_grouping():


class BaseJSON:
# NumPy doesn't handle an array of equal-length UserDicts.
# The default assert_series_equal eventually does a
# Series.values, which raises. We work around it by
# converting the UserDicts to dicts.
@classmethod
def assert_series_equal(cls, left, right, *args, **kwargs):
if left.dtype.name == "json":
assert left.dtype == right.dtype
left = pd.Series(
JSONArray(left.values.astype(object)), index=left.index, name=left.name
)
right = pd.Series(
JSONArray(right.values.astype(object)),
index=right.index,
name=right.name,
)
tm.assert_series_equal(left, right, *args, **kwargs)

@classmethod
def assert_frame_equal(cls, left, right, *args, **kwargs):
obj_type = kwargs.get("obj", "DataFrame")
tm.assert_index_equal(
left.columns,
right.columns,
exact=kwargs.get("check_column_type", "equiv"),
check_names=kwargs.get("check_names", True),
check_exact=kwargs.get("check_exact", False),
check_categorical=kwargs.get("check_categorical", True),
obj=f"{obj_type}.columns",
)

jsons = (left.dtypes == "json").index

for col in jsons:
cls.assert_series_equal(left[col], right[col], *args, **kwargs)

left = left.drop(columns=jsons)
right = right.drop(columns=jsons)
tm.assert_frame_equal(left, right, *args, **kwargs)
pass


class TestDtype(BaseJSON, base.BaseDtypeTests):
pass


class TestInterface(BaseJSON, base.BaseInterfaceTests):
def test_custom_asserts(self):
# This would always trigger the KeyError from trying to put
# an array of equal-length UserDicts inside an ndarray.
data = JSONArray(
[
collections.UserDict({"a": 1}),
collections.UserDict({"b": 2}),
collections.UserDict({"c": 3}),
]
)
a = pd.Series(data)
self.assert_series_equal(a, a)
self.assert_frame_equal(a.to_frame(), a.to_frame())

b = pd.Series(data.take([0, 0, 1]))
msg = r"Series are different"
with pytest.raises(AssertionError, match=msg):
self.assert_series_equal(a, b)

with pytest.raises(AssertionError, match=msg):
self.assert_frame_equal(a.to_frame(), b.to_frame())

@pytest.mark.xfail(
reason="comparison method not implemented for JSONArray (GH-37867)"
)
Expand Down Expand Up @@ -385,3 +325,66 @@ class TestComparisonOps(BaseJSON, base.BaseComparisonOpsTests):

class TestPrinting(BaseJSON, base.BasePrintingTests):
pass


def custom_assert_series_equal(left, right, *args, **kwargs):
# NumPy doesn't handle an array of equal-length UserDicts.
# The default assert_series_equal eventually does a
# Series.values, which raises. We work around it by
# converting the UserDicts to dicts.
if left.dtype.name == "json":
assert left.dtype == right.dtype
left = pd.Series(
JSONArray(left.values.astype(object)), index=left.index, name=left.name
)
right = pd.Series(
JSONArray(right.values.astype(object)),
index=right.index,
name=right.name,
)
tm.assert_series_equal(left, right, *args, **kwargs)


def custom_assert_frame_equal(left, right, *args, **kwargs):
obj_type = kwargs.get("obj", "DataFrame")
tm.assert_index_equal(
left.columns,
right.columns,
exact=kwargs.get("check_column_type", "equiv"),
check_names=kwargs.get("check_names", True),
check_exact=kwargs.get("check_exact", False),
check_categorical=kwargs.get("check_categorical", True),
obj=f"{obj_type}.columns",
)

jsons = (left.dtypes == "json").index

for col in jsons:
custom_assert_series_equal(left[col], right[col], *args, **kwargs)

left = left.drop(columns=jsons)
right = right.drop(columns=jsons)
tm.assert_frame_equal(left, right, *args, **kwargs)


def test_custom_asserts():
# This would always trigger the KeyError from trying to put
# an array of equal-length UserDicts inside an ndarray.
data = JSONArray(
[
collections.UserDict({"a": 1}),
collections.UserDict({"b": 2}),
collections.UserDict({"c": 3}),
]
)
a = pd.Series(data)
custom_assert_series_equal(a, a)
custom_assert_frame_equal(a.to_frame(), a.to_frame())

b = pd.Series(data.take([0, 0, 1]))
msg = r"Series are different"
with pytest.raises(AssertionError, match=msg):
custom_assert_series_equal(a, b)

with pytest.raises(AssertionError, match=msg):
custom_assert_frame_equal(a.to_frame(), b.to_frame())

0 comments on commit 95593db

Please sign in to comment.