Skip to content

Commit

Permalink
chore: remove MIR changes for now -- makes this PR non-breaking
Browse files Browse the repository at this point in the history
  • Loading branch information
Jmgr committed Nov 14, 2024
1 parent ed72769 commit 5507db3
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 84 deletions.
24 changes: 12 additions & 12 deletions nada_dsl/ast_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,17 +88,17 @@ class UnaryASTOperation(ASTOperation):
"""Superclass of all the unary operations in AST representation"""

name: str
this: int
child: int

def child_operations(self):
return [self.this]
return [self.child]

def to_mir(self):

return {
self.name: {
"id": self.id,
"this": self.this,
"this": self.child,
"type": self.ty,
"source_ref_index": self.source_ref.to_index(),
}
Expand All @@ -109,20 +109,20 @@ def to_mir(self):
class IfElseASTOperation(ASTOperation):
"""AST Representation of an IfElse operation."""

this: int
arg_0: int
arg_1: int
condition: int
true_branch_child: int
false_branch_child: int

def child_operations(self):
return [self.this, self.arg_0, self.arg_1]
return [self.condition, self.true_branch_child, self.false_branch_child]

def to_mir(self):
return {
"IfElse": {
"id": self.id,
"this": self.this,
"arg_0": self.arg_0,
"arg_1": self.arg_1,
"this": self.condition,
"arg_0": self.true_branch_child,
"arg_1": self.false_branch_child,
"type": self.ty,
"source_ref_index": self.source_ref.to_index(),
}
Expand Down Expand Up @@ -227,7 +227,7 @@ def to_mir(self):
"Reduce": {
"id": self.id,
"fn": self.fn,
"child": self.child,
"inner": self.child,
"initial": self.initial,
"type": self.ty,
"source_ref_index": self.source_ref.to_index(),
Expand All @@ -250,7 +250,7 @@ def to_mir(self):
"Map": {
"id": self.id,
"fn": self.fn,
"child": self.child,
"inner": self.child,
"type": self.ty,
"source_ref_index": self.source_ref.to_index(),
}
Expand Down
110 changes: 56 additions & 54 deletions nada_dsl/nada_types/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ class Collection(NadaType):

left_type: AllTypesType
right_type: AllTypesType
contained_types: AllTypesType
contained_type: AllTypesType

def to_mir(self):
"""Convert operation wrapper to a dictionary representing its type."""
if isinstance(self, (Array, ArrayType)):
size = {"size": self.size} if self.size else {}
contained_types = self.retrieve_inner_type()
return {"Array": {"contained_types": contained_types, **size}}
contained_type = self.retrieve_inner_type()
return {"Array": {"inner_type": contained_type, **size}}
if isinstance(self, (Vector, VectorType)):
contained_types = self.retrieve_inner_type()
return {"Vector": {"contained_types": contained_types}}
contained_type = self.retrieve_inner_type()
return {"Vector": {"inner_type": contained_type}}
if isinstance(self, (Tuple, TupleType)):
return {
"Tuple": {
Expand All @@ -82,11 +82,11 @@ def to_mir(self):

def retrieve_inner_type(self):
"""Retrieves the child type of this collection"""
if isinstance(self.contained_types, TypeVar):
if isinstance(self.contained_type, TypeVar):
return "T"
if inspect.isclass(self.contained_types):
return self.contained_types.class_to_type()
return self.contained_types.to_mir()
if inspect.isclass(self.contained_type):
return self.contained_type.class_to_type()
return self.contained_type.to_mir()


class Map(Generic[T, R]):
Expand Down Expand Up @@ -295,11 +295,13 @@ def to_mir(self):
}


def get_inner_type(contained_types):
"""Utility that returns the child type for a composite type."""
contained_types = copy.copy(contained_types)
setattr(contained_types, "child", None)
return contained_types
# pylint: disable=W0511
# TODO: remove this
def get_inner_type(inner_type):
"""Utility that returns the inner type for a composite type."""
inner_type = copy.copy(inner_type)
setattr(inner_type, "inner", None)
return inner_type


class Zip:
Expand Down Expand Up @@ -336,7 +338,7 @@ def store_in_ast(self, ty: NadaTypeRepr):
AST_OPERATIONS[self.id] = UnaryASTOperation(
id=self.id,
name="Unzip",
this=self.child.child.id,
child=self.child.child.id,
source_ref=self.source_ref,
ty=ty,
)
Expand Down Expand Up @@ -367,14 +369,14 @@ def store_in_ast(self, ty: NadaTypeRepr):
class ArrayType:
"""Marker type for arrays."""

contained_types: AllTypesType
contained_type: AllTypesType
size: int

def to_mir(self):
"""Convert this generic type into a MIR Nada type."""
return {
"Array": {
"contained_types": self.contained_types.to_mir(),
"inner_type": self.contained_type.to_mir(),
"size": self.size,
}
}
Expand All @@ -388,26 +390,26 @@ class Array(Generic[T], Collection):
Attributes
----------
contained_types: T
contained_type: T
The type of the array
child:
The optional child operation
size: int
The size of the array
"""

contained_types: T
contained_type: T
size: int

def __init__(self, child, size: int, contained_types: T = None):
self.contained_types = (
contained_types
if (child is None or contained_types is not None)
def __init__(self, child, size: int, contained_type: T = None):
self.contained_type = (
contained_type
if (child is None or contained_type is not None)
else get_inner_type(child)
)
self.size = size
self.child = (
child if contained_types is not None else getattr(child, "child", None)
child if contained_type is not None else getattr(child, "child", None)
)
if self.child is not None:
self.child.store_in_ast(self.to_mir())
Expand All @@ -424,7 +426,7 @@ def map(self: "Array[T]", function) -> "Array":
nada_function = nada_fn(function)
return Array(
size=self.size,
contained_types=nada_function.return_type,
contained_type=nada_function.return_type,
child=Map(child=self, fn=nada_function, source_ref=SourceRef.back_frame()),
)

Expand All @@ -447,9 +449,9 @@ def zip(self: "Array[T]", other: "Array[U]") -> "Array[Tuple[T, U]]":
raise IncompatibleTypesError("Cannot zip arrays of different size")
return Array(
size=self.size,
contained_types=Tuple(
left_type=self.contained_types,
right_type=other.contained_types,
contained_type=Tuple(
left_type=self.contained_type,
right_type=other.contained_type,
child=None,
),
child=Zip(left=self, right=other, source_ref=SourceRef.back_frame()),
Expand All @@ -465,12 +467,12 @@ def inner_product(self: "Array[T]", other: "Array[T]") -> T:
if is_primitive_integer(self.retrieve_inner_type()) and is_primitive_integer(
other.retrieve_inner_type()
):
contained_types = (
self.contained_types
if inspect.isclass(self.contained_types)
else self.contained_types.__class__
contained_type = (
self.contained_type
if inspect.isclass(self.contained_type)
else self.contained_type.__class__
)
return contained_types(
return contained_type(
child=InnerProduct(
left=self, right=other, source_ref=SourceRef.back_frame()
)
Expand All @@ -491,7 +493,7 @@ def new(cls, *args) -> "Array[T]":
raise TypeError("All arguments must be of the same type")

return Array(
contained_types=first_arg,
contained_type=first_arg,
size=len(args),
child=ArrayNew(
child=args,
Expand All @@ -500,21 +502,21 @@ def new(cls, *args) -> "Array[T]":
)

@classmethod
def generic_type(cls, contained_types: T, size: int) -> ArrayType:
def generic_type(cls, contained_type: T, size: int) -> ArrayType:
"""Return the generic type of the Array."""
return ArrayType(contained_types=contained_types, size=size)
return ArrayType(contained_type=contained_type, size=size)

@classmethod
def init_as_template_type(cls, contained_types) -> "Array[T]":
def init_as_template_type(cls, contained_type) -> "Array[T]":
"""Construct an empty template array with the given child type."""
return Array(child=None, contained_types=contained_types, size=None)
return Array(child=None, contained_type=contained_type, size=None)


@dataclass
class VectorType(Collection):
"""The generic type for Vectors."""

contained_types: AllTypesType
contained_type: AllTypesType


@dataclass
Expand All @@ -527,17 +529,17 @@ class Vector(Generic[T], Collection):
its size may change at runtime.
"""

contained_types: T
contained_type: T
size: int

def __init__(self, child, size, contained_types=None):
self.contained_types = (
contained_types
if (child is None or contained_types is not None)
def __init__(self, child, size, contained_type=None):
self.contained_type = (
contained_type
if (child is None or contained_type is not None)
else get_inner_type(child)
)
self.size = size
self.child = child if contained_types else getattr(child, "child", None)
self.child = child if contained_type else getattr(child, "child", None)
self.child.store_in_ast(self.to_mir())

def __iter__(self):
Expand All @@ -550,16 +552,16 @@ def map(self: "Vector[T]", function: NadaFunction[T, R]) -> "Vector[R]":
"""The map operation for Nada Vectors."""
return Vector(
size=self.size,
contained_types=function.return_type,
contained_type=function.return_type,
child=(Map(child=self, fn=function, source_ref=SourceRef.back_frame())),
)

def zip(self: "Vector[T]", other: "Vector[R]") -> "Vector[Tuple[T, R]]":
"""The Zip operation for Nada Vectors."""
return Vector(
size=self.size,
contained_types=Tuple.generic_type(
self.contained_types, other.contained_types
contained_type=Tuple.generic_type(
self.contained_type, other.contained_type
),
child=Zip(left=self, right=other, source_ref=SourceRef.back_frame()),
)
Expand All @@ -578,14 +580,14 @@ def reduce(
) # type: ignore

@classmethod
def generic_type(cls, contained_types: T) -> VectorType:
def generic_type(cls, contained_type: T) -> VectorType:
"""Returns the generic type for a Vector with the given child type."""
return VectorType(child=None, contained_types=contained_types)
return VectorType(child=None, contained_type=contained_type)

@classmethod
def init_as_template_type(cls, contained_types) -> "Vector[T]":
def init_as_template_type(cls, contained_type) -> "Vector[T]":
"""Construct an empty Vector with the given child type."""
return Vector(child=None, contained_types=contained_types, size=None)
return Vector(child=None, contained_type=contained_type, size=None)


class TupleNew(Generic[T, U]):
Expand Down Expand Up @@ -666,10 +668,10 @@ def store_in_ast(self, ty: object):
def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]:
"""The Unzip operation for Arrays."""
right_type = ArrayType(
contained_types=array.contained_types.right_type, size=array.size
contained_type=array.contained_type.right_type, size=array.size
)
left_type = ArrayType(
contained_types=array.contained_types.left_type, size=array.size
contained_type=array.contained_type.left_type, size=array.size
)

return Tuple(
Expand Down
8 changes: 4 additions & 4 deletions nada_dsl/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def store_in_ast(self, ty: object):
AST_OPERATIONS[self.id] = UnaryASTOperation(
id=self.id,
name=self.__class__.__name__,
this=self.child.child.id,
child=self.child.child.id,
source_ref=self.source_ref,
ty=ty,
)
Expand Down Expand Up @@ -167,9 +167,9 @@ def store_in_ast(self, ty):
"""Store object in AST."""
AST_OPERATIONS[self.id] = IfElseASTOperation(
id=self.id,
this=self.this.child.id,
arg_0=self.arg_0.child.id,
arg_1=self.arg_1.child.id,
condition=self.this.child.id,
true_branch_child=self.arg_0.child.id,
false_branch_child=self.arg_1.child.id,
ty=ty,
source_ref=self.source_ref,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,12 @@ def test_compile_map_simple():
if name == "InputReference":
array_input_id = op_id
assert op["type"] == {
"Array": {"contained_types": "SecretInteger", "size": 3}
"Array": {"inner_type": "SecretInteger", "size": 3}
}
operations_found += 1
elif name == "Map":
assert op["fn"] == function_id
map_inner = op["child"]
map_inner = op["inner"]
function_op_id = op["id"]
operations_found += 1
else:
Expand Down
Loading

0 comments on commit 5507db3

Please sign in to comment.