Skip to content

Commit

Permalink
cleanup schema
Browse files Browse the repository at this point in the history
  • Loading branch information
DylanLukes committed Aug 24, 2024
1 parent 7f570f7 commit 26c9a8f
Show file tree
Hide file tree
Showing 14 changed files with 176 additions and 92 deletions.
14 changes: 14 additions & 0 deletions src/renkon/api.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# SPDX-FileCopyrightText: 2024-present Dylan Lukes <[email protected]>
#
# SPDX-License-Identifier: BSD-3-Clause

__all__ = [
"int_",
"float_",
"str_",
"bool_",
"any_",
"none",
"numeric",
"equatable",
"comparable"
]

from renkon.core.model.type import int_, float_, str_, bool_, any_, none, numeric, equatable, comparable
15 changes: 2 additions & 13 deletions src/renkon/core/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,18 @@

__all__ = [
"RenkonType",
"ColumnName",
"ColumnNames",
"ColumnType",
"ColumnTypes",
"ColumnTypeSet",
"Schema",
"BitSeries",
"Schema",
"TraitId",
"TraitKind",
"TraitPattern",
"TraitSpec",
"TraitSketch",
"TraitResult",
"TraitResultScore",
"TraitScore",
]

from renkon.core.model.bitseries import BitSeries
from renkon.core.model.schema import Schema
from renkon.core.model.trait.kind import TraitKind
from renkon.core.model.trait.pattern import TraitPattern
from renkon.core.model.trait.result import TraitResult, TraitResultScore
from renkon.core.model.trait.sketch import TraitSketch
from renkon.core.model.trait.spec import TraitId, TraitSpec
from renkon.core.model.trait import TraitKind, TraitPattern, TraitResult, TraitScore, TraitSketch, TraitId, TraitSpec
from renkon.core.model.type import RenkonType
from renkon.core.model.type_aliases import ColumnName, ColumnNames, ColumnType, ColumnTypes, ColumnTypeSet
80 changes: 30 additions & 50 deletions src/renkon/core/model/schema.py
Original file line number Diff line number Diff line change
@@ -1,63 +1,43 @@
from collections.abc import Hashable, Iterator, Mapping, Sequence
from typing import Self, overload
import sys
from collections.abc import Sequence
from typing import Self

from polars.type_aliases import SchemaDict
from pydantic import ConfigDict, RootModel
from polars.type_aliases import SchemaDict as PolarsSchemaDict
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema as cs

from renkon.core.model.type import RenkonType, tyconv_pl_to_rk
from renkon.core.model.type_aliases import ColumnName, ColumnNames
from renkon.core.model.type import RenkonType, tyconv_pl_to_rk, tyconv_rk_to_pl

type ColumnName = str
type ColumnNames = tuple[ColumnName, ...]

class Schema(RootModel[dict[ColumnName, RenkonType]], Mapping[ColumnName, RenkonType], Hashable):
"""
Represents a schema for some or all of the columns a data frame.
type ColumnType = RenkonType
type ColumnTypes = tuple[ColumnType, ...]
type ColumnTypeSet = frozenset[ColumnType]

Explicitly preserves order of its entries, provides a .index method for lookup of
the index of a column name, and convenience accessors for column names and types.
if sys.version_info <= (3, 6):
raise RuntimeError("Dictionaries are not guaranteed to preserve order before Python 3.6.")

Note that Python dict preserves insertion order since Pyt@hon 3.7.
"""

model_config = ConfigDict(frozen=True)
root: dict[ColumnName, RenkonType]

def __hash__(self) -> int:
return hash(tuple(self.root.items()))

@overload
def __getitem__(self, key: ColumnName) -> RenkonType: ...

@overload
def __getitem__(self, key: ColumnNames) -> Self: ...

def __getitem__(self, key: ColumnName | ColumnNames) -> RenkonType | Self:
match key:
case str():
return self.root[key]
case tuple():
return self.subschema(key)

def __iter__(self) -> Iterator[ColumnName]: # type: ignore
yield from iter(self.root)

def __len__(self) -> int:
return len(self.root)

def __lt__(self, other: Self) -> bool:
"""Compares two schemas by their column names in lexicographic order."""
return self.columns < other.columns

class Schema(dict[str, RenkonType]):
@property
def columns(self) -> ColumnNames:
return tuple(self.root.keys())
def columns(self):
return list(self.keys())

@property
def dtypes(self) -> tuple[RenkonType, ...]:
return tuple(self.root.values())
def types(self):
return list(self.values())

def subschema(self, columns: Sequence[str]) -> Self:
return self.__class__({col: self[col] for col in columns})

@classmethod
def from_polars(cls, schema_dict: SchemaDict) -> Self:
return cls(root={col_name: tyconv_pl_to_rk(pl_ty) for col_name, pl_ty in schema_dict.items()})
def from_polars(cls, schema: PolarsSchemaDict):
return cls({col: tyconv_pl_to_rk(pl_ty) for col, pl_ty in schema.items()})

def subschema(self, columns: Sequence[str]) -> Self:
return self.__class__(root={col: self.root[col] for col in columns})
def to_polars(self) -> PolarsSchemaDict:
return {col: tyconv_rk_to_pl(rk_ty) for col, rk_ty in self.items()}

@classmethod
def __get_pydantic_core_schema__(cls, source_type: type, handler: GetCoreSchemaHandler, /):
return cs.chain_schema([handler(dict), cs.no_info_plain_validator_function(cls.__call__)])
4 changes: 2 additions & 2 deletions src/renkon/core/model/trait/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
#
# SPDX-License-Identifier: BSD-3-Clause

__all__ = ["TraitId", "TraitKind", "TraitPattern", "TraitSpec", "TraitSketch", "TraitResult"]
__all__ = ["TraitId", "TraitKind", "TraitPattern", "TraitSpec", "TraitSketch", "TraitResult", "TraitScore"]

from renkon.core.model.trait.kind import TraitKind
from renkon.core.model.trait.pattern import TraitPattern
from renkon.core.model.trait.result import TraitResult
from renkon.core.model.trait.result import TraitResult, TraitScore
from renkon.core.model.trait.sketch import TraitSketch
from renkon.core.model.trait.spec import TraitId, TraitSpec
4 changes: 2 additions & 2 deletions src/renkon/core/model/trait/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from renkon.core.model.bitseries import BitSeries
from renkon.core.model.trait.sketch import TraitSketch

type TraitResultScore = Annotated[float, Gt(0.0), Lt(1.0)]
type TraitScore = Annotated[float, Gt(0.0), Lt(1.0)]


class TraitResult(BaseModel):
Expand All @@ -19,7 +19,7 @@ class TraitResult(BaseModel):

sketch: TraitSketch

score: TraitResultScore
score: TraitScore
match_mask: BitSeries

params: dict[str, tuple[str, Any]]
49 changes: 39 additions & 10 deletions src/renkon/core/model/trait/sketch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,55 @@

from pydantic import BaseModel, model_validator

from renkon.core.model.schema import Schema
from renkon.core.model.trait.spec import TraitSpec
from renkon.core.model.type_aliases import ColumnName, ColumnType


class TraitSketch(BaseModel):
"""
Represents a sketch of a trait with holes filled.
:param trait: the trait being sketched.
:param metavar_bindings: the assignments of (typed) column names to metavariable in the trait form.
:param spec: the trait being sketched.
:param schema: schema (names -> types) of the data
:param bindings: bindings (metavariables -> actual column names)
"""

trait: TraitSpec
metavar_bindings: dict[str, tuple[ColumnName, ColumnType]]
spec: TraitSpec
schema: Schema # pyright: ignore [reportIncompatibleMethodOverride]
bindings: dict[str, str]

@model_validator(mode="after")
def check_columns(self) -> Self:
bound_colnames = set(self.metavar_bindings.keys())
metavars = set(self.trait.pattern.metavars)
if bound_colnames != metavars:
msg = f"Bindings {bound_colnames} do not match trait metavariables {metavars}"
def check_bindings_keys(self) -> Self:
pattern_mvars = set(self.spec.pattern.metavars)
bound_mvars = set(self.bindings.keys())

missing_mvars = pattern_mvars - bound_mvars
extra_mvars = bound_mvars - pattern_mvars

if len(missing_mvars) > 0:
msg = f"Metavariables {missing_mvars} are missing in bindings {self.bindings}"
raise ValueError(msg)

if len(extra_mvars) > 0:
msg = f"Metavariables {extra_mvars} do not occur in pattern {self.spec.pattern}"
raise ValueError(msg)

return self

@model_validator(mode="after")
def check_bindings_values(self) -> Self:
for mvar, col in self.bindings.items():
if col not in self.schema.columns:
msg = f"Cannot bind '{mvar}' to '{col} not found in {list(self.schema.columns)}"
raise ValueError(msg)
return self

# @model_validator(mode="after")
# def check_bindings_typings(self) -> Self:
# for mvar, _ in self.bindings.items():
# match self.trait.typings[mvar]:
# case RenkonType():
# pass
# case str():
# pass
# return self
10 changes: 0 additions & 10 deletions src/renkon/core/model/type_aliases.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/renkon/core/old_trait/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def numeric_eq(lhs: pl.Series, rhs: pl.Series, schema: Schema) -> pl.Series:
Equality comparator for numeric types which uses the equality operator,
but in the case of floats uses np.isclose.
"""
if set(schema.dtypes) & FLOAT_DTYPES:
if set(schema.types) & FLOAT_DTYPES:
logger.warning(f"Sketch {schema} contains floats, using fuzzy check with rtol=1.e-5, atol=1.e-8.")
return pl.Series(np.isclose(lhs, rhs))

Expand Down
2 changes: 1 addition & 1 deletion src/renkon/core/old_trait/util/instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def check_type_compatibility(meta: TraitMeta, schema: Schema) -> bool:
if len(schema) != meta.arity:
return False

for dtype, supported_dtypes in zip(schema.dtypes, meta.supported_dtypes, strict=True):
for dtype, supported_dtypes in zip(schema.types, meta.supported_dtypes, strict=True):
if dtype not in supported_dtypes:
return False

Expand Down
14 changes: 13 additions & 1 deletion src/renkon/core/trait/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
# SPDX-FileCopyrightText: 2024-present Dylan Lukes <[email protected]>
#
# SPDX-License-Identifier: BSD-3-Clause
__all__ = ["Equal", "Less", "LessOrEqual", "Greater", "GreaterOrEqual"]
__all__ = [
"Equal",
"Less",
"LessOrEqual",
"Greater",
"GreaterOrEqual",
"NonNull",
"NonZero",
"NonNegative",
"Linear2"
]

from renkon.core.trait.compare import Equal, Greater, GreaterOrEqual, Less, LessOrEqual
from renkon.core.trait.refinement import NonNull, NonZero, NonNegative
from renkon.core.trait.linear import Linear2
5 changes: 4 additions & 1 deletion src/renkon/core/trait/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import ClassVar, Protocol, final

import renkon.core.model.type as rk_type
from renkon.core.model import TraitId, TraitKind, TraitPattern, TraitSketch, TraitSpec
from renkon.core.model import TraitId, TraitKind, TraitPattern, TraitSketch, TraitSpec, Schema
from renkon.core.model.type import RenkonType


Expand Down Expand Up @@ -47,6 +47,9 @@ def typevars(self) -> dict[str, RenkonType]:
def typings(self) -> dict[str, RenkonType | str]:
return self.spec.typings

def can_sketch(self, schema: Schema, bindings: dict[str, str]) -> bool:
return False # todo: implement

def sketch(self, **kwargs: RenkonType) -> TraitSketch:
return TraitSketch.model_validate(
{
Expand Down
2 changes: 1 addition & 1 deletion src/renkon/core/trait/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from renkon.core.trait.base import Trait


class Linear(Trait):
class Linear2(Trait):
spec = TraitSpec(
id=f"{__qualname__}",
name=f"{__name__}",
Expand Down
12 changes: 12 additions & 0 deletions tests/renkon/core/model/test_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-FileCopyrightText: 2024-present Dylan Lukes <[email protected]>
#
# SPDX-License-Identifier: BSD-3-Clause
from renkon.core.model.schema import Schema


def test_schema():
schema = Schema({})
print(schema)

def test_schema_as_model_field():
pass
55 changes: 55 additions & 0 deletions tests/renkon/core/model/test_sketch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# SPDX-FileCopyrightText: 2024-present Dylan Lukes <[email protected]>
#
# SPDX-License-Identifier: BSD-3-Clause
import pytest

import renkon.api as rk
from renkon.core.model import TraitSketch, Schema
from renkon.core.trait import Linear2, Equal


def test_sketch_bindings_missing():
schema = Schema({
"x": rk.int_(),
"y": rk.int_()
})
with pytest.raises(ValueError, match="missing in bindings"):
TraitSketch(
spec=Equal.spec,
schema=schema,
bindings={
"A": "x"
}
)


def test_sketch_bindings_extra():
schema = Schema({
"x": rk.int_(),
"y": rk.int_()
})
with pytest.raises(ValueError, match="do not occur in pattern"):
TraitSketch(
spec=Equal.spec,
schema=schema,
bindings={
"A": "x",
"B": "y",
"C": "z"
}
)


def test_sketch_linear2():
schema = Schema({
"time": rk.float_(),
"open tabs": rk.float_()
})
TraitSketch(
spec=Linear2.spec,
schema=schema,
bindings={
"X": "time",
"Y": "open tabs"
}
)

0 comments on commit 26c9a8f

Please sign in to comment.