From 9d9911cba1032e418911e7bddb3917ca94452265 Mon Sep 17 00:00:00 2001 From: Dylan Lukes Date: Mon, 26 Aug 2024 19:23:12 -0700 Subject: [PATCH] typevar validation for Sketch --- src/renkon/core/model/trait/sketch.py | 65 +++++++++++++++++++---- src/renkon/core/model/type/base.py | 41 +++++++------- src/renkon/core/trait/compare.py | 4 +- tests/renkon/core/model/test_sketch.py | 12 +++-- tests/renkon/core/model/type/test_base.py | 5 +- 5 files changed, 93 insertions(+), 34 deletions(-) diff --git a/src/renkon/core/model/trait/sketch.py b/src/renkon/core/model/trait/sketch.py index 371d8cb..d0fbb1e 100644 --- a/src/renkon/core/model/trait/sketch.py +++ b/src/renkon/core/model/trait/sketch.py @@ -5,6 +5,7 @@ from pydantic import BaseModel, model_validator +import renkon.core.model.type as rk from renkon.core.model.schema import Schema from renkon.core.model.trait.spec import TraitSpec from renkon.core.model.type import RenkonType @@ -26,6 +27,9 @@ class TraitSketch(BaseModel): # Inverted lookup from column name to metavariable _bindings_inv: dict[str, str] = {} + # Instantiations of typevars to concrete types + _typevar_insts: dict[str, RenkonType] = {} + @model_validator(mode="after") def _populate_bindings_inv(self) -> Self: self._bindings_inv = {v: k for (k, v) in self.bindings.items()} @@ -57,18 +61,59 @@ def _check_bindings_values(self) -> Self: raise ValueError(msg) return self + @model_validator(mode="after") + def _populate_typevar_insts(self) -> Self: + """(Try to) instantiate each type variable to a concrete type.""" + + col_to_type = self.schema + mvar_to_col = self.bindings + mvar_to_typing = self.spec.typings + mvar_to_type = {mvar: col_to_type[col] for (mvar, col) in mvar_to_col.items()} + + typevars = self.spec.typevars + typevar_insts: dict[str, RenkonType] = self._typevar_insts + + for typevar_name, typevar_bound in typevars.items(): + # Filter mvar_to_type to only entries that reference this type variable. + typevar_mvar_to_type = { + mvar: mvar_to_type[mvar] + for (mvar, mvar_typing) in mvar_to_typing.items() + if isinstance(mvar_typing, str) and mvar_typing == typevar_name + } + + # Check that all the bounds are satisfied. + for mvar, mvar_type in typevar_mvar_to_type.items(): + if not mvar_type.is_subtype(typevar_bound): + msg = (f"Column '{mvar_to_col[mvar]} has incompatible type '{mvar_type}', " + f"does not satisfy bound '{typevar_bound}' of typevar '{typevar_name}'.") + raise TypeError(msg) + + # Attempt to find a least upper bound to instantiate the typevar to. + lub_ty = rk.union(rk.any_()) + for mvar_type in typevar_mvar_to_type.values(): + lub_ty &= rk.union(mvar_type) + lub_ty = lub_ty.normalize() + + if lub_ty == rk.none(): + msg = f"Could not instantiate typevar '{typevar_name}' given concrete typings {typevar_mvar_to_type}" + raise TypeError(msg) + + typevar_insts[typevar_name] = lub_ty + + self._typevar_insts = typevar_insts + return self + @model_validator(mode="after") def _check_bindings_typings(self) -> Self: - # Check that the types in the provided schema match typings. for col, ty in self.schema.items(): - metavar = self._bindings_inv[col] - req_ty = self.spec.typings[metavar] - match req_ty: - case RenkonType(): - if not ty.is_subtype(req_ty): - msg = f"Column '{col}' has incompatible type '{ty}', expected '{req_ty}'." - raise TypeError(msg) - case str(): - raise NotImplementedError + mvar = self._bindings_inv[col] + req_ty = self.spec.typings[mvar] + + if isinstance(req_ty, str): + req_ty = self._typevar_insts[req_ty] + + if not ty.is_subtype(req_ty): + msg = f"Column '{col}' has incompatible type '{ty}', does not satisfy bound '{req_ty}'." + raise TypeError(msg) return self diff --git a/src/renkon/core/model/type/base.py b/src/renkon/core/model/type/base.py index d74d30b..e15d710 100644 --- a/src/renkon/core/model/type/base.py +++ b/src/renkon/core/model/type/base.py @@ -164,8 +164,11 @@ def __eq__(self, other: object) -> bool: return super().__eq__(other) return self.is_equal(other) + def __and__(self, other: RenkonType) -> RenkonType: + return union(self).intersect(union(other)).normalize() + def __or__(self, other: RenkonType) -> UnionType: - return UnionType(ts=frozenset({self, other})).canonicalize() + return union(self, other).canonicalize() @abstractmethod def __hash__(self) -> int: ... @@ -333,21 +336,21 @@ class BoolType(PrimitiveType, name="bool"): ... class UnionType(RenkonType): ts: frozenset[RenkonType] - @property def is_empty_union(self) -> bool: return not self.ts - @property def is_trivial_union(self) -> bool: """True if the union is of one unique type.""" return len(self.ts) == 1 - @property def contains_top(self) -> bool: """True if the union contains (at any depth) a TopType.""" return bool(any(isinstance(t, TopType) for t in self.flatten().ts)) - @property + def contains_bottom(self) -> bool: + """True if the union contains (at any depth) a BottomType.""" + return bool(any(isinstance(t, BottomType) for t in self.flatten().ts)) + def contains_union(self) -> bool: """True if the union contains an immediate child Union.""" return bool(any(isinstance(t, UnionType) for t in self.ts)) @@ -376,7 +379,7 @@ def canonicalize(self) -> UnionType: ts = flat.ts # If a top type is present, leave only it. - if self.contains_top: + if self.contains_top(): return UnionType(ts=frozenset({TopType()})) # Remove any bottom types. @@ -387,11 +390,11 @@ def canonicalize(self) -> UnionType: def normalize(self) -> RenkonType: canon = self.canonicalize() - if canon.contains_top: + if canon.contains_top(): return TopType() - if canon.is_empty_union: + if canon.is_empty_union(): return BottomType() - if canon.is_trivial_union: + if canon.is_trivial_union(): return canon.single() return canon @@ -399,18 +402,23 @@ def union(self, other: UnionType) -> UnionType: return UnionType(ts=self.ts.union(other.ts)).canonicalize() def intersect(self, other: UnionType) -> UnionType: + if self.contains_top(): + return other.canonicalize() + if other.contains_top(): + return self.canonicalize() + return UnionType(ts=self.ts.intersection(other.ts)).canonicalize() def dump_string(self) -> str: - if self.is_empty_union: - return "⊥ | ⊥" - if self.is_trivial_union: - return f"{self.single().dump_string()} | ⊥" + if self.is_empty_union(): + return "none | none" + if self.is_trivial_union(): + return f"{self.single().dump_string()} | none" return " | ".join(sorted(t.dump_string() for t in self.ts)) def flatten(self) -> UnionType: """Recursively flatten nested unions.""" - if not self.contains_union: + if not self.contains_union(): return self ts: set[RenkonType] = set() for t in self.ts: @@ -421,7 +429,7 @@ def flatten(self) -> UnionType: return UnionType.model_validate({"ts": ts}) def single(self) -> RenkonType: - if not self.is_trivial_union: + if not self.is_trivial_union(): msg = "Union is not trivial, a single type" raise ValueError(msg) return next(iter(self.ts)) @@ -429,9 +437,6 @@ def single(self) -> RenkonType: def __hash__(self) -> int: return hash((type(self), self.ts)) - def __and__(self, other: UnionType) -> UnionType: - return UnionType(ts=self.ts.intersection(other.ts)).canonicalize() - # endregion diff --git a/src/renkon/core/trait/compare.py b/src/renkon/core/trait/compare.py index 3b6c3a6..ec61267 100644 --- a/src/renkon/core/trait/compare.py +++ b/src/renkon/core/trait/compare.py @@ -9,7 +9,7 @@ from typing import Any, ClassVar, Literal, final from renkon.core.model import TraitKind, TraitPattern, TraitSpec -from renkon.core.model.type import comparable, numeric +from renkon.core.model.type import comparable, numeric, equatable from renkon.core.trait.base import Trait type _CmpOpStr = Literal["<", "≤", "=", "≥", ">"] @@ -39,7 +39,7 @@ def __init_subclass__(cls, *, op_str: _CmpOpStr, **kwargs: Any): kind=TraitKind.LOGICAL, pattern=TraitPattern("{A}" f" {op_str} " "{B}"), commutors=[{"A", "B"}], - typevars={"T": numeric() if op_str == "=" else comparable()}, + typevars={"T": equatable() if op_str == "=" else comparable()}, typings={"A": "T", "B": "T"}, ) diff --git a/tests/renkon/core/model/test_sketch.py b/tests/renkon/core/model/test_sketch.py index b3205dc..565eef8 100644 --- a/tests/renkon/core/model/test_sketch.py +++ b/tests/renkon/core/model/test_sketch.py @@ -29,15 +29,21 @@ def test_sketch_linear2(): def test_sketch_incorrect_typing(): schema = Schema({"x": rk.int_(), "name": rk.str_()}) - with pytest.raises(TypeError, match="incompatible type"): + with pytest.raises(TypeError, match="incompatible type .* does not satisfy bound"): TraitSketch(spec=Linear2.spec, schema=schema, bindings={"X": "x", "Y": "name"}) -def test_sketch_typevars(): +def test_sketch_typevar_incorrect_typing(): + schema = Schema({"a": rk.float_(), "b": rk.float_()}) + with pytest.raises(TypeError, match="incompatible type .* does not satisfy bound .* of typevar"): + TraitSketch(spec=Equal.spec, schema=schema, bindings={"A": "a", "B": "b"}) + + +def test_sketch_typevar_instantiation(): for ty1, ty2 in it.product(rk.equatable().ts, repeat=2): schema = Schema({"a": ty1, "b": ty2}) if ty1 == ty2: TraitSketch(spec=Equal.spec, schema=schema, bindings={"A": "a", "B": "b"}) else: - with pytest.raises(TypeError): + with pytest.raises(TypeError, match=r"Could not instantiate .* given concrete .*"): TraitSketch(spec=Equal.spec, schema=schema, bindings={"A": "a", "B": "b"}) diff --git a/tests/renkon/core/model/type/test_base.py b/tests/renkon/core/model/type/test_base.py index 208f22f..59129e6 100644 --- a/tests/renkon/core/model/type/test_base.py +++ b/tests/renkon/core/model/type/test_base.py @@ -29,7 +29,7 @@ def test_type_model_dump_primitive(): def test_type_model_dump_union(): assert union(int_(), float_()).model_dump() == "float | int" assert union(int_(), str_()).model_dump() == "int | string" - assert union().model_dump() == "⊥ | ⊥" + assert union().model_dump() == "none | none" def test_type_model_dump_any(): @@ -158,6 +158,9 @@ def test_union_intersect(): assert union(int_(), float_()).intersect(union(str_(), bool_())) == union() +def test_union_intersect_any(): + assert union(any_()).intersect(union(int_(), str_())) == union(int_(), str_()) + def test_union_dump_python(): assert union(int_(), float_()).model_dump() == "float | int"