From 26c9a8fef032cc184f961eb6b79f9cffa8fbe89b Mon Sep 17 00:00:00 2001 From: Dylan Lukes Date: Fri, 23 Aug 2024 18:17:39 -0700 Subject: [PATCH] cleanup schema --- src/renkon/api.py | 14 ++++ src/renkon/core/model/__init__.py | 15 +--- src/renkon/core/model/schema.py | 80 +++++++------------ src/renkon/core/model/trait/__init__.py | 4 +- src/renkon/core/model/trait/result.py | 4 +- src/renkon/core/model/trait/sketch.py | 49 +++++++++--- src/renkon/core/model/type_aliases.py | 10 --- src/renkon/core/old_trait/compare.py | 2 +- src/renkon/core/old_trait/util/instantiate.py | 2 +- src/renkon/core/trait/__init__.py | 14 +++- src/renkon/core/trait/base.py | 5 +- src/renkon/core/trait/linear.py | 2 +- tests/renkon/core/model/test_schema.py | 12 +++ tests/renkon/core/model/test_sketch.py | 55 +++++++++++++ 14 files changed, 176 insertions(+), 92 deletions(-) delete mode 100644 src/renkon/core/model/type_aliases.py create mode 100644 tests/renkon/core/model/test_schema.py create mode 100644 tests/renkon/core/model/test_sketch.py diff --git a/src/renkon/api.py b/src/renkon/api.py index 169f7c2..1295756 100644 --- a/src/renkon/api.py +++ b/src/renkon/api.py @@ -1,3 +1,17 @@ # SPDX-FileCopyrightText: 2024-present Dylan Lukes # # SPDX-License-Identifier: BSD-3-Clause + +__all__ = [ + "int_", + "float_", + "str_", + "bool_", + "any_", + "none", + "numeric", + "equatable", + "comparable" +] + +from renkon.core.model.type import int_, float_, str_, bool_, any_, none, numeric, equatable, comparable diff --git a/src/renkon/core/model/__init__.py b/src/renkon/core/model/__init__.py index f11db4f..9ce3d25 100644 --- a/src/renkon/core/model/__init__.py +++ b/src/renkon/core/model/__init__.py @@ -4,29 +4,18 @@ __all__ = [ "RenkonType", - "ColumnName", - "ColumnNames", - "ColumnType", - "ColumnTypes", - "ColumnTypeSet", "Schema", "BitSeries", - "Schema", "TraitId", "TraitKind", "TraitPattern", "TraitSpec", "TraitSketch", "TraitResult", - "TraitResultScore", + "TraitScore", ] from renkon.core.model.bitseries import BitSeries from renkon.core.model.schema import Schema -from renkon.core.model.trait.kind import TraitKind -from renkon.core.model.trait.pattern import TraitPattern -from renkon.core.model.trait.result import TraitResult, TraitResultScore -from renkon.core.model.trait.sketch import TraitSketch -from renkon.core.model.trait.spec import TraitId, TraitSpec +from renkon.core.model.trait import TraitKind, TraitPattern, TraitResult, TraitScore, TraitSketch, TraitId, TraitSpec from renkon.core.model.type import RenkonType -from renkon.core.model.type_aliases import ColumnName, ColumnNames, ColumnType, ColumnTypes, ColumnTypeSet diff --git a/src/renkon/core/model/schema.py b/src/renkon/core/model/schema.py index 99bd885..f3584ac 100644 --- a/src/renkon/core/model/schema.py +++ b/src/renkon/core/model/schema.py @@ -1,63 +1,43 @@ -from collections.abc import Hashable, Iterator, Mapping, Sequence -from typing import Self, overload +import sys +from collections.abc import Sequence +from typing import Self -from polars.type_aliases import SchemaDict -from pydantic import ConfigDict, RootModel +from polars.type_aliases import SchemaDict as PolarsSchemaDict +from pydantic import GetCoreSchemaHandler +from pydantic_core import core_schema as cs -from renkon.core.model.type import RenkonType, tyconv_pl_to_rk -from renkon.core.model.type_aliases import ColumnName, ColumnNames +from renkon.core.model.type import RenkonType, tyconv_pl_to_rk, tyconv_rk_to_pl +type ColumnName = str +type ColumnNames = tuple[ColumnName, ...] -class Schema(RootModel[dict[ColumnName, RenkonType]], Mapping[ColumnName, RenkonType], Hashable): - """ - Represents a schema for some or all of the columns a data frame. +type ColumnType = RenkonType +type ColumnTypes = tuple[ColumnType, ...] +type ColumnTypeSet = frozenset[ColumnType] - Explicitly preserves order of its entries, provides a .index method for lookup of - the index of a column name, and convenience accessors for column names and types. +if sys.version_info <= (3, 6): + raise RuntimeError("Dictionaries are not guaranteed to preserve order before Python 3.6.") - Note that Python dict preserves insertion order since Pyt@hon 3.7. - """ - - model_config = ConfigDict(frozen=True) - root: dict[ColumnName, RenkonType] - - def __hash__(self) -> int: - return hash(tuple(self.root.items())) - - @overload - def __getitem__(self, key: ColumnName) -> RenkonType: ... - - @overload - def __getitem__(self, key: ColumnNames) -> Self: ... - - def __getitem__(self, key: ColumnName | ColumnNames) -> RenkonType | Self: - match key: - case str(): - return self.root[key] - case tuple(): - return self.subschema(key) - - def __iter__(self) -> Iterator[ColumnName]: # type: ignore - yield from iter(self.root) - - def __len__(self) -> int: - return len(self.root) - - def __lt__(self, other: Self) -> bool: - """Compares two schemas by their column names in lexicographic order.""" - return self.columns < other.columns +class Schema(dict[str, RenkonType]): @property - def columns(self) -> ColumnNames: - return tuple(self.root.keys()) + def columns(self): + return list(self.keys()) @property - def dtypes(self) -> tuple[RenkonType, ...]: - return tuple(self.root.values()) + def types(self): + return list(self.values()) + + def subschema(self, columns: Sequence[str]) -> Self: + return self.__class__({col: self[col] for col in columns}) @classmethod - def from_polars(cls, schema_dict: SchemaDict) -> Self: - return cls(root={col_name: tyconv_pl_to_rk(pl_ty) for col_name, pl_ty in schema_dict.items()}) + def from_polars(cls, schema: PolarsSchemaDict): + return cls({col: tyconv_pl_to_rk(pl_ty) for col, pl_ty in schema.items()}) - def subschema(self, columns: Sequence[str]) -> Self: - return self.__class__(root={col: self.root[col] for col in columns}) + def to_polars(self) -> PolarsSchemaDict: + return {col: tyconv_rk_to_pl(rk_ty) for col, rk_ty in self.items()} + + @classmethod + def __get_pydantic_core_schema__(cls, source_type: type, handler: GetCoreSchemaHandler, /): + return cs.chain_schema([handler(dict), cs.no_info_plain_validator_function(cls.__call__)]) diff --git a/src/renkon/core/model/trait/__init__.py b/src/renkon/core/model/trait/__init__.py index d198655..52fdeed 100644 --- a/src/renkon/core/model/trait/__init__.py +++ b/src/renkon/core/model/trait/__init__.py @@ -2,10 +2,10 @@ # # SPDX-License-Identifier: BSD-3-Clause -__all__ = ["TraitId", "TraitKind", "TraitPattern", "TraitSpec", "TraitSketch", "TraitResult"] +__all__ = ["TraitId", "TraitKind", "TraitPattern", "TraitSpec", "TraitSketch", "TraitResult", "TraitScore"] from renkon.core.model.trait.kind import TraitKind from renkon.core.model.trait.pattern import TraitPattern -from renkon.core.model.trait.result import TraitResult +from renkon.core.model.trait.result import TraitResult, TraitScore from renkon.core.model.trait.sketch import TraitSketch from renkon.core.model.trait.spec import TraitId, TraitSpec diff --git a/src/renkon/core/model/trait/result.py b/src/renkon/core/model/trait/result.py index 0e80186..ab2defa 100644 --- a/src/renkon/core/model/trait/result.py +++ b/src/renkon/core/model/trait/result.py @@ -9,7 +9,7 @@ from renkon.core.model.bitseries import BitSeries from renkon.core.model.trait.sketch import TraitSketch -type TraitResultScore = Annotated[float, Gt(0.0), Lt(1.0)] +type TraitScore = Annotated[float, Gt(0.0), Lt(1.0)] class TraitResult(BaseModel): @@ -19,7 +19,7 @@ class TraitResult(BaseModel): sketch: TraitSketch - score: TraitResultScore + score: TraitScore match_mask: BitSeries params: dict[str, tuple[str, Any]] diff --git a/src/renkon/core/model/trait/sketch.py b/src/renkon/core/model/trait/sketch.py index 8d6d6c6..454999d 100644 --- a/src/renkon/core/model/trait/sketch.py +++ b/src/renkon/core/model/trait/sketch.py @@ -5,26 +5,55 @@ from pydantic import BaseModel, model_validator +from renkon.core.model.schema import Schema from renkon.core.model.trait.spec import TraitSpec -from renkon.core.model.type_aliases import ColumnName, ColumnType class TraitSketch(BaseModel): """ Represents a sketch of a trait with holes filled. - :param trait: the trait being sketched. - :param metavar_bindings: the assignments of (typed) column names to metavariable in the trait form. + :param spec: the trait being sketched. + :param schema: schema (names -> types) of the data + :param bindings: bindings (metavariables -> actual column names) """ - trait: TraitSpec - metavar_bindings: dict[str, tuple[ColumnName, ColumnType]] + spec: TraitSpec + schema: Schema # pyright: ignore [reportIncompatibleMethodOverride] + bindings: dict[str, str] @model_validator(mode="after") - def check_columns(self) -> Self: - bound_colnames = set(self.metavar_bindings.keys()) - metavars = set(self.trait.pattern.metavars) - if bound_colnames != metavars: - msg = f"Bindings {bound_colnames} do not match trait metavariables {metavars}" + def check_bindings_keys(self) -> Self: + pattern_mvars = set(self.spec.pattern.metavars) + bound_mvars = set(self.bindings.keys()) + + missing_mvars = pattern_mvars - bound_mvars + extra_mvars = bound_mvars - pattern_mvars + + if len(missing_mvars) > 0: + msg = f"Metavariables {missing_mvars} are missing in bindings {self.bindings}" raise ValueError(msg) + + if len(extra_mvars) > 0: + msg = f"Metavariables {extra_mvars} do not occur in pattern {self.spec.pattern}" + raise ValueError(msg) + return self + + @model_validator(mode="after") + def check_bindings_values(self) -> Self: + for mvar, col in self.bindings.items(): + if col not in self.schema.columns: + msg = f"Cannot bind '{mvar}' to '{col} not found in {list(self.schema.columns)}" + raise ValueError(msg) + return self + + # @model_validator(mode="after") + # def check_bindings_typings(self) -> Self: + # for mvar, _ in self.bindings.items(): + # match self.trait.typings[mvar]: + # case RenkonType(): + # pass + # case str(): + # pass + # return self diff --git a/src/renkon/core/model/type_aliases.py b/src/renkon/core/model/type_aliases.py deleted file mode 100644 index aa3e553..0000000 --- a/src/renkon/core/model/type_aliases.py +++ /dev/null @@ -1,10 +0,0 @@ -from __future__ import annotations - -from renkon.core.model.type import RenkonType - -type ColumnName = str -type ColumnNames = tuple[ColumnName, ...] - -type ColumnType = RenkonType -type ColumnTypes = tuple[ColumnType, ...] -type ColumnTypeSet = frozenset[ColumnType] diff --git a/src/renkon/core/old_trait/compare.py b/src/renkon/core/old_trait/compare.py index 756164a..62ced61 100644 --- a/src/renkon/core/old_trait/compare.py +++ b/src/renkon/core/old_trait/compare.py @@ -62,7 +62,7 @@ def numeric_eq(lhs: pl.Series, rhs: pl.Series, schema: Schema) -> pl.Series: Equality comparator for numeric types which uses the equality operator, but in the case of floats uses np.isclose. """ - if set(schema.dtypes) & FLOAT_DTYPES: + if set(schema.types) & FLOAT_DTYPES: logger.warning(f"Sketch {schema} contains floats, using fuzzy check with rtol=1.e-5, atol=1.e-8.") return pl.Series(np.isclose(lhs, rhs)) diff --git a/src/renkon/core/old_trait/util/instantiate.py b/src/renkon/core/old_trait/util/instantiate.py index 989cecf..7056be1 100644 --- a/src/renkon/core/old_trait/util/instantiate.py +++ b/src/renkon/core/old_trait/util/instantiate.py @@ -9,7 +9,7 @@ def check_type_compatibility(meta: TraitMeta, schema: Schema) -> bool: if len(schema) != meta.arity: return False - for dtype, supported_dtypes in zip(schema.dtypes, meta.supported_dtypes, strict=True): + for dtype, supported_dtypes in zip(schema.types, meta.supported_dtypes, strict=True): if dtype not in supported_dtypes: return False diff --git a/src/renkon/core/trait/__init__.py b/src/renkon/core/trait/__init__.py index fbd8674..83f5746 100644 --- a/src/renkon/core/trait/__init__.py +++ b/src/renkon/core/trait/__init__.py @@ -1,6 +1,18 @@ # SPDX-FileCopyrightText: 2024-present Dylan Lukes # # SPDX-License-Identifier: BSD-3-Clause -__all__ = ["Equal", "Less", "LessOrEqual", "Greater", "GreaterOrEqual"] +__all__ = [ + "Equal", + "Less", + "LessOrEqual", + "Greater", + "GreaterOrEqual", + "NonNull", + "NonZero", + "NonNegative", + "Linear2" +] from renkon.core.trait.compare import Equal, Greater, GreaterOrEqual, Less, LessOrEqual +from renkon.core.trait.refinement import NonNull, NonZero, NonNegative +from renkon.core.trait.linear import Linear2 \ No newline at end of file diff --git a/src/renkon/core/trait/base.py b/src/renkon/core/trait/base.py index 3bc0656..7c393c8 100644 --- a/src/renkon/core/trait/base.py +++ b/src/renkon/core/trait/base.py @@ -4,7 +4,7 @@ from typing import ClassVar, Protocol, final import renkon.core.model.type as rk_type -from renkon.core.model import TraitId, TraitKind, TraitPattern, TraitSketch, TraitSpec +from renkon.core.model import TraitId, TraitKind, TraitPattern, TraitSketch, TraitSpec, Schema from renkon.core.model.type import RenkonType @@ -47,6 +47,9 @@ def typevars(self) -> dict[str, RenkonType]: def typings(self) -> dict[str, RenkonType | str]: return self.spec.typings + def can_sketch(self, schema: Schema, bindings: dict[str, str]) -> bool: + return False # todo: implement + def sketch(self, **kwargs: RenkonType) -> TraitSketch: return TraitSketch.model_validate( { diff --git a/src/renkon/core/trait/linear.py b/src/renkon/core/trait/linear.py index 6d51d99..14a1196 100644 --- a/src/renkon/core/trait/linear.py +++ b/src/renkon/core/trait/linear.py @@ -6,7 +6,7 @@ from renkon.core.trait.base import Trait -class Linear(Trait): +class Linear2(Trait): spec = TraitSpec( id=f"{__qualname__}", name=f"{__name__}", diff --git a/tests/renkon/core/model/test_schema.py b/tests/renkon/core/model/test_schema.py new file mode 100644 index 0000000..ead0eb9 --- /dev/null +++ b/tests/renkon/core/model/test_schema.py @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: 2024-present Dylan Lukes +# +# SPDX-License-Identifier: BSD-3-Clause +from renkon.core.model.schema import Schema + + +def test_schema(): + schema = Schema({}) + print(schema) + +def test_schema_as_model_field(): + pass \ No newline at end of file diff --git a/tests/renkon/core/model/test_sketch.py b/tests/renkon/core/model/test_sketch.py new file mode 100644 index 0000000..3395dce --- /dev/null +++ b/tests/renkon/core/model/test_sketch.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: 2024-present Dylan Lukes +# +# SPDX-License-Identifier: BSD-3-Clause +import pytest + +import renkon.api as rk +from renkon.core.model import TraitSketch, Schema +from renkon.core.trait import Linear2, Equal + + +def test_sketch_bindings_missing(): + schema = Schema({ + "x": rk.int_(), + "y": rk.int_() + }) + with pytest.raises(ValueError, match="missing in bindings"): + TraitSketch( + spec=Equal.spec, + schema=schema, + bindings={ + "A": "x" + } + ) + + +def test_sketch_bindings_extra(): + schema = Schema({ + "x": rk.int_(), + "y": rk.int_() + }) + with pytest.raises(ValueError, match="do not occur in pattern"): + TraitSketch( + spec=Equal.spec, + schema=schema, + bindings={ + "A": "x", + "B": "y", + "C": "z" + } + ) + + +def test_sketch_linear2(): + schema = Schema({ + "time": rk.float_(), + "open tabs": rk.float_() + }) + TraitSketch( + spec=Linear2.spec, + schema=schema, + bindings={ + "X": "time", + "Y": "open tabs" + } + )