Skip to content

Commit

Permalink
fix: remove Category inheritance from ArrowDictionary
Browse files Browse the repository at this point in the history
Signed-off-by: Daren Liang <[email protected]>
  • Loading branch information
darenliang committed Nov 14, 2024
1 parent 9667234 commit 6f23de1
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 24 deletions.
2 changes: 1 addition & 1 deletion pandera/engines/pandas_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1661,7 +1661,7 @@ def from_parametrized_dtype(cls, pyarrow_dtype: pyarrow.TimestampType):
equivalents=[pyarrow.dictionary, pyarrow.DictionaryType]
)
@immutable(init=True)
class ArrowDictionary(ArrowDataType, dtypes.Category):
class ArrowDictionary(ArrowDataType):
"""Semantic representation of a :class:`pyarrow.dictionary`."""

type: Optional[pd.ArrowDtype] = dataclasses.field(
Expand Down
2 changes: 1 addition & 1 deletion pandera/engines/pyarrow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def from_parametrized_dtype(cls, pyarrow_dtype: pyarrow.TimestampType):
equivalents=[pyarrow.dictionary, pyarrow.DictionaryType]
)
@immutable(init=True)
class ArrowDictionary(ArrowDataType, dtypes.Category):
class ArrowDictionary(ArrowDataType):
"""Semantic representation of a :class:`pyarrow.dictionary`."""

type: Optional[pd.ArrowDtype] = dataclasses.field(default=None, init=False)
Expand Down
79 changes: 57 additions & 22 deletions tests/core/test_pandas_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ def test_pandas_date_coerce_dtype(to_df, data):
(pd.Series(["foo", "bar", "baz", None]), pyarrow.binary(3)),
(pd.Series(["foo", "barbaz", None]), pyarrow.large_binary()),
(pd.Series(["1", "1.0", "foo", "bar", None]), pyarrow.large_string()),
(
pd.Series(["a", "b", "c"]),
pyarrow.dictionary(pyarrow.int64(), pyarrow.string()),
),
)


Expand All @@ -289,17 +293,20 @@ def test_pandas_arrow_dtype(data, dtype):
pytest.skip("Support of pandas 2.0.0+ with pyarrow only")
dtype = pandas_engine.Engine.dtype(dtype)

dtype.coerce(data)
coerced_data = dtype.coerce(data)
assert coerced_data.dtype == dtype.type


pandas_arrow_dtype_error_cases = (
(
pd.Series([["a", "b", "c"]]),
pyarrow.list_(pyarrow.int64()),
pyarrow.ArrowInvalid,
),
(
pd.Series([["a", "b"]]),
pyarrow.list_(pyarrow.string(), 3),
pyarrow.ArrowInvalid,
),
(
pd.Series([{"foo": 1, "bar": "a"}]),
Expand All @@ -309,13 +316,22 @@ def test_pandas_arrow_dtype(data, dtype):
("bar", pyarrow.int64()),
]
),
pyarrow.ArrowTypeError,
),
(pd.Series(["a", "1"]), pyarrow.null, NotImplementedError),
(
pd.Series(["a", date(1970, 1, 1), "1970-01-01"]),
pyarrow.date32,
pyarrow.ArrowTypeError,
),
(
pd.Series(["a", date(1970, 1, 1), "1970-01-01"]),
pyarrow.date64,
pyarrow.ArrowTypeError,
),
(pd.Series(["a", "1"]), pyarrow.null),
(pd.Series(["a", date(1970, 1, 1), "1970-01-01"]), pyarrow.date32),
(pd.Series(["a", date(1970, 1, 1), "1970-01-01"]), pyarrow.date64),
(pd.Series(["a"]), pyarrow.duration("ns")),
(pd.Series(["a", "b"]), pyarrow.time32("ms")),
(pd.Series(["a", "b"]), pyarrow.time64("ns")),
(pd.Series(["a"]), pyarrow.duration("ns"), ValueError),
(pd.Series(["a", "b"]), pyarrow.time32("ms"), ValueError),
(pd.Series(["a", "b"]), pyarrow.time64("ns"), ValueError),
(
pd.Series(
[
Expand All @@ -324,29 +340,48 @@ def test_pandas_arrow_dtype(data, dtype):
]
),
pyarrow.map_(pyarrow.int32(), pyarrow.string()),
NotImplementedError,
),
(pd.Series([1, "foo", None]), pyarrow.binary(), pyarrow.ArrowInvalid),
(
pd.Series(["foo", "bar", "baz", None]),
pyarrow.binary(2),
NotImplementedError,
),
(
pd.Series([1, "foo", "barbaz", None]),
pyarrow.large_binary(),
pyarrow.ArrowInvalid,
),
(
pd.Series([1, 1.0, "foo", "bar", None]),
pyarrow.large_string(),
pyarrow.ArrowInvalid,
),
(
pd.Series([1.0, 2.0, 3.0]),
pyarrow.dictionary(pyarrow.int64(), pyarrow.float64()),
NotImplementedError,
),
(
pd.Series(["a", "b", "c"]),
pyarrow.dictionary(pyarrow.int64(), pyarrow.int64()),
AssertionError,
),
(pd.Series([1, "foo", None]), pyarrow.binary()),
(pd.Series(["foo", "bar", "baz", None]), pyarrow.binary(2)),
(pd.Series([1, "foo", "barbaz", None]), pyarrow.large_binary()),
(pd.Series([1, 1.0, "foo", "bar", None]), pyarrow.large_string()),
)


@pytest.mark.parametrize(("data", "dtype"), pandas_arrow_dtype_error_cases)
def test_pandas_arrow_dtype_error(data, dtype):
@pytest.mark.parametrize(
("data", "dtype", "exc"), pandas_arrow_dtype_error_cases
)
def test_pandas_arrow_dtype_error(data, dtype, exc):
"""Test pyarrow dtype raises Error on bad data."""
if not (
pandas_engine.PYARROW_INSTALLED and pandas_engine.PANDAS_2_0_0_PLUS
):
pytest.skip("Support of pandas 2.0.0+ with pyarrow only")
dtype = pandas_engine.Engine.dtype(dtype)

with pytest.raises(
(
pyarrow.ArrowInvalid,
pyarrow.ArrowTypeError,
NotImplementedError,
ValueError,
)
):
dtype.coerce(data)
with pytest.raises(exc):
coerced_data = dtype.coerce(data)
assert coerced_data.dtype == dtype.type

0 comments on commit 6f23de1

Please sign in to comment.