Skip to content

Commit

Permalink
cleanup pt 2
Browse files Browse the repository at this point in the history
  • Loading branch information
DylanLukes committed Aug 21, 2024
1 parent d013ebc commit fd2dfd0
Show file tree
Hide file tree
Showing 7 changed files with 188 additions and 178 deletions.
10 changes: 10 additions & 0 deletions src/renkon/core/model/type/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@
"BottomType",
"tyconv_rk_to_pl",
"tyconv_pl_to_rk",
"int_",
"float_",
"str_",
"bool_",
"bottom",
"union",
"equatable",
"comparable",
"numeric",
]

from renkon.core.model.type.base import (
Expand All @@ -27,5 +36,6 @@
UnionType,
TypeStr,
is_type_str,
int_, float_, str_, bool_, bottom, union, equatable, comparable, numeric,
)
from renkon.core.model.type.convert import tyconv_pl_to_rk, tyconv_rk_to_pl
100 changes: 49 additions & 51 deletions src/renkon/core/model/type/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,56 @@
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations

import functools
from abc import ABC, abstractmethod
from collections.abc import Hashable
from functools import lru_cache
from typing import Any, ClassVar, Self, override, Literal, Union, Annotated, TypeGuard

from lark.exceptions import LarkError
from annotated_types import Predicate
from lark import Transformer
from lark.exceptions import LarkError
from pydantic import BaseModel, GetCoreSchemaHandler
from pydantic_core import CoreSchema
from pydantic_core import core_schema as cs

from renkon.core.model.type.parser import parser


def int_() -> IntType:
return IntType()


def float_() -> FloatType:
return FloatType()


def str_() -> StringType:
return StringType()


def bool_() -> BoolType:
return BoolType()


def bottom() -> BottomType:
return BottomType()


def union(*types: Type) -> UnionType:
return UnionType(ts=frozenset(types)).canonicalize()


def equatable() -> UnionType:
return union(int_(), str_(), bool_())


def comparable() -> UnionType:
return union(int_(), float_(), str_())


def numeric() -> UnionType:
return union(int_(), float_())


def is_type_str(s: str) -> TypeGuard[TypeStr]:
try:
Type.parse_string(s)
Expand Down Expand Up @@ -99,13 +133,13 @@ def parse_string(cls, s: str) -> Type:
raise ValueError(msg) from e

def is_numeric(self) -> bool:
return self.is_subtype(Type.numeric())
return self.is_subtype(numeric())

def is_equatable(self) -> bool:
return self.is_subtype(Type.equatable())
return self.is_subtype(equatable())

def is_comparable(self) -> bool:
return self.is_subtype(Type.comparable())
return self.is_subtype(comparable())

def __str__(self) -> str:
return self.dump_string()
Expand Down Expand Up @@ -140,42 +174,6 @@ def __get_pydantic_core_schema__(cls, source: type[BaseModel], handler: GetCoreS
serialization=serializer,
)

@staticmethod
def bottom() -> BottomType:
return BottomType()

@staticmethod
def int() -> IntType:
return IntType()

@staticmethod
def float() -> FloatType:
return FloatType()

@staticmethod
def str() -> StringType:
return StringType()

@staticmethod
def bool() -> BoolType:
return BoolType()

@staticmethod
def union(*types: Type) -> UnionType:
return UnionType(ts=frozenset(types))

@staticmethod
def equatable() -> UnionType:
return Type.union(Type.int(), Type.str(), Type.bool())

@staticmethod
def comparable() -> UnionType:
return Type.union(Type.int(), Type.float(), Type.str())

@staticmethod
def numeric() -> UnionType:
return Type.union(Type.int(), Type.float())


class BottomType(Type):
def is_equal(self, other: Type) -> bool:
Expand Down Expand Up @@ -374,31 +372,31 @@ def type(self, type_: list[Type]):
return type_[0]

def int(self, _) -> IntType:
return Type.int()
return int_()

def float(self, _) -> FloatType:
return Type.float()
return float_()

def string(self, _) -> StringType:
return Type.str()
return str_()

def bool(self, _) -> BoolType:
return Type.bool()
return bool_()

def bottom(self, _) -> BottomType:
return Type.bottom()
return bottom()

def union(self, types: list[Type]) -> UnionType:
return Type.union(*types).canonicalize()
return union(*types)

def equatable(self, _) -> UnionType:
return Type.equatable()
return equatable()

def comparable(self, _) -> UnionType:
return Type.comparable()
return comparable()

def numeric(self, _) -> UnionType:
return Type.numeric()
return numeric()

def paren(self, type_: list[Type]) -> Type:
return type_[0]
Expand Down
22 changes: 11 additions & 11 deletions src/renkon/core/model/type/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,44 @@

import polars as pl

from renkon.core.model.type.base import Type as RenkonType
from renkon.core.model import type as rk


def tyconv_pl_to_rk(pl_ty: pl.PolarsDataType) -> RenkonType:
def tyconv_pl_to_rk(pl_ty: pl.PolarsDataType) -> rk.Type:
"""
Convert a Polars data type to a Renkon data type.
"""
if pl_ty.is_integer():
return RenkonType.int()
return rk.int_()

if pl_ty.is_float():
return RenkonType.float()
return rk.float_()

if pl_ty.is_(pl.String):
return RenkonType.str()
return rk.str_()

if pl_ty.is_(pl.Boolean):
return RenkonType.bool()
return rk.bool_()

msg = f"Unsupported Polars data type: {pl_ty}"
raise ValueError(msg)


def tyconv_rk_to_pl(rk_ty: RenkonType) -> pl.PolarsDataType:
def tyconv_rk_to_pl(rk_ty: rk.Type) -> pl.PolarsDataType:
"""
Convert a Renkon data type to a Polars data type.
"""

if rk_ty.is_equal(RenkonType.int()):
if rk_ty.is_equal(rk.int_()):
return pl.Int64

if rk_ty.is_equal(RenkonType.float()) or rk_ty.is_equal(RenkonType.int() | RenkonType.float()):
if rk_ty.is_equal(rk.float_()) or rk_ty.is_equal(rk.int_() | rk.float_()):
return pl.Float64

if rk_ty.is_equal(RenkonType.str()):
if rk_ty.is_equal(rk.str_()):
return pl.Utf8

if rk_ty.is_equal(RenkonType.bool()):
if rk_ty.is_equal(rk.bool_()):
return pl.Boolean

msg = f"Unsupported Renkon data type: {rk_ty}"
Expand Down
9 changes: 5 additions & 4 deletions src/renkon/core/traits/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from renkon.core.model import TraitId, TraitKind, TraitPattern, TraitSketch, TraitSpec
from renkon.core.model.type import Type
import renkon.core.model.type as rk_type


class Trait(Protocol):
Expand Down Expand Up @@ -63,9 +64,9 @@ class Linear2(Trait):
kind=TraitKind.MODEL,
pattern=TraitPattern("{Y} = {a}*{X} + {b}"),
typings={
"X": Type.numeric(),
"Y": Type.numeric(),
"a": Type.float(),
"b": Type.float(),
"X": rk_type.numeric(),
"Y": rk_type.numeric(),
"a": rk_type.float_(),
"b": rk_type.float_(),
},
)
Loading

0 comments on commit fd2dfd0

Please sign in to comment.