Skip to content

Commit

Permalink
chore: rename inner_type into contained_types
Browse files Browse the repository at this point in the history
  • Loading branch information
Jmgr committed Nov 14, 2024
1 parent bc081d8 commit 0d75641
Show file tree
Hide file tree
Showing 6 changed files with 77 additions and 64 deletions.
2 changes: 1 addition & 1 deletion nada_dsl/nada_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(self, name):
"""Type alias for the NadaType representation.
This representation can be either a string ("SecretInteger")
or a dictionary (Array{inner_type=SecretInteger, size=3}).
or a dictionary (Array{contained_types=SecretInteger, size=3}).
"""


Expand Down
113 changes: 63 additions & 50 deletions nada_dsl/nada_types/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,17 @@ class Collection(NadaType):

left_type: AllTypesType
right_type: AllTypesType
inner_type: AllTypesType
contained_types: AllTypesType

def to_type(self):
"""Convert operation wrapper to a dictionary representing its type."""
if isinstance(self, (Array, ArrayType)):
size = {"size": self.size} if self.size else {}
inner_type = self.retrieve_inner_type()
return {"Array": {"inner_type": inner_type, **size}}
contained_types = self.retrieve_inner_type()
return {"Array": {"contained_types": contained_types, **size}}
if isinstance(self, (Vector, VectorType)):
inner_type = self.retrieve_inner_type()
return {"Vector": {"inner_type": inner_type}}
contained_types = self.retrieve_inner_type()
return {"Vector": {"contained_types": contained_types}}
if isinstance(self, (Tuple, TupleType)):
return {
"Tuple": {
Expand All @@ -82,11 +82,11 @@ def to_type(self):

def retrieve_inner_type(self):
"""Retrieves the parent type of this collection"""
if isinstance(self.inner_type, TypeVar):
if isinstance(self.contained_types, TypeVar):
return "T"
if inspect.isclass(self.inner_type):
return self.inner_type.class_to_type()
return self.inner_type.to_type()
if inspect.isclass(self.contained_types):
return self.contained_types.class_to_type()
return self.contained_types.to_type()


class Map(Generic[T, R]):
Expand Down Expand Up @@ -295,11 +295,11 @@ def to_type(self):
}


def get_inner_type(inner_type):
def get_inner_type(contained_types):
"""Utility that returns the parent type for a composite type."""
inner_type = copy.copy(inner_type)
setattr(inner_type, "parent", None)
return inner_type
contained_types = copy.copy(contained_types)
setattr(contained_types, "parent", None)
return contained_types


class Zip:
Expand Down Expand Up @@ -367,12 +367,17 @@ def store_in_ast(self, ty: NadaTypeRepr):
class ArrayType:
"""Marker type for arrays."""

inner_type: AllTypesType
contained_types: AllTypesType
size: int

def to_type(self):
"""Convert this generic type into a MIR Nada type."""
return {"Array": {"inner_type": self.inner_type.to_type(), "size": self.size}}
return {
"Array": {
"contained_types": self.contained_types.to_type(),
"size": self.size,
}
}


class Array(Generic[T], Collection):
Expand All @@ -383,26 +388,26 @@ class Array(Generic[T], Collection):
Attributes
----------
inner_type: T
contained_types: T
The type of the array
parent:
The optional parent operation
size: int
The size of the array
"""

inner_type: T
contained_types: T
size: int

def __init__(self, parent, size: int, inner_type: T = None):
self.inner_type = (
inner_type
if (parent is None or inner_type is not None)
def __init__(self, parent, size: int, contained_types: T = None):
self.contained_types = (
contained_types
if (parent is None or contained_types is not None)
else get_inner_type(parent)
)
self.size = size
self.parent = (
parent if inner_type is not None else getattr(parent, "parent", None)
parent if contained_types is not None else getattr(parent, "parent", None)
)
if self.parent is not None:
self.parent.store_in_ast(self.to_type())
Expand All @@ -419,7 +424,7 @@ def map(self: "Array[T]", function) -> "Array":
nada_function = nada_fn(function)
return Array(
size=self.size,
inner_type=nada_function.return_type,
contained_types=nada_function.return_type,
parent=Map(
parent=self, fn=nada_function, source_ref=SourceRef.back_frame()
),
Expand All @@ -444,8 +449,10 @@ def zip(self: "Array[T]", other: "Array[U]") -> "Array[Tuple[T, U]]":
raise IncompatibleTypesError("Cannot zip arrays of different size")
return Array(
size=self.size,
inner_type=Tuple(
left_type=self.inner_type, right_type=other.inner_type, parent=None
contained_types=Tuple(
left_type=self.contained_types,
right_type=other.contained_types,
parent=None,
),
parent=Zip(left=self, right=other, source_ref=SourceRef.back_frame()),
)
Expand All @@ -460,12 +467,12 @@ def inner_product(self: "Array[T]", other: "Array[T]") -> T:
if is_primitive_integer(self.retrieve_inner_type()) and is_primitive_integer(
other.retrieve_inner_type()
):
inner_type = (
self.inner_type
if inspect.isclass(self.inner_type)
else self.inner_type.__class__
contained_types = (
self.contained_types
if inspect.isclass(self.contained_types)
else self.contained_types.__class__
)
return inner_type(
return contained_types(
parent=InnerProduct(
left=self, right=other, source_ref=SourceRef.back_frame()
)
Expand All @@ -486,7 +493,7 @@ def new(cls, *args) -> "Array[T]":
raise TypeError("All arguments must be of the same type")

return Array(
inner_type=first_arg,
contained_types=first_arg,
size=len(args),
parent=ArrayNew(
parent=args,
Expand All @@ -495,21 +502,21 @@ def new(cls, *args) -> "Array[T]":
)

@classmethod
def generic_type(cls, inner_type: T, size: int) -> ArrayType:
def generic_type(cls, contained_types: T, size: int) -> ArrayType:
"""Return the generic type of the Array."""
return ArrayType(inner_type=inner_type, size=size)
return ArrayType(contained_types=contained_types, size=size)

@classmethod
def init_as_template_type(cls, inner_type) -> "Array[T]":
def init_as_template_type(cls, contained_types) -> "Array[T]":
"""Construct an empty template array with the given parent type."""
return Array(parent=None, inner_type=inner_type, size=None)
return Array(parent=None, contained_types=contained_types, size=None)


@dataclass
class VectorType(Collection):
"""The generic type for Vectors."""

inner_type: AllTypesType
contained_types: AllTypesType


@dataclass
Expand All @@ -522,17 +529,17 @@ class Vector(Generic[T], Collection):
its size may change at runtime.
"""

inner_type: T
contained_types: T
size: int

def __init__(self, parent, size, inner_type=None):
self.inner_type = (
inner_type
if (parent is None or inner_type is not None)
def __init__(self, parent, size, contained_types=None):
self.contained_types = (
contained_types
if (parent is None or contained_types is not None)
else get_inner_type(parent)
)
self.size = size
self.parent = parent if inner_type else getattr(parent, "parent", None)
self.parent = parent if contained_types else getattr(parent, "parent", None)
self.parent.store_in_ast(self.to_type())

def __iter__(self):
Expand All @@ -545,15 +552,17 @@ def map(self: "Vector[T]", function: NadaFunction[T, R]) -> "Vector[R]":
"""The map operation for Nada Vectors."""
return Vector(
size=self.size,
inner_type=function.return_type,
contained_types=function.return_type,
parent=(Map(parent=self, fn=function, source_ref=SourceRef.back_frame())),
)

def zip(self: "Vector[T]", other: "Vector[R]") -> "Vector[Tuple[T, R]]":
"""The Zip operation for Nada Vectors."""
return Vector(
size=self.size,
inner_type=Tuple.generic_type(self.inner_type, other.inner_type),
contained_types=Tuple.generic_type(
self.contained_types, other.contained_types
),
parent=Zip(left=self, right=other, source_ref=SourceRef.back_frame()),
)

Expand All @@ -571,14 +580,14 @@ def reduce(
) # type: ignore

@classmethod
def generic_type(cls, inner_type: T) -> VectorType:
def generic_type(cls, contained_types: T) -> VectorType:
"""Returns the generic type for a Vector with the given parent type."""
return VectorType(parent=None, inner_type=inner_type)
return VectorType(parent=None, contained_types=contained_types)

@classmethod
def init_as_template_type(cls, inner_type) -> "Vector[T]":
def init_as_template_type(cls, contained_types) -> "Vector[T]":
"""Construct an empty Vector with the given parent type."""
return Vector(parent=None, inner_type=inner_type, size=None)
return Vector(parent=None, contained_types=contained_types, size=None)


class TupleNew(Generic[T, U]):
Expand Down Expand Up @@ -658,8 +667,12 @@ def store_in_ast(self, ty: object):

def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]:
"""The Unzip operation for Arrays."""
right_type = ArrayType(inner_type=array.inner_type.right_type, size=array.size)
left_type = ArrayType(inner_type=array.inner_type.left_type, size=array.size)
right_type = ArrayType(
contained_types=array.contained_types.right_type, size=array.size
)
left_type = ArrayType(
contained_types=array.contained_types.left_type, size=array.size
)

return Tuple(
right_type=right_type,
Expand Down
6 changes: 3 additions & 3 deletions nada_dsl/nada_types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,13 @@ def store_in_ast(self, ty):
)


def inner_type(ty):
def contained_types(ty):
"""Utility function that calculates the parent type for a function argument."""

origin_ty = getattr(ty, "__origin__", ty)
if not issubclass(origin_ty, ScalarType):
inner_ty = getattr(ty, "__args__", None)
inner_ty = inner_type(inner_ty[0]) if inner_ty else T
inner_ty = contained_types(inner_ty[0]) if inner_ty else T
return origin_ty.init_as_template_type(inner_ty)
if origin_ty.mode == Mode.CONSTANT:
return origin_ty(value=0)
Expand All @@ -167,7 +167,7 @@ def nada_fn(fn, args_ty=None, return_ty=None) -> NadaFunction[T, R]:
function_id = next_operation_id()
for arg in args.args:
arg_type = args_ty[arg] if args_ty else args.annotations[arg]
arg_type = inner_type(arg_type)
arg_type = contained_types(arg_type)
# We'll get the function source ref for now
nada_arg = NadaFunctionArg(
function_id,
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion tests/compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def test_compile_map_simple():
if name == "InputReference":
array_input_id = op_id
assert op["type"] == {
"Array": {"inner_type": "SecretInteger", "size": 3}
"Array": {"contained_types": "SecretInteger", "size": 3}
}
operations_found += 1
elif name == "Map":
Expand Down
16 changes: 8 additions & 8 deletions tests/compiler_frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def test_zip(input_type, input_name):
right = AST_OPERATIONS[zip_mir["right"]]
assert left.name == "left"
assert right.name == "right"
assert zip_mir["type"][input_name]["inner_type"] == {
assert zip_mir["type"][input_name]["contained_types"] == {
"Tuple": {
"left_type": "SecretInteger",
"right_type": "SecretInteger",
Expand Down Expand Up @@ -187,8 +187,8 @@ def test_unzip(input_type: type[Array]):
assert zip_ast.name == "Zip"
assert unzip_mir["type"] == {
"Tuple": {
"left_type": {"Array": {"inner_type": "SecretInteger", "size": 10}},
"right_type": {"Array": {"inner_type": "SecretInteger", "size": 10}},
"left_type": {"Array": {"contained_types": "SecretInteger", "size": 10}},
"right_type": {"Array": {"contained_types": "SecretInteger", "size": 10}},
}
}

Expand Down Expand Up @@ -218,7 +218,7 @@ def nada_function(a: SecretInteger) -> SecretInteger:
assert list(parent["type"].keys()) == [input_name]
inner_inner = AST_OPERATIONS[parent["parent"]]
assert inner_inner.name == "parent"
assert parent["type"][input_name]["inner_type"] == "SecretInteger"
assert parent["type"][input_name]["contained_types"] == "SecretInteger"


@pytest.mark.parametrize(
Expand Down Expand Up @@ -423,9 +423,9 @@ def matrix_addition(
assert matrix_addition_fn["function"] == "matrix_addition"
args = matrix_addition_fn["args"]
assert len(args) == 2
a_arg_type = {input_name: {"inner_type": "SecretInteger"}}
a_arg_type = {input_name: {"contained_types": "SecretInteger"}}
check_arg(args[0], "a", a_arg_type)
b_arg_type = {input_name: {"inner_type": "SecretInteger"}}
b_arg_type = {input_name: {"contained_types": "SecretInteger"}}
check_arg(args[1], "b", b_arg_type)
assert matrix_addition_fn["return_type"] == "SecretInteger"

Expand All @@ -440,7 +440,7 @@ def matrix_addition(
assert list(reduce_op_inner.keys()) == ["Map"]
map_op = reduce_op_inner["Map"]
map_op["function_id"] = add_fn["id"]
map_op["type"] = {input_name: {"inner_type": "SecretInteger", "size": None}}
map_op["type"] = {input_name: {"contained_types": "SecretInteger", "size": None}}

map_op_inner = operations[map_op["parent"]]
assert list(map_op_inner.keys()) == ["Zip"]
Expand Down Expand Up @@ -469,7 +469,7 @@ def test_array_new():
assert first.name == "first"
assert second.name == "second"
assert parent["type"]["Array"] == {
"inner_type": "SecretInteger",
"contained_types": "SecretInteger",
"size": 2,
}

Expand Down

0 comments on commit 0d75641

Please sign in to comment.