Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow subclassed series when using to_schema #1093

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions pandera/core/pandas/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,10 +360,7 @@ def _build_columns_index( # pylint:disable=too-many-locals

dtype = None if dtype is Any else dtype

if (
annotation.origin in SERIES_TYPES
or annotation.raw_annotation in SERIES_TYPES
):
if _annotation_is_valid_column(annotation):
col_constructor = field.to_column if field else Column

if check_name is False:
Expand Down Expand Up @@ -585,6 +582,17 @@ def _build_schema_index(
return index


def _annotation_is_valid_column(annotation: AnnotationInfo) -> bool:
if annotation.origin in SERIES_TYPES:
return True
if annotation.raw_annotation in SERIES_TYPES:
return True
if isinstance(annotation.origin, type):
if issubclass(annotation.origin, tuple(SERIES_TYPES)):
return True
return False


def _regex_filter(seq: Iterable, regexps: Iterable[str]) -> Set[str]:
"""Filter items matching at least one of the regexes."""
matched: Set[str] = set()
Expand Down
23 changes: 23 additions & 0 deletions tests/core/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pandera.core.extensions as pax
from pandera.errors import SchemaError, SchemaInitError
from pandera.typing import DataFrame, Index, Series, String
from pandera.typing.common import GenericDtype


def test_to_schema_and_validate() -> None:
Expand Down Expand Up @@ -103,6 +104,28 @@ class InvalidDtype(pa.DataFrameModel):
InvalidDtype.to_schema()


def test_sublcassing_series():
"""Test that when DataFrameModel.to_schema() does not raise an error for
sublcassed Series."""

class ValidSeries(Series, Generic[GenericDtype]):
pass

class ValidSchema(pa.DataFrameModel):
a: ValidSeries[int]

ValidSchema.to_schema()

class InvalidSeries(int, Generic[GenericDtype]):
pass

class InvalidSchema(pa.DataFrameModel):
a: InvalidSeries[int]

with pytest.raises(pa.errors.SchemaInitError, match="Invalid annotation"):
InvalidSchema.to_schema()


def test_optional_column() -> None:
"""Test that optional columns are not required."""

Expand Down