Skip to content

Commit

Permalink
chore: split NadaType and NadaValue / refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Jmgr committed Nov 20, 2024
1 parent 5098d7b commit e9c6324
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 72 deletions.
5 changes: 5 additions & 0 deletions nada_dsl/nada_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,8 @@ def is_scalar(cls) -> bool:
def is_literal(cls) -> bool:
"""Returns True if the type is a literal."""
return False


@dataclass
class NadaValue:
pass
173 changes: 101 additions & 72 deletions nada_dsl/nada_types/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from nada_dsl.nada_types.function import NadaFunction, nada_fn
from nada_dsl.nada_types.generics import U, T, R
from . import AllTypes, AllTypesType, NadaTypeRepr, OperationType
from . import AllTypes, AllTypesType, NadaTypeRepr, NadaValue, OperationType


def is_primitive_integer(nada_type_str: str):
Expand All @@ -47,73 +47,10 @@ def is_primitive_integer(nada_type_str: str):
)


@dataclass
class Collection(NadaType):
"""Superclass of collection types"""

left_type: AllTypesType
right_type: AllTypesType
contained_type: AllTypesType

def to_mir(self):
"""Convert operation wrapper to a dictionary representing its type."""
if isinstance(self, (Array, ArrayType)):
size = {"size": self.size} if self.size else {}
contained_type = self.retrieve_inner_type()
return {"Array": {"inner_type": contained_type, **size}}
if isinstance(self, (Tuple, TupleType)):
return {
"Tuple": {
"left_type": (
self.left_type.to_mir()
if isinstance(self.left_type, (NadaType, ArrayType, TupleType))
else self.left_type.class_to_mir()
),
"right_type": (
self.right_type.to_mir()
if isinstance(
self.right_type,
(NadaType, ArrayType, TupleType),
)
else self.right_type.class_to_mir()
),
}
}
if isinstance(self, NTuple):
return {
"NTuple": {
"types": [
(
ty.to_mir()
if isinstance(ty, (NadaType, ArrayType, TupleType))
else ty.class_to_mir()
)
for ty in [
type(value)
for value in self.values # pylint: disable=E1101
]
]
}
}
if isinstance(self, Object):
return {
"Object": {
"types": {
name: (
ty.to_mir()
if isinstance(ty, (NadaType, ArrayType, TupleType))
else ty.class_to_mir()
)
for name, ty in [
(name, type(value))
for name, value in self.values.items() # pylint: disable=E1101
]
}
}
}
raise InvalidTypeError(
f"{self.__class__.__name__} is not a valid Nada Collection"
)

def retrieve_inner_type(self):
"""Retrieves the child type of this collection"""
if isinstance(self.contained_type, TypeVar):
Expand All @@ -123,6 +60,7 @@ def retrieve_inner_type(self):
return self.contained_type.to_mir()


@dataclass
class Map(Generic[T, R]):
"""The Map operation"""

Expand Down Expand Up @@ -203,6 +141,7 @@ def to_mir(self):
}


@dataclass
class Tuple(Generic[T, U], Collection):
"""The Tuple type"""

Expand All @@ -215,6 +154,25 @@ def __init__(self, child, left_type: T, right_type: U):
self.child = child
super().__init__(self.child)

def to_mir(self):
return {
"Tuple": {
"left_type": (
self.left_type.to_mir()
if isinstance(self.left_type, (NadaType, ArrayType, TupleType))
else self.left_type.class_to_mir()
),
"right_type": (
self.right_type.to_mir()
if isinstance(
self.right_type,
(NadaType, ArrayType, TupleType),
)
else self.right_type.class_to_mir()
),
}
}

@classmethod
def new(cls, left_type: T, right_type: U) -> "Tuple[T, U]":
"""Constructs a new Tuple."""
Expand Down Expand Up @@ -259,12 +217,28 @@ def _generate_accessor(value: Any, accessor: Any) -> NadaType:
raise TypeError(f"Unsupported type for accessor: {ty}")


@dataclass
class NTupleType:
"""Marker type for NTuples."""

types: List[NadaType]

def to_mir(self):
"""Convert a tuple object into a Nada type."""
return {
"NTuple": {
"types": [ty.to_mir() for ty in self.types],
}
}


@dataclass
class NTuple(Collection):
"""The NTuple type"""

values: List[NadaType]
values: List[NadaValue]

def __init__(self, child, values: List[NadaType]):
def __init__(self, child, values: List[NadaValue]):
self.values = values
self.child = child
super().__init__(self.child)
Expand Down Expand Up @@ -292,6 +266,22 @@ def __getitem__(self, index: int) -> NadaType:

return _generate_accessor(self.values[index], accessor)

def to_mir(self):
return {
"NTuple": {
"types": [
(
ty.to_mir()
if isinstance(ty, (NadaType, ArrayType, TupleType))
else ty.class_to_mir()
)
for ty in [
type(value) for value in self.values # pylint: disable=E1101
]
]
}
}


@dataclass
class NTupleAccessor:
Expand Down Expand Up @@ -323,13 +313,25 @@ def store_in_ast(self, ty: object):
)


@dataclass
class ObjectType:
"""Marker type for Objects."""

types: Dict[str, NadaType]

def to_mir(self):
"""Convert an object into a Nada type."""
return {"Object": {name: ty.to_mir() for name, ty in self.types.items()}}


@dataclass
class Object(Collection):
"""The Object type"""

values: Dict[str, NadaType]
types: Dict[str, NadaType]

def __init__(self, child, values: Dict[str, NadaType]):
self.values = values
def __init__(self, child, values: Dict[str, NadaValue]):
self.types = values
self.child = child
super().__init__(self.child)

Expand All @@ -345,7 +347,7 @@ def new(cls, values: Dict[str, NadaType]) -> "Object":
)

def __getattr__(self, attr: str) -> NadaType:
if attr not in self.values:
if attr not in self.types:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
)
Expand All @@ -356,7 +358,24 @@ def __getattr__(self, attr: str) -> NadaType:
source_ref=SourceRef.back_frame(),
)

return _generate_accessor(self.values[attr], accessor)
return _generate_accessor(self.types[attr], accessor)

def to_mir(self):
return {
"Object": {
"types": {
name: (
ty.to_mir()
if isinstance(ty, (NadaType, ArrayType, TupleType))
else ty.class_to_mir()
)
for name, ty in [
(name, type(value))
for name, value in self.types.items() # pylint: disable=E1101
]
}
}
}


@dataclass
Expand Down Expand Up @@ -476,6 +495,7 @@ def to_mir(self):
}


@dataclass
class Array(Generic[T], Collection):
"""Nada Array type.
Expand Down Expand Up @@ -576,6 +596,11 @@ def inner_product(self: "Array[T]", other: "Array[T]") -> T:
"Inner product is only implemented for arrays of integer types"
)

def to_mir(self):
size = {"size": self.size} if self.size else {}
contained_type = self.retrieve_inner_type()
return {"Array": {"inner_type": contained_type, **size}}

@classmethod
def new(cls, *args) -> "Array[T]":
"""Constructs a new Array."""
Expand All @@ -601,6 +626,7 @@ def init_as_template_type(cls, contained_type) -> "Array[T]":
return Array(child=None, contained_type=contained_type, size=None)


@dataclass
class TupleNew(Generic[T, U]):
"""MIR Tuple new operation.
Expand All @@ -626,6 +652,7 @@ def store_in_ast(self, ty: object):
)


@dataclass
class NTupleNew:
"""MIR NTuple new operation.
Expand All @@ -651,6 +678,7 @@ def store_in_ast(self, ty: object):
)


@dataclass
class ObjectNew:
"""MIR Object new operation.
Expand Down Expand Up @@ -692,6 +720,7 @@ def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]:
)


@dataclass
class ArrayNew(Generic[T]):
"""MIR Array new operation"""

Expand Down

0 comments on commit e9c6324

Please sign in to comment.