Skip to content

Commit

Permalink
feat: add pyarrow list and struct to pandas engine (#1699)
Browse files Browse the repository at this point in the history
* feat: add pyarrow list and struct to pandas engine

Signed-off-by: Ajith Aravind <[email protected]>

* test: add tests for pyarrow list and struct

Signed-off-by: Ajith Aravind <[email protected]>

* fix: linting errors

Signed-off-by: Ajith Aravind <[email protected]>

---------

Signed-off-by: Ajith Aravind <[email protected]>
  • Loading branch information
aaravind100 authored Jun 27, 2024
1 parent c3011b5 commit 44a9763
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 1 deletion.
67 changes: 66 additions & 1 deletion pandera/engines/pandas_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
List,
NamedTuple,
Optional,
Tuple,
Type,
Union,
cast,
Expand Down Expand Up @@ -1765,7 +1766,7 @@ class ArrowDecimal128(DataType, dtypes.Decimal):
precision: int = 28
scale: int = 0

def __post_init__(self) -> None:
def __post_init__(self):
type_ = pd.ArrowDtype(
pyarrow.decimal128(self.precision, self.scale)
)
Expand Down Expand Up @@ -1832,3 +1833,67 @@ def from_parametrized_dtype(
value_type=pyarrow_dtype.value_type, # type: ignore
ordered=pyarrow_dtype.ordered, # type: ignore
)

@Engine.register_dtype(
equivalents=[
pyarrow.list_,
pyarrow.ListType,
pyarrow.FixedSizeListType,
]
)
@immutable(init=True)
class ArrowList(DataType):
"""Semantic representation of a :class:`pyarrow.list_`."""

type: Optional[pd.ArrowDtype] = dataclasses.field(
default=None, init=False
)
value_type: Optional[Union[pyarrow.DataType, pyarrow.Field]] = (
pyarrow.string()
)
list_size: Optional[int] = -1

def __post_init__(self):
type_ = pd.ArrowDtype(
pyarrow.list_(self.value_type, self.list_size)
)
object.__setattr__(self, "type", type_)

@classmethod
def from_parametrized_dtype(
cls,
pyarrow_dtype: Union[pyarrow.ListType, pyarrow.FixedSizeListType],
):
try:
_dtype = cls(
value_type=pyarrow_dtype.value_type, # type: ignore
list_size=pyarrow_dtype.list_size, # type: ignore
)
except AttributeError:
_dtype = cls(value_type=pyarrow_dtype.value_type) # type: ignore
return _dtype

@Engine.register_dtype(equivalents=[pyarrow.struct, pyarrow.StructType])
@immutable(init=True)
class ArrowStruct(DataType):
"""Semantic representation of a :class:`pyarrow.struct`."""

type: Optional[pd.ArrowDtype] = dataclasses.field(
default=None, init=False
)
fields: Optional[
Union[
Iterable[Union[pyarrow.Field, Tuple[str, pyarrow.DataType]]],
Dict[str, pyarrow.DataType],
]
] = tuple()

def __post_init__(self):
type_ = pd.ArrowDtype(pyarrow.struct(self.fields))
object.__setattr__(self, "type", type_)

@classmethod
def from_parametrized_dtype(cls, pyarrow_dtype: pyarrow.StructType):
return cls(
fields=[pyarrow_dtype.field(i) for i in range(pyarrow_dtype.num_fields)] # type: ignore
)
60 changes: 60 additions & 0 deletions tests/core/test_pandas_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import hypothesis.strategies as st
import numpy as np
import pandas as pd
import pyarrow
import pytest
import pytz
from hypothesis import given
Expand Down Expand Up @@ -237,3 +238,62 @@ def test_pandas_date_coerce_dtype(to_df, data):
assert (
coerced_data.map(lambda x: isinstance(x, date)) | coerced_data.isna()
).all()


pandas_arrow_dtype_cases = (
(
pd.Series([["a", "b", "c"]]),
pyarrow.list_(pyarrow.string()),
),
(
pd.Series([["a", "b"]]),
pyarrow.list_(pyarrow.string(), 2),
),
(
pd.Series([{"foo": 1, "bar": "a"}]),
pyarrow.struct(
[
("foo", pyarrow.int64()),
("bar", pyarrow.string()),
]
),
),
)


@pytest.mark.parametrize(("data", "dtype"), pandas_arrow_dtype_cases)
def test_pandas_arrow_dtype(data, dtype):
"""Test pyarrow dtype."""
dtype = pandas_engine.Engine.dtype(dtype)

dtype.coerce(data)


pandas_arrow_dtype_errors_cases = (
(
pd.Series([["a", "b", "c"]]),
pyarrow.list_(pyarrow.int64()),
),
(
pd.Series([["a", "b"]]),
pyarrow.list_(pyarrow.string(), 3),
),
(
pd.Series([{"foo": 1, "bar": "a"}]),
pyarrow.struct(
[
("foo", pyarrow.string()),
("bar", pyarrow.int64()),
]
),
),
)


@pytest.mark.parametrize(("data", "dtype"), pandas_arrow_dtype_errors_cases)
def test_pandas_arrow_dtype_errors(data, dtype):
"""Test pyarrow dtype raises ArrowInvalid or ArrowTypeError on bad data."""
dtype = pandas_engine.Engine.dtype(dtype)

with pytest.raises((pyarrow.ArrowInvalid, pyarrow.ArrowTypeError)):
dtype.coerce(data)
2 changes: 2 additions & 0 deletions tests/strategies/test_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@
pandas_engine.ArrowUInt16,
pandas_engine.ArrowUInt32,
pandas_engine.ArrowUInt64,
pandas_engine.ArrowList,
pandas_engine.ArrowStruct,
]
)

Expand Down

0 comments on commit 44a9763

Please sign in to comment.