Skip to content

Commit

Permalink
typevar validation for Sketch
Browse files Browse the repository at this point in the history
  • Loading branch information
DylanLukes committed Aug 27, 2024
1 parent 3568635 commit 9d9911c
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 34 deletions.
65 changes: 55 additions & 10 deletions src/renkon/core/model/trait/sketch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pydantic import BaseModel, model_validator

import renkon.core.model.type as rk
from renkon.core.model.schema import Schema
from renkon.core.model.trait.spec import TraitSpec
from renkon.core.model.type import RenkonType
Expand All @@ -26,6 +27,9 @@ class TraitSketch(BaseModel):
# Inverted lookup from column name to metavariable
_bindings_inv: dict[str, str] = {}

# Instantiations of typevars to concrete types
_typevar_insts: dict[str, RenkonType] = {}

@model_validator(mode="after")
def _populate_bindings_inv(self) -> Self:
self._bindings_inv = {v: k for (k, v) in self.bindings.items()}
Expand Down Expand Up @@ -57,18 +61,59 @@ def _check_bindings_values(self) -> Self:
raise ValueError(msg)
return self

@model_validator(mode="after")
def _populate_typevar_insts(self) -> Self:
"""(Try to) instantiate each type variable to a concrete type."""

col_to_type = self.schema
mvar_to_col = self.bindings
mvar_to_typing = self.spec.typings
mvar_to_type = {mvar: col_to_type[col] for (mvar, col) in mvar_to_col.items()}

typevars = self.spec.typevars
typevar_insts: dict[str, RenkonType] = self._typevar_insts

for typevar_name, typevar_bound in typevars.items():
# Filter mvar_to_type to only entries that reference this type variable.
typevar_mvar_to_type = {
mvar: mvar_to_type[mvar]
for (mvar, mvar_typing) in mvar_to_typing.items()
if isinstance(mvar_typing, str) and mvar_typing == typevar_name
}

# Check that all the bounds are satisfied.
for mvar, mvar_type in typevar_mvar_to_type.items():
if not mvar_type.is_subtype(typevar_bound):
msg = (f"Column '{mvar_to_col[mvar]} has incompatible type '{mvar_type}', "
f"does not satisfy bound '{typevar_bound}' of typevar '{typevar_name}'.")
raise TypeError(msg)

# Attempt to find a least upper bound to instantiate the typevar to.
lub_ty = rk.union(rk.any_())
for mvar_type in typevar_mvar_to_type.values():
lub_ty &= rk.union(mvar_type)
lub_ty = lub_ty.normalize()

if lub_ty == rk.none():
msg = f"Could not instantiate typevar '{typevar_name}' given concrete typings {typevar_mvar_to_type}"
raise TypeError(msg)

typevar_insts[typevar_name] = lub_ty

self._typevar_insts = typevar_insts
return self

@model_validator(mode="after")
def _check_bindings_typings(self) -> Self:
# Check that the types in the provided schema match typings.
for col, ty in self.schema.items():
metavar = self._bindings_inv[col]
req_ty = self.spec.typings[metavar]
match req_ty:
case RenkonType():
if not ty.is_subtype(req_ty):
msg = f"Column '{col}' has incompatible type '{ty}', expected '{req_ty}'."
raise TypeError(msg)
case str():
raise NotImplementedError
mvar = self._bindings_inv[col]
req_ty = self.spec.typings[mvar]

if isinstance(req_ty, str):
req_ty = self._typevar_insts[req_ty]

if not ty.is_subtype(req_ty):
msg = f"Column '{col}' has incompatible type '{ty}', does not satisfy bound '{req_ty}'."
raise TypeError(msg)

return self
41 changes: 23 additions & 18 deletions src/renkon/core/model/type/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,11 @@ def __eq__(self, other: object) -> bool:
return super().__eq__(other)
return self.is_equal(other)

def __and__(self, other: RenkonType) -> RenkonType:
return union(self).intersect(union(other)).normalize()

def __or__(self, other: RenkonType) -> UnionType:
return UnionType(ts=frozenset({self, other})).canonicalize()
return union(self, other).canonicalize()

@abstractmethod
def __hash__(self) -> int: ...
Expand Down Expand Up @@ -333,21 +336,21 @@ class BoolType(PrimitiveType, name="bool"): ...
class UnionType(RenkonType):
ts: frozenset[RenkonType]

@property
def is_empty_union(self) -> bool:
return not self.ts

@property
def is_trivial_union(self) -> bool:
"""True if the union is of one unique type."""
return len(self.ts) == 1

@property
def contains_top(self) -> bool:
"""True if the union contains (at any depth) a TopType."""
return bool(any(isinstance(t, TopType) for t in self.flatten().ts))

@property
def contains_bottom(self) -> bool:
"""True if the union contains (at any depth) a BottomType."""
return bool(any(isinstance(t, BottomType) for t in self.flatten().ts))

def contains_union(self) -> bool:
"""True if the union contains an immediate child Union."""
return bool(any(isinstance(t, UnionType) for t in self.ts))
Expand Down Expand Up @@ -376,7 +379,7 @@ def canonicalize(self) -> UnionType:
ts = flat.ts

