From e9c6324ff99dc240b22189bfe1eaae8a7d4863e1 Mon Sep 17 00:00:00 2001 From: Jmgr Date: Wed, 20 Nov 2024 21:50:38 +0000 Subject: [PATCH] chore: split NadaType and NadaValue / refactoring --- nada_dsl/nada_types/__init__.py | 5 + nada_dsl/nada_types/collections.py | 173 +++++++++++++++++------------ 2 files changed, 106 insertions(+), 72 deletions(-) diff --git a/nada_dsl/nada_types/__init__.py b/nada_dsl/nada_types/__init__.py index f65a47f..11055fe 100644 --- a/nada_dsl/nada_types/__init__.py +++ b/nada_dsl/nada_types/__init__.py @@ -171,3 +171,8 @@ def is_scalar(cls) -> bool: def is_literal(cls) -> bool: """Returns True if the type is a literal.""" return False + + +@dataclass +class NadaValue: + pass diff --git a/nada_dsl/nada_types/collections.py b/nada_dsl/nada_types/collections.py index dbba76a..b131734 100644 --- a/nada_dsl/nada_types/collections.py +++ b/nada_dsl/nada_types/collections.py @@ -29,7 +29,7 @@ ) from nada_dsl.nada_types.function import NadaFunction, nada_fn from nada_dsl.nada_types.generics import U, T, R -from . import AllTypes, AllTypesType, NadaTypeRepr, OperationType +from . import AllTypes, AllTypesType, NadaTypeRepr, NadaValue, OperationType def is_primitive_integer(nada_type_str: str): @@ -47,73 +47,10 @@ def is_primitive_integer(nada_type_str: str): ) +@dataclass class Collection(NadaType): """Superclass of collection types""" - left_type: AllTypesType - right_type: AllTypesType - contained_type: AllTypesType - - def to_mir(self): - """Convert operation wrapper to a dictionary representing its type.""" - if isinstance(self, (Array, ArrayType)): - size = {"size": self.size} if self.size else {} - contained_type = self.retrieve_inner_type() - return {"Array": {"inner_type": contained_type, **size}} - if isinstance(self, (Tuple, TupleType)): - return { - "Tuple": { - "left_type": ( - self.left_type.to_mir() - if isinstance(self.left_type, (NadaType, ArrayType, TupleType)) - else self.left_type.class_to_mir() - ), - "right_type": ( - self.right_type.to_mir() - if isinstance( - self.right_type, - (NadaType, ArrayType, TupleType), - ) - else self.right_type.class_to_mir() - ), - } - } - if isinstance(self, NTuple): - return { - "NTuple": { - "types": [ - ( - ty.to_mir() - if isinstance(ty, (NadaType, ArrayType, TupleType)) - else ty.class_to_mir() - ) - for ty in [ - type(value) - for value in self.values # pylint: disable=E1101 - ] - ] - } - } - if isinstance(self, Object): - return { - "Object": { - "types": { - name: ( - ty.to_mir() - if isinstance(ty, (NadaType, ArrayType, TupleType)) - else ty.class_to_mir() - ) - for name, ty in [ - (name, type(value)) - for name, value in self.values.items() # pylint: disable=E1101 - ] - } - } - } - raise InvalidTypeError( - f"{self.__class__.__name__} is not a valid Nada Collection" - ) - def retrieve_inner_type(self): """Retrieves the child type of this collection""" if isinstance(self.contained_type, TypeVar): @@ -123,6 +60,7 @@ def retrieve_inner_type(self): return self.contained_type.to_mir() +@dataclass class Map(Generic[T, R]): """The Map operation""" @@ -203,6 +141,7 @@ def to_mir(self): } +@dataclass class Tuple(Generic[T, U], Collection): """The Tuple type""" @@ -215,6 +154,25 @@ def __init__(self, child, left_type: T, right_type: U): self.child = child super().__init__(self.child) + def to_mir(self): + return { + "Tuple": { + "left_type": ( + self.left_type.to_mir() + if isinstance(self.left_type, (NadaType, ArrayType, TupleType)) + else self.left_type.class_to_mir() + ), + "right_type": ( + self.right_type.to_mir() + if isinstance( + self.right_type, + (NadaType, ArrayType, TupleType), + ) + else self.right_type.class_to_mir() + ), + } + } + @classmethod def new(cls, left_type: T, right_type: U) -> "Tuple[T, U]": """Constructs a new Tuple.""" @@ -259,12 +217,28 @@ def _generate_accessor(value: Any, accessor: Any) -> NadaType: raise TypeError(f"Unsupported type for accessor: {ty}") +@dataclass +class NTupleType: + """Marker type for NTuples.""" + + types: List[NadaType] + + def to_mir(self): + """Convert a tuple object into a Nada type.""" + return { + "NTuple": { + "types": [ty.to_mir() for ty in self.types], + } + } + + +@dataclass class NTuple(Collection): """The NTuple type""" - values: List[NadaType] + values: List[NadaValue] - def __init__(self, child, values: List[NadaType]): + def __init__(self, child, values: List[NadaValue]): self.values = values self.child = child super().__init__(self.child) @@ -292,6 +266,22 @@ def __getitem__(self, index: int) -> NadaType: return _generate_accessor(self.values[index], accessor) + def to_mir(self): + return { + "NTuple": { + "types": [ + ( + ty.to_mir() + if isinstance(ty, (NadaType, ArrayType, TupleType)) + else ty.class_to_mir() + ) + for ty in [ + type(value) for value in self.values # pylint: disable=E1101 + ] + ] + } + } + @dataclass class NTupleAccessor: @@ -323,13 +313,25 @@ def store_in_ast(self, ty: object): ) +@dataclass +class ObjectType: + """Marker type for Objects.""" + + types: Dict[str, NadaType] + + def to_mir(self): + """Convert an object into a Nada type.""" + return {"Object": {name: ty.to_mir() for name, ty in self.types.items()}} + + +@dataclass class Object(Collection): """The Object type""" - values: Dict[str, NadaType] + types: Dict[str, NadaType] - def __init__(self, child, values: Dict[str, NadaType]): - self.values = values + def __init__(self, child, values: Dict[str, NadaValue]): + self.types = values self.child = child super().__init__(self.child) @@ -345,7 +347,7 @@ def new(cls, values: Dict[str, NadaType]) -> "Object": ) def __getattr__(self, attr: str) -> NadaType: - if attr not in self.values: + if attr not in self.types: raise AttributeError( f"'{self.__class__.__name__}' object has no attribute '{attr}'" ) @@ -356,7 +358,24 @@ def __getattr__(self, attr: str) -> NadaType: source_ref=SourceRef.back_frame(), ) - return _generate_accessor(self.values[attr], accessor) + return _generate_accessor(self.types[attr], accessor) + + def to_mir(self): + return { + "Object": { + "types": { + name: ( + ty.to_mir() + if isinstance(ty, (NadaType, ArrayType, TupleType)) + else ty.class_to_mir() + ) + for name, ty in [ + (name, type(value)) + for name, value in self.types.items() # pylint: disable=E1101 + ] + } + } + } @dataclass @@ -476,6 +495,7 @@ def to_mir(self): } +@dataclass class Array(Generic[T], Collection): """Nada Array type. @@ -576,6 +596,11 @@ def inner_product(self: "Array[T]", other: "Array[T]") -> T: "Inner product is only implemented for arrays of integer types" ) + def to_mir(self): + size = {"size": self.size} if self.size else {} + contained_type = self.retrieve_inner_type() + return {"Array": {"inner_type": contained_type, **size}} + @classmethod def new(cls, *args) -> "Array[T]": """Constructs a new Array.""" @@ -601,6 +626,7 @@ def init_as_template_type(cls, contained_type) -> "Array[T]": return Array(child=None, contained_type=contained_type, size=None) +@dataclass class TupleNew(Generic[T, U]): """MIR Tuple new operation. @@ -626,6 +652,7 @@ def store_in_ast(self, ty: object): ) +@dataclass class NTupleNew: """MIR NTuple new operation. @@ -651,6 +678,7 @@ def store_in_ast(self, ty: object): ) +@dataclass class ObjectNew: """MIR Object new operation. @@ -692,6 +720,7 @@ def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]: ) +@dataclass class ArrayNew(Generic[T]): """MIR Array new operation"""