Skip to content

Commit

Permalink
chore: split NadaType and NadaValue / refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Jmgr committed Nov 21, 2024
1 parent 5098d7b commit 6ff849b
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 93 deletions.
12 changes: 11 additions & 1 deletion nada_dsl/nada_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from dataclasses import dataclass
from enum import Enum
from typing import Dict, TypeAlias, Union, Type
from typing import Any, Dict, TypeAlias, Union, Type
from nada_dsl.source_ref import SourceRef


Expand Down Expand Up @@ -171,3 +171,13 @@ def is_scalar(cls) -> bool:
def is_literal(cls) -> bool:
"""Returns True if the type is a literal."""
return False

def instantiate(self, child: Any) -> "NadaValue":
pass


@dataclass
class NadaValue:
@classmethod
def to_type(cls) -> NadaType:
pass
209 changes: 117 additions & 92 deletions nada_dsl/nada_types/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from nada_dsl.nada_types.function import NadaFunction, nada_fn
from nada_dsl.nada_types.generics import U, T, R
from . import AllTypes, AllTypesType, NadaTypeRepr, OperationType
from . import AllTypes, AllTypesType, NadaTypeRepr, NadaValue, OperationType


def is_primitive_integer(nada_type_str: str):
Expand All @@ -47,73 +47,10 @@ def is_primitive_integer(nada_type_str: str):
)


@dataclass
class Collection(NadaType):
"""Superclass of collection types"""

left_type: AllTypesType
right_type: AllTypesType
contained_type: AllTypesType

def to_mir(self):
"""Convert operation wrapper to a dictionary representing its type."""
if isinstance(self, (Array, ArrayType)):
size = {"size": self.size} if self.size else {}
contained_type = self.retrieve_inner_type()
return {"Array": {"inner_type": contained_type, **size}}
if isinstance(self, (Tuple, TupleType)):
return {
"Tuple": {
"left_type": (
self.left_type.to_mir()
if isinstance(self.left_type, (NadaType, ArrayType, TupleType))
else self.left_type.class_to_mir()
),
"right_type": (
self.right_type.to_mir()
if isinstance(
self.right_type,
(NadaType, ArrayType, TupleType),
)
else self.right_type.class_to_mir()
),
}
}
if isinstance(self, NTuple):
return {
"NTuple": {
"types": [
(
ty.to_mir()
if isinstance(ty, (NadaType, ArrayType, TupleType))
else ty.class_to_mir()
)
for ty in [
type(value)
for value in self.values # pylint: disable=E1101
]
]
}
}
if isinstance(self, Object):
return {
"Object": {
"types": {
name: (
ty.to_mir()
if isinstance(ty, (NadaType, ArrayType, TupleType))
else ty.class_to_mir()
)
for name, ty in [
(name, type(value))
for name, value in self.values.items() # pylint: disable=E1101
]
}
}
}
raise InvalidTypeError(
f"{self.__class__.__name__} is not a valid Nada Collection"
)

def retrieve_inner_type(self):
"""Retrieves the child type of this collection"""
if isinstance(self.contained_type, TypeVar):
Expand All @@ -123,6 +60,7 @@ def retrieve_inner_type(self):
return self.contained_type.to_mir()


@dataclass
class Map(Generic[T, R]):
"""The Map operation"""

Expand Down Expand Up @@ -203,9 +141,11 @@ def to_mir(self):
}


@dataclass
class Tuple(Generic[T, U], Collection):
"""The Tuple type"""

# TODO: T and U have to inherit from NadaType?
left_type: T
right_type: U

Expand All @@ -215,14 +155,33 @@ def __init__(self, child, left_type: T, right_type: U):
self.child = child
super().__init__(self.child)

def to_mir(self):
return {
"Tuple": {
"left_type": (
self.left_type.to_mir()
if isinstance(self.left_type, (NadaType, ArrayType, TupleType))
else self.left_type.class_to_mir()
),
"right_type": (
self.right_type.to_mir()
if isinstance(
self.right_type,
(NadaType, ArrayType, TupleType),
)
else self.right_type.class_to_mir()
),
}
}

@classmethod
def new(cls, left_type: T, right_type: U) -> "Tuple[T, U]":
def new(cls, left_value: NadaValue, right_value: NadaValue) -> "Tuple[T, U]":
"""Constructs a new Tuple."""
return Tuple(
left_type=left_type,
right_type=right_type,
left_type=left_value.to_type(),
right_type=right_value.to_type(),
child=TupleNew(
child=(left_type, right_type),
child=(left_value.to_type(), right_value.to_type()),
source_ref=SourceRef.back_frame(),
),
)
Expand All @@ -233,55 +192,70 @@ def generic_type(cls, left_type: U, right_type: T) -> TupleType:
return TupleType(left_type=left_type, right_type=right_type)


def _generate_accessor(value: Any, accessor: Any) -> NadaType:
ty = type(value)

def _generate_accessor(ty: Any, accessor: Any) -> NadaType:
if ty.is_scalar():
if ty.is_literal():
return value
return ty # value.instantiate(child=accessor) ?
return ty(child=accessor)
if ty == Array:
return Array(
child=accessor,
contained_type=value.contained_type,
size=value.size,
contained_type=ty.contained_type,
size=ty.size,
)
if ty == NTuple:
return NTuple(
child=accessor,
values=value.values,
types=ty.values,
)
if ty == Object:
return Object(
child=accessor,
values=value.values,
values=ty.values,
)
raise TypeError(f"Unsupported type for accessor: {ty}")


@dataclass
class NTupleType:
"""Marker type for NTuples."""

types: List[NadaType]

def to_mir(self):
"""Convert a tuple object into a Nada type."""
return {
"NTuple": {
"types": [ty.to_mir() for ty in self.types],
}
}


@dataclass
class NTuple(Collection):
"""The NTuple type"""

values: List[NadaType]
types: List[NadaValue]

def __init__(self, child, values: List[NadaType]):
self.values = values
def __init__(self, child, types: List[NadaType]):
self.types = types
self.child = child
super().__init__(self.child)

@classmethod
def new(cls, values: List[NadaType]) -> "NTuple":
def new(cls, values: List[NadaValue]) -> "NTuple":
"""Constructs a new NTuple."""
types = [value.to_type() for value in values]
return NTuple(
values=values,
types=types,
child=NTupleNew(
child=values,
child=types,
source_ref=SourceRef.back_frame(),
),
)

def __getitem__(self, index: int) -> NadaType:
if index >= len(self.values):
if index >= len(self.types):
raise IndexError(f"Invalid index {index} for NTuple.")

accessor = NTupleAccessor(
Expand All @@ -290,7 +264,21 @@ def __getitem__(self, index: int) -> NadaType:
source_ref=SourceRef.back_frame(),
)

return _generate_accessor(self.values[index], accessor)
return _generate_accessor(self.types[index], accessor)

def to_mir(self):
return {
"NTuple": {
"types": [
(
ty.to_mir()
if isinstance(ty, (NadaType, ArrayType, TupleType))
else ty.class_to_mir()
)
for ty in self.types
]
}
}


@dataclass
Expand Down Expand Up @@ -323,29 +311,42 @@ def store_in_ast(self, ty: object):
)


@dataclass
class ObjectType:
"""Marker type for Objects."""

types: Dict[str, NadaType]

def to_mir(self):
"""Convert an object into a Nada type."""
return {"Object": {name: ty.to_mir() for name, ty in self.types.items()}}


@dataclass
class Object(Collection):
"""The Object type"""

values: Dict[str, NadaType]
types: Dict[str, NadaType]

def __init__(self, child, values: Dict[str, NadaType]):
self.values = values
def __init__(self, child, types: Dict[str, NadaType]):
self.types = types
self.child = child
super().__init__(self.child)

@classmethod
def new(cls, values: Dict[str, NadaType]) -> "Object":
"""Constructs a new Object."""
types = {key: value.to_type() for key, value in values.items()}
return Object(
values=values,
types=types,
child=ObjectNew(
child=values,
child=types,
source_ref=SourceRef.back_frame(),
),
)

def __getattr__(self, attr: str) -> NadaType:
if attr not in self.values:
if attr not in self.types:
raise AttributeError(
f"'{self.__class__.__name__}' object has no attribute '{attr}'"
)
Expand All @@ -356,7 +357,21 @@ def __getattr__(self, attr: str) -> NadaType:
source_ref=SourceRef.back_frame(),
)

return _generate_accessor(self.values[attr], accessor)
return _generate_accessor(self.types[attr], accessor)

def to_mir(self):
return {
"Object": {
"types": {
name: (
ty.to_mir()
if isinstance(ty, (NadaType, ArrayType, TupleType))
else ty.class_to_mir()
)
for name, ty in self.types.items()
}
}
}


@dataclass
Expand Down Expand Up @@ -476,6 +491,7 @@ def to_mir(self):
}


@dataclass
class Array(Generic[T], Collection):
"""Nada Array type.
Expand Down Expand Up @@ -576,6 +592,11 @@ def inner_product(self: "Array[T]", other: "Array[T]") -> T:
"Inner product is only implemented for arrays of integer types"
)

def to_mir(self):
size = {"size": self.size} if self.size else {}
contained_type = self.retrieve_inner_type()
return {"Array": {"inner_type": contained_type, **size}}

@classmethod
def new(cls, *args) -> "Array[T]":
"""Constructs a new Array."""
Expand All @@ -601,6 +622,7 @@ def init_as_template_type(cls, contained_type) -> "Array[T]":
return Array(child=None, contained_type=contained_type, size=None)


@dataclass
class TupleNew(Generic[T, U]):
"""MIR Tuple new operation.
Expand All @@ -626,6 +648,7 @@ def store_in_ast(self, ty: object):
)


@dataclass
class NTupleNew:
"""MIR NTuple new operation.
Expand All @@ -651,6 +674,7 @@ def store_in_ast(self, ty: object):
)


@dataclass
class ObjectNew:
"""MIR Object new operation.
Expand Down Expand Up @@ -692,6 +716,7 @@ def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]:
)


@dataclass
class ArrayNew(Generic[T]):
"""MIR Array new operation"""

Expand Down

0 comments on commit 6ff849b

Please sign in to comment.