# If a top type is present, leave only it.
if self.contains_top:
if self.contains_top():
return UnionType(ts=frozenset({TopType()}))

# Remove any bottom types.
Expand All @@ -387,30 +390,35 @@ def canonicalize(self) -> UnionType:
def normalize(self) -> RenkonType:
canon = self.canonicalize()

if canon.contains_top:
if canon.contains_top():
return TopType()
if canon.is_empty_union:
if canon.is_empty_union():
return BottomType()
if canon.is_trivial_union:
if canon.is_trivial_union():
return canon.single()
return canon

def union(self, other: UnionType) -> UnionType:
return UnionType(ts=self.ts.union(other.ts)).canonicalize()

def intersect(self, other: UnionType) -> UnionType:
if self.contains_top():
return other.canonicalize()
if other.contains_top():
return self.canonicalize()

return UnionType(ts=self.ts.intersection(other.ts)).canonicalize()

def dump_string(self) -> str:
if self.is_empty_union:
return " | "
if self.is_trivial_union:
return f"{self.single().dump_string()} | "
if self.is_empty_union():
return "none | none"
if self.is_trivial_union():
return f"{self.single().dump_string()} | none"
return " | ".join(sorted(t.dump_string() for t in self.ts))

def flatten(self) -> UnionType:
"""Recursively flatten nested unions."""
if not self.contains_union:
if not self.contains_union():
return self
ts: set[RenkonType] = set()
for t in self.ts:
Expand All @@ -421,17 +429,14 @@ def flatten(self) -> UnionType:
return UnionType.model_validate({"ts": ts})

def single(self) -> RenkonType:
if not self.is_trivial_union:
if not self.is_trivial_union():
msg = "Union is not trivial, a single type"
raise ValueError(msg)
return next(iter(self.ts))

def __hash__(self) -> int:
return hash((type(self), self.ts))

def __and__(self, other: UnionType) -> UnionType:
return UnionType(ts=self.ts.intersection(other.ts)).canonicalize()


# endregion

Expand Down
4 changes: 2 additions & 2 deletions src/renkon/core/trait/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from typing import Any, ClassVar, Literal, final

from renkon.core.model import TraitKind, TraitPattern, TraitSpec
from renkon.core.model.type import comparable, numeric
from renkon.core.model.type import comparable, numeric, equatable
from renkon.core.trait.base import Trait

type _CmpOpStr = Literal["<", "≤", "=", "≥", ">"]
Expand Down Expand Up @@ -39,7 +39,7 @@ def __init_subclass__(cls, *, op_str: _CmpOpStr, **kwargs: Any):
kind=TraitKind.LOGICAL,
pattern=TraitPattern("{A}" f" {op_str} " "{B}"),
commutors=[{"A", "B"}],
typevars={"T": numeric() if op_str == "=" else comparable()},
typevars={"T": equatable() if op_str == "=" else comparable()},
typings={"A": "T", "B": "T"},
)

Expand Down
12 changes: 9 additions & 3 deletions tests/renkon/core/model/test_sketch.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,21 @@ def test_sketch_linear2():

def test_sketch_incorrect_typing():
schema = Schema({"x": rk.int_(), "name": rk.str_()})
with pytest.raises(TypeError, match="incompatible type"):
with pytest.raises(TypeError, match="incompatible type .* does not satisfy bound"):
TraitSketch(spec=Linear2.spec, schema=schema, bindings={"X": "x", "Y": "name"})


def test_sketch_typevars():
def test_sketch_typevar_incorrect_typing():
schema = Schema({"a": rk.float_(), "b": rk.float_()})
with pytest.raises(TypeError, match="incompatible type .* does not satisfy bound .* of typevar"):
TraitSketch(spec=Equal.spec, schema=schema, bindings={"A": "a", "B": "b"})


def test_sketch_typevar_instantiation():
for ty1, ty2 in it.product(rk.equatable().ts, repeat=2):
schema = Schema({"a": ty1, "b": ty2})
if ty1 == ty2:
TraitSketch(spec=Equal.spec, schema=schema, bindings={"A": "a", "B": "b"})
else:
with pytest.raises(TypeError):
with pytest.raises(TypeError, match=r"Could not instantiate .* given concrete .*"):
TraitSketch(spec=Equal.spec, schema=schema, bindings={"A": "a", "B": "b"})
5 changes: 4 additions & 1 deletion tests/renkon/core/model/type/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def test_type_model_dump_primitive():
def test_type_model_dump_union():
assert union(int_(), float_()).model_dump() == "float | int"
assert union(int_(), str_()).model_dump() == "int | string"
assert union().model_dump() == " | "
assert union().model_dump() == "none | none"


def test_type_model_dump_any():
Expand Down Expand Up @@ -158,6 +158,9 @@ def test_union_intersect():
assert union(int_(), float_()).intersect(union(str_(), bool_())) == union()


def test_union_intersect_any():
assert union(any_()).intersect(union(int_(), str_())) == union(int_(), str_())

def test_union_dump_python():
assert union(int_(), float_()).model_dump() == "float | int"

Expand Down

0 comments on commit 9d9911c

Please sign in to comment.