From 9cd4eb354c6c66fbf2505be3db029e8cf96fb7ab Mon Sep 17 00:00:00 2001 From: Weery Date: Thu, 16 Feb 2023 09:17:46 +0000 Subject: [PATCH] Allow subclassed series when using to_schema Signed-off-by: Weery --- pandera/core/pandas/model.py | 16 ++++++++++++---- tests/core/test_model.py | 23 +++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/pandera/core/pandas/model.py b/pandera/core/pandas/model.py index af1ab0df7..2f20bd0f3 100644 --- a/pandera/core/pandas/model.py +++ b/pandera/core/pandas/model.py @@ -360,10 +360,7 @@ def _build_columns_index( # pylint:disable=too-many-locals dtype = None if dtype is Any else dtype - if ( - annotation.origin in SERIES_TYPES - or annotation.raw_annotation in SERIES_TYPES - ): + if _annotation_is_valid_column(annotation): col_constructor = field.to_column if field else Column if check_name is False: @@ -585,6 +582,17 @@ def _build_schema_index( return index +def _annotation_is_valid_column(annotation: AnnotationInfo) -> bool: + if annotation.origin in SERIES_TYPES: + return True + if annotation.raw_annotation in SERIES_TYPES: + return True + if isinstance(annotation.origin, type): + if issubclass(annotation.origin, tuple(SERIES_TYPES)): + return True + return False + + def _regex_filter(seq: Iterable, regexps: Iterable[str]) -> Set[str]: """Filter items matching at least one of the regexes.""" matched: Set[str] = set() diff --git a/tests/core/test_model.py b/tests/core/test_model.py index 10ecb892a..8dab91ac0 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -11,6 +11,7 @@ import pandera.core.extensions as pax from pandera.errors import SchemaError, SchemaInitError from pandera.typing import DataFrame, Index, Series, String +from pandera.typing.common import GenericDtype def test_to_schema_and_validate() -> None: @@ -103,6 +104,28 @@ class InvalidDtype(pa.DataFrameModel): InvalidDtype.to_schema() +def test_sublcassing_series(): + """Test that when DataFrameModel.to_schema() does not raise an error for + sublcassed Series.""" + + class ValidSeries(Series, Generic[GenericDtype]): + pass + + class ValidSchema(pa.DataFrameModel): + a: ValidSeries[int] + + ValidSchema.to_schema() + + class InvalidSeries(int, Generic[GenericDtype]): + pass + + class InvalidSchema(pa.DataFrameModel): + a: InvalidSeries[int] + + with pytest.raises(pa.errors.SchemaInitError, match="Invalid annotation"): + InvalidSchema.to_schema() + + def test_optional_column() -> None: """Test that optional columns are not required."""