Skip to content

Commit

Permalink
feat: add support for Object
Browse files Browse the repository at this point in the history
  • Loading branch information
Jmgr committed Nov 5, 2024
1 parent f697877 commit 09c61e6
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 2 deletions.
27 changes: 26 additions & 1 deletion nada_dsl/compiler_frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
traverse_and_process_operations,
)
from nada_dsl.nada_types import AllTypes, Party
from nada_dsl.nada_types.collections import Array, Vector, Tuple, NTuple, unzip
from nada_dsl.nada_types.collections import Array, Vector, Tuple, NTuple, Object, unzip
from nada_dsl.nada_types.function import NadaFunctionArg, NadaFunctionCall, nada_fn


Expand Down Expand Up @@ -545,6 +545,31 @@ def test_n_tuple_new():
}


def test_object_new():
first_input = create_input(SecretInteger, "first", "party", **{})
second_input = create_input(PublicInteger, "second", "party", **{})
third_input = create_input(SecretInteger, "third", "party", **{})
object = Object.new({"a": first_input, "b": second_input, "c": third_input})
array_ast = AST_OPERATIONS[object.inner.id]

op = process_operation(array_ast, {}).mir

assert list(op.keys()) == ["New"]

inner = op["New"]

first_ast = AST_OPERATIONS[inner["elements"][0]]
second_ast = AST_OPERATIONS[inner["elements"][1]]
third_ast = AST_OPERATIONS[inner["elements"][2]]
assert first_ast.name == "first"
assert second_ast.name == "second"
assert third_ast.name == "third"
print(f"inner = {inner}")
assert inner["type"]["Object"] == {
"types": {"a": "SecretInteger", "b": "Integer", "c": "SecretInteger"},
}


@pytest.mark.parametrize(
("binary_operator", "name", "ty"),
[
Expand Down
2 changes: 2 additions & 0 deletions nada_dsl/nada_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, name):
"Vector",
"Tuple",
"NTuple",
"Object",
]
AllTypesType = Union[
Type["Integer"],
Expand All @@ -53,6 +54,7 @@ def __init__(self, name):
Type["Vector"],
Type["Tuple"],
Type["NTuple"],
Type["Object"],
]
OperationType = Union[
"Addition",
Expand Down
91 changes: 90 additions & 1 deletion nada_dsl/nada_types/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
from dataclasses import dataclass
import inspect
from typing import Generic, List, Optional
from typing import Dict, Generic, List, Optional
import typing
from typing import TypeVar

Expand Down Expand Up @@ -261,6 +261,65 @@ def get_inner_type(inner_type):
return inner_type


@dataclass
class ObjectType:
"""Marker type for Objects."""

types: Dict[str, NadaType]

def to_type(self):
"""Convert an object into a Nada type."""
return {
"Object": {
"types": {name: ty.to_type() for name, ty in self.types.items()},
}
}


class Object(NadaType):
"""The Object type"""

types: Dict[str, NadaType]

def __init__(self, inner, types: Dict[str, NadaType]):
self.types = types
self.inner = inner
super().__init__(self.inner)

@classmethod
def new(cls, types: Dict[str, NadaType]) -> "Object":
"""Constructs a new Object."""
return Object(
types=types,
inner=ObjectNew(
inner=types,
source_ref=SourceRef.back_frame(),
inner_type=Object(
types=types, inner=None
),
),
)

@classmethod
def generic_type(cls, types: Dict[str, NadaType]) -> ObjectType:
"""Returns the generic type for this Object"""
return ObjectType(types=types)

def to_type(self):
"""Convert operation wrapper to a dictionary representing its type."""
return {
"Object": {
"types": {name: ty.to_type() for name, ty in self.types.items()},
}
}

def get_inner_type(inner_type):
"""Utility that returns the inner type for a composite type."""
inner_type = copy.copy(inner_type)
setattr(inner_type, "inner", None)
return inner_type


class Zip:
"""The Zip operation."""

Expand Down Expand Up @@ -597,6 +656,36 @@ def store_in_ast(self, ty: object):
)


class ObjectNew:
"""MIR Object new operation.
Represents the creation of a new Object.
"""

inner_types: Dict[str, NadaType]
inner: typing.Dict
source_ref: SourceRef

def __init__(
self, inner_type: NadaType, inner: typing.Dict, source_ref: SourceRef
):
self.id = next_operation_id()
self.inner = inner
self.source_ref = source_ref
self.inner_type = inner_type

def store_in_ast(self, ty: object):
"""Store this Object in the AST."""
AST_OPERATIONS[self.id] = NewASTOperation(
id=self.id,
name=self.__class__.__name__,
elements=[element.inner.id for element in self.inner.values()],
source_ref=self.source_ref,
ty=ty,
inner_type=self.inner_type,
)


def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]:
"""The Unzip operation for Arrays."""
right_type = ArrayType(inner_type=array.inner_type.right_type, size=array.size)
Expand Down

0 comments on commit 09c61e6

Please sign in to comment.