Skip to content

Commit

Permalink
feat: add support for Objects (#43)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jmgr authored Nov 5, 2024
1 parent f697877 commit 9892065
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 2 deletions.
27 changes: 26 additions & 1 deletion nada_dsl/compiler_frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
traverse_and_process_operations,
)
from nada_dsl.nada_types import AllTypes, Party
from nada_dsl.nada_types.collections import Array, Vector, Tuple, NTuple, unzip
from nada_dsl.nada_types.collections import Array, Vector, Tuple, NTuple, Object, unzip
from nada_dsl.nada_types.function import NadaFunctionArg, NadaFunctionCall, nada_fn


Expand Down Expand Up @@ -545,6 +545,31 @@ def test_n_tuple_new():
}


def test_object_new():
first_input = create_input(SecretInteger, "first", "party", **{})
second_input = create_input(PublicInteger, "second", "party", **{})
third_input = create_input(SecretInteger, "third", "party", **{})
object = Object.new({"a": first_input, "b": second_input, "c": third_input})
array_ast = AST_OPERATIONS[object.inner.id]

op = process_operation(array_ast, {}).mir

assert list(op.keys()) == ["New"]

inner = op["New"]

first_ast = AST_OPERATIONS[inner["elements"][0]]
second_ast = AST_OPERATIONS[inner["elements"][1]]
third_ast = AST_OPERATIONS[inner["elements"][2]]
assert first_ast.name == "first"
assert second_ast.name == "second"
assert third_ast.name == "third"
print(f"inner = {inner}")
assert inner["type"]["Object"] == {
"types": {"a": "SecretInteger", "b": "Integer", "c": "SecretInteger"},
}


@pytest.mark.parametrize(
("binary_operator", "name", "ty"),
[
Expand Down
2 changes: 2 additions & 0 deletions nada_dsl/nada_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, name):
"Vector",
"Tuple",
"NTuple",
"Object",
]
AllTypesType = Union[
Type["Integer"],
Expand All @@ -53,6 +54,7 @@ def __init__(self, name):
Type["Vector"],
Type["Tuple"],
Type["NTuple"],
Type["Object"],
]
OperationType = Union[
"Addition",
Expand Down
85 changes: 84 additions & 1 deletion nada_dsl/nada_types/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
from dataclasses import dataclass
import inspect
from typing import Generic, List, Optional
from typing import Dict, Generic, List, Optional
import typing
from typing import TypeVar

Expand Down Expand Up @@ -254,6 +254,59 @@ def to_type(self):
}
}


@dataclass
class ObjectType:
"""Marker type for Objects."""

types: Dict[str, NadaType]

def to_type(self):
"""Convert an object into a Nada type."""
return {
"Object": {
"types": {name: ty.to_type() for name, ty in self.types.items()},
}
}


class Object(NadaType):
"""The Object type"""

types: Dict[str, NadaType]

def __init__(self, inner, types: Dict[str, NadaType]):
self.types = types
self.inner = inner
super().__init__(self.inner)

@classmethod
def new(cls, types: Dict[str, NadaType]) -> "Object":
"""Constructs a new Object."""
return Object(
types=types,
inner=ObjectNew(
inner=types,
source_ref=SourceRef.back_frame(),
inner_type=Object(
types=types, inner=None
),
),
)

@classmethod
def generic_type(cls, types: Dict[str, NadaType]) -> ObjectType:
"""Returns the generic type for this Object"""
return ObjectType(types=types)

def to_type(self):
"""Convert operation wrapper to a dictionary representing its type."""
return {
"Object": {
"types": {name: ty.to_type() for name, ty in self.types.items()},
}
}

def get_inner_type(inner_type):
"""Utility that returns the inner type for a composite type."""
inner_type = copy.copy(inner_type)
Expand Down Expand Up @@ -597,6 +650,36 @@ def store_in_ast(self, ty: object):
)


class ObjectNew:
"""MIR Object new operation.
Represents the creation of a new Object.
"""

inner_types: Dict[str, NadaType]
inner: typing.Dict
source_ref: SourceRef

def __init__(
self, inner_type: NadaType, inner: typing.Dict, source_ref: SourceRef
):
self.id = next_operation_id()
self.inner = inner
self.source_ref = source_ref
self.inner_type = inner_type

def store_in_ast(self, ty: object):
"""Store this Object in the AST."""
AST_OPERATIONS[self.id] = NewASTOperation(
id=self.id,
name=self.__class__.__name__,
elements=[element.inner.id for element in self.inner.values()],
source_ref=self.source_ref,
ty=ty,
inner_type=self.inner_type,
)


def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]:
"""The Unzip operation for Arrays."""
right_type = ArrayType(inner_type=array.inner_type.right_type, size=array.size)
Expand Down

0 comments on commit 9892065

Please sign in to comment.