Skip to content

Commit

Permalink
Merge pull request #60 from StreetEasy/bug-legacy-categoricals-int
Browse files Browse the repository at this point in the history
bug-legacy-categoricals-int
  • Loading branch information
Casyfill authored May 10, 2023
2 parents 054e7b2 + 02fbf61 commit 18a6529
Show file tree
Hide file tree
Showing 9 changed files with 98 additions and 61 deletions.
5 changes: 5 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Changelog

v0.0.9:

- Pydantic bumped to `1.10`
- Bug Fix: Categorical constraints (`exact_set`, `oneof`, `include`) now can keeo `int` and `float` values. That expands to legacy schemas as well.

v0.0.8:
Legacy Schema Aliases (support for legacy schemas):
- `min_value` now also supports `min` alias
Expand Down
4 changes: 2 additions & 2 deletions dfschema/core/column.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import sys
from datetime import date, datetime
from typing import List, Optional, Set, Union, Tuple # , Pattern
from typing import List, Optional, FrozenSet, Union, Tuple # , Pattern
from warnings import warn

import pandas as pd
Expand Down Expand Up @@ -140,7 +140,7 @@ def validate_column(self, series: pd.Series, root, col_name: Optional[str] = Non


class Categorical(BaseModel): # type: ignore
value_set: Optional[Set[str]] = None
value_set: Optional[Union[FrozenSet[int], FrozenSet[float], FrozenSet[str],]] = None
mode: Optional[Literal["oneof", "exact_set", "include"]] = None
unique: bool = Field(
False, description="if true, the column must contain only unique values"
Expand Down
24 changes: 13 additions & 11 deletions dfschema/core/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .legacy import infer_protocol_version, LegacySchemaRegistry
from .generate import generate_schema_dict_from_df

# from .utils import SchemaEncoder
# from .base_config import BaseConfig


Expand All @@ -40,13 +41,17 @@ class MetaData(BaseModel):
)


class DfSchema(BaseModel, extra=Extra.forbid, arbitrary_types_allowed=True): # type: ignore
class DfSchema(BaseModel): # type: ignore
"""Main class of the package
Represents a Schema to check (validate) dataframe against. Schema
is flavor-agnostic (does not specify what kind of dataframe it is)
"""

class Config:
extra = Extra.forbid
arbitrary_types_allowed = True

metadata: Optional[MetaData] = Field(
MetaData(),
description="optional metadata, including version and protocol version",
Expand Down Expand Up @@ -225,11 +230,14 @@ def to_file(self, path: Union[str, Path]) -> None:
path = Path(path)

try:
schema_dict = self.dict(exclude_none=True)

if path.suffix == ".json":
schema_json = self.json(exclude_none=True, indent=4)
with path.open("w") as f:
json.dump(schema_dict, f, indent=4)
f.write(schema_json)
elif path.suffix in (".yml", ".yaml"):
schema_dict = self.dict(exclude_none=True)

try:
import yaml

Expand All @@ -246,10 +254,7 @@ def to_file(self, path: Union[str, Path]) -> None:
raise DataFrameSchemaError(f"Error wriging schema to file {path}") from e

@classmethod
def from_dict(
cls,
dict_: dict,
) -> "DfSchema":
def from_dict(cls, dict_: dict,) -> "DfSchema":
"""create DfSchema from dict.
same as `DfSchema(**dict_)`, but will also migrate old protocol schemas if necessary.
Expand Down Expand Up @@ -324,10 +329,7 @@ class SubsetSchema(BaseModel, extra=Extra.forbid, arbitrary_types_allowed=True):
predicate to select subset.
- If string, will be interpreted as query for `df.query()`.
- If dict, keys should be column names, values should be values to exactly match"""
predicate: Union[
dict,
str,
] = Field(..., description=_predicate_description)
predicate: Union[dict, str,] = Field(..., description=_predicate_description)

shape: Optional[ShapeSchema] = Field(None, description="shape expectations")
columns: Optional[List[ColSchema]] = Field([], description="columns expectations")
Expand Down
15 changes: 8 additions & 7 deletions dfschema/core/legacy/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# import json

from typing import Optional, Union, Dict, List, Tuple
from typing import Optional, Union, Dict, List, Tuple, Set
from ..logger import logger
from ..dtype import DtypeLiteral

Expand Down Expand Up @@ -35,8 +35,8 @@ class Config:

na_limit: Union[None, bool, float] = Field(None, gt=0, le=1.0)

include: Optional[List[str]] = None
oneof: Optional[List[str]] = Field(None, alias="one_of")
include: Optional[Union[Set[int], Set[float], Set[str]]] = None
oneof: Optional[Union[Set[int], Set[float], Set[str]]] = Field(None, alias="one_of")
unique: Optional[bool] = None


Expand All @@ -54,9 +54,7 @@ class Config:
allow_population_by_field_name = True

version: Optional[str] = Field(
None,
description="version of the schema",
example="2022-06-12",
None, description="version of the schema", example="2022-06-12",
)

protocol_version: float = Field(1.0, description="version of the protocol")
Expand Down Expand Up @@ -122,7 +120,10 @@ def migrate(self) -> Tuple[dict, float]:
if col.get(k) is not None:
categorical = col.get("categorical", dict())
try:
categorical["value_set"] = set(col.pop(k, {}))
categorical["value_set"] = set(col.pop(k, set()))
logger.debug(
f'Converting Categorical value set for mode={k}: {categorical["value_set"]}'
)
except TypeError as e:
raise TypeError(k, col, e)
categorical["mode"] = k
Expand Down
79 changes: 40 additions & 39 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions property_benchmarks_v2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"metadata": {"protocol_version": 2.0, "version": "2023-01-30"}, "columns": [{"name": "ID", "dtype": "int", "na_limit": 0.99, "value_limits": {"min": 1.0}}, {"name": "LAST_SALE_ID", "dtype": "float", "na_limit": 0.9}, {"name": "BUILDING_ID", "dtype": "float", "na_limit": 0.9}, {"name": "UNITTYPE", "dtype": "str", "na_limit": 0.9}, {"name": "NEIGHBORHOOD_ID", "dtype": "int", "na_limit": 0.99}, {"name": "SUBMARKET_ID", "dtype": "number", "na_limit": 0.99}, {"name": "BOROUGH_ID", "dtype": "int", "na_limit": 0.99, "categorical": {"value_set":
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "dfschema"
version = "0.0.8"
version = "0.0.9"
description = "lightweight pandas.DataFrame schema"
authors = ["Philipp <[email protected]>"]
readme = "README.md"
Expand All @@ -13,7 +13,7 @@ python = ">=3.7.1,<4.0"
pandas = "^1.2.4"
sqlalchemy = {version = "1.*", optional = true}
pandera = {version = "^0.6", optional = true}
pydantic = "^1.9.1"
pydantic = ">1.10"
typer = {version = "^0.6.1", optional = true}
PyYAML = {version = "^6.0", optional = true}

Expand Down
14 changes: 14 additions & 0 deletions tests/test_schemas/v1/good/property_benchmarks.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"version": "2023-01-30",
"strict": true,
"columns": {
"ID": {"dtype":"int", "na_limit":1, "min_value":1},
"LAST_SALE_ID": {"dtype":"float", "na_limit":0.9},
"BUILDING_ID": {"dtype":"float", "na_limit":0.9},
"UNITTYPE": {"dtype":"str", "na_limit":0.9},
"NEIGHBORHOOD_ID": {"dtype":"int", "na_limit":1},
"SUBMARKET_ID": {"dtype":"number", "na_limit":true},
"BOROUGH_ID": {"dtype":"int", "na_limit":1, "oneof":[100,200,300,400,500]},
"MEDIAN_NHOOD_PPSF": {"dtype":"number", "na_limit":0.9}
}
}
13 changes: 13 additions & 0 deletions tests/test_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,16 @@ def test_schema_objects(good_schema_v1: dict):
new = S.dict()
model_col = [c for c in new["columns"] if c["name"] == "model"][0]
assert model_col.get("categorical", {}).get("mode") == "exact_set"


def test_categorical_dtypes():
from dfschema.core.core import DfSchema
import json
from pathlib import Path

path = Path(__name__).parent / "tests/test_schemas/v1/good/property_benchmarks.json"
schema = json.loads(path.read_text())

S = DfSchema.from_dict(schema)
catcol = [el for el in S.columns if el.name == "BOROUGH_ID"][0]
assert catcol.categorical.value_set == frozenset((100, 200, 300, 400, 500))

0 comments on commit 18a6529

Please sign in to comment.