Skip to content

Commit

Permalink
feat: add support for NTuple (#41)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jmgr authored Nov 5, 2024
1 parent b0f2056 commit f697877
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 3 deletions.
31 changes: 28 additions & 3 deletions nada_dsl/compiler_frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
traverse_and_process_operations,
)
from nada_dsl.nada_types import AllTypes, Party
from nada_dsl.nada_types.collections import Array, Vector, Tuple, unzip
from nada_dsl.nada_types.collections import Array, Vector, Tuple, NTuple, unzip
from nada_dsl.nada_types.function import NadaFunctionArg, NadaFunctionCall, nada_fn


Expand Down Expand Up @@ -492,8 +492,8 @@ def test_array_new_same_type():
def test_tuple_new():
first_input = create_input(SecretInteger, "first", "party", **{})
second_input = create_input(PublicInteger, "second", "party", **{})
array = Tuple.new(first_input, second_input)
array_ast = AST_OPERATIONS[array.inner.id]
tuple = Tuple.new(first_input, second_input)
array_ast = AST_OPERATIONS[tuple.inner.id]

op = process_operation(array_ast, {}).mir

Expand All @@ -520,6 +520,31 @@ def test_tuple_new_empty():
)


def test_n_tuple_new():
first_input = create_input(SecretInteger, "first", "party", **{})
second_input = create_input(PublicInteger, "second", "party", **{})
third_input = create_input(SecretInteger, "third", "party", **{})
tuple = NTuple.new([first_input, second_input, third_input])
array_ast = AST_OPERATIONS[tuple.inner.id]

op = process_operation(array_ast, {}).mir

assert list(op.keys()) == ["New"]

inner = op["New"]

first_ast = AST_OPERATIONS[inner["elements"][0]]
second_ast = AST_OPERATIONS[inner["elements"][1]]
third_ast = AST_OPERATIONS[inner["elements"][2]]
assert first_ast.name == "first"
assert second_ast.name == "second"
assert third_ast.name == "third"
print(f"inner = {inner}")
assert inner["type"]["NTuple"] == {
"types": ["SecretInteger", "Integer", "SecretInteger"],
}


@pytest.mark.parametrize(
("binary_operator", "name", "ty"),
[
Expand Down
2 changes: 2 additions & 0 deletions nada_dsl/nada_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, name):
"Array",
"Vector",
"Tuple",
"NTuple",
]
AllTypesType = Union[
Type["Integer"],
Expand All @@ -51,6 +52,7 @@ def __init__(self, name):
Type["ArrayType"],
Type["Vector"],
Type["Tuple"],
Type["NTuple"],
]
OperationType = Union[
"Addition",
Expand Down
82 changes: 82 additions & 0 deletions nada_dsl/nada_types/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,58 @@ def generic_type(cls, left_type: U, right_type: T) -> TupleType:
return TupleType(left_type=left_type, right_type=right_type)


@dataclass
class NTupleType:
"""Marker type for NTuples."""

types: List[NadaType]

def to_type(self):
"""Convert a n tuple object into a Nada type."""
return {
"NTuple": {
"types": [ty.to_type() for ty in self.types],
}
}


class NTuple(NadaType):
"""The NTuple type"""

types: List[NadaType]

def __init__(self, inner, types: List[NadaType]):
self.types = types
self.inner = inner
super().__init__(self.inner)

@classmethod
def new(cls, types: List[NadaType]) -> "NTuple":
"""Constructs a new NTuple."""
return NTuple(
types=types,
inner=NTupleNew(
inner=types,
source_ref=SourceRef.back_frame(),
inner_type=NTuple(
types=types, inner=None
),
),
)

@classmethod
def generic_type(cls, types: List[NadaType]) -> NTupleType:
"""Returns the generic type for this NTuple"""
return NTupleType(types=types)

def to_type(self):
"""Convert operation wrapper to a dictionary representing its type."""
return {
"NTuple": {
"types": [ty.to_type() for ty in self.types]
}
}

def get_inner_type(inner_type):
"""Utility that returns the inner type for a composite type."""
inner_type = copy.copy(inner_type)
Expand Down Expand Up @@ -515,6 +567,36 @@ def store_in_ast(self, ty: object):
)


class NTupleNew:
"""MIR NTuple new operation.
Represents the creation of a new Tuple.
"""

inner_types: List[NadaType]
inner: typing.Tuple
source_ref: SourceRef

def __init__(
self, inner_type: NadaType, inner: typing.Tuple, source_ref: SourceRef
):
self.id = next_operation_id()
self.inner = inner
self.source_ref = source_ref
self.inner_type = inner_type

def store_in_ast(self, ty: object):
"""Store this NTupleNew in the AST."""
AST_OPERATIONS[self.id] = NewASTOperation(
id=self.id,
name=self.__class__.__name__,
elements=[element.inner.id for element in self.inner],
source_ref=self.source_ref,
ty=ty,
inner_type=self.inner_type,
)


def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]:
"""The Unzip operation for Arrays."""
right_type = ArrayType(inner_type=array.inner_type.right_type, size=array.size)
Expand Down

0 comments on commit f697877

Please sign in to comment.