diff --git a/nada_dsl/ast_util.py b/nada_dsl/ast_util.py index f48dc6f..c3d8b33 100644 --- a/nada_dsl/ast_util.py +++ b/nada_dsl/ast_util.py @@ -376,3 +376,47 @@ def to_mir(self): "source_ref_index": self.source_ref.to_index(), } } + + +@dataclass +class NTupleAccessorASTOperation(ASTOperation): + """AST representation of a n tuple accessor operation.""" + + index: int + source: int + + def child_operations(self): + return [self.source] + + def to_mir(self): + return { + "NTupleAccessor": { + "id": self.id, + "index": self.index, + "source": self.source, + "type": self.ty, + "source_ref_index": self.source_ref.to_index(), + } + } + + +@dataclass +class ObjectAccessorASTOperation(ASTOperation): + """AST representation of an object accessor operation.""" + + key: str + source: int + + def child_operations(self): + return [self.source] + + def to_mir(self): + return { + "ObjectAccessor": { + "id": self.id, + "key": self.key, + "source": self.source, + "type": self.ty, + "source_ref_index": self.source_ref.to_index(), + } + } diff --git a/nada_dsl/compiler_frontend.py b/nada_dsl/compiler_frontend.py index 033e867..6efcffc 100644 --- a/nada_dsl/compiler_frontend.py +++ b/nada_dsl/compiler_frontend.py @@ -20,10 +20,12 @@ InputASTOperation, LiteralASTOperation, MapASTOperation, + NTupleAccessorASTOperation, NadaFunctionASTOperation, NadaFunctionArgASTOperation, NadaFunctionCallASTOperation, NewASTOperation, + ObjectAccessorASTOperation, RandomASTOperation, ReduceASTOperation, UnaryASTOperation, @@ -296,6 +298,8 @@ def process_operation( NewASTOperation, RandomASTOperation, NadaFunctionArgASTOperation, + NTupleAccessorASTOperation, + ObjectAccessorASTOperation, ), ): processed_operation = ProcessOperationOutput(operation.to_mir(), None) diff --git a/nada_dsl/nada_types/__init__.py b/nada_dsl/nada_types/__init__.py index 2e7ab0f..d6f541b 100644 --- a/nada_dsl/nada_types/__init__.py +++ b/nada_dsl/nada_types/__init__.py @@ -150,10 +150,10 @@ def __init__(self, child: OperationType): def to_mir(self): """Default implementation for the Conversion of a type into MIR representation.""" - return self.__class__.class_to_type() + return self.__class__.class_to_mir() @classmethod - def class_to_type(cls) -> str: + def class_to_mir(cls) -> str: """Converts a class into a MIR Nada type.""" name = cls.__name__ # Rename public variables so they are considered as the same as literals. @@ -165,8 +165,8 @@ def __bool__(self): raise NotImplementedError @classmethod - def is_scalable(cls) -> bool: - """Returns True if the type is a scalable.""" + def is_scalar(cls) -> bool: + """Returns True if the type is a scalar.""" return False @classmethod diff --git a/nada_dsl/nada_types/collections.py b/nada_dsl/nada_types/collections.py index a4cbd73..7a58e94 100644 --- a/nada_dsl/nada_types/collections.py +++ b/nada_dsl/nada_types/collections.py @@ -3,7 +3,7 @@ import copy from dataclasses import dataclass import inspect -from typing import Dict, Generic, List, Optional +from typing import Any, Dict, Generic, List, Optional import typing from typing import TypeVar @@ -11,7 +11,9 @@ AST_OPERATIONS, BinaryASTOperation, MapASTOperation, + NTupleAccessorASTOperation, NewASTOperation, + ObjectAccessorASTOperation, ReduceASTOperation, UnaryASTOperation, ) @@ -58,7 +60,7 @@ def to_mir(self): size = {"size": self.size} if self.size else {} contained_type = self.retrieve_inner_type() return {"Array": {"inner_type": contained_type, **size}} - if isinstance(self, (Vector, VectorType)): + if isinstance(self, Vector): contained_type = self.retrieve_inner_type() return {"Vector": {"inner_type": contained_type}} if isinstance(self, (Tuple, TupleType)): @@ -67,15 +69,50 @@ def to_mir(self): "left_type": ( self.left_type.to_mir() if isinstance(self.left_type, (NadaType, ArrayType, TupleType)) - else self.left_type.class_to_type() + else self.left_type.class_to_mir() ), "right_type": ( self.right_type.to_mir() - if isinstance(self.right_type, (NadaType, ArrayType, TupleType)) - else self.right_type.class_to_type() + if isinstance( + self.right_type, + (NadaType, ArrayType, TupleType), + ) + else self.right_type.class_to_mir() ), } } + if isinstance(self, NTuple): + return { + "NTuple": { + "types": [ + ( + ty.to_mir() + if isinstance(ty, (NadaType, ArrayType, TupleType)) + else ty.class_to_mir() + ) + for ty in [ + type(value) + for value in self.values # pylint: disable=E1101 + ] + ] + } + } + if isinstance(self, Object): + return { + "Object": { + "types": { + name: ( + ty.to_mir() + if isinstance(ty, (NadaType, ArrayType, TupleType)) + else ty.class_to_mir() + ) + for name, ty in [ + (name, type(value)) + for name, value in self.values.items() # pylint: disable=E1101 + ] + } + } + } raise InvalidTypeError( f"{self.__class__.__name__} is not a valid Nada Collection" ) @@ -85,7 +122,7 @@ def retrieve_inner_type(self): if isinstance(self.contained_type, TypeVar): return "T" if inspect.isclass(self.contained_type): - return self.contained_type.class_to_type() + return self.contained_type.class_to_mir() return self.contained_type.to_mir() @@ -199,100 +236,160 @@ def generic_type(cls, left_type: U, right_type: T) -> TupleType: return TupleType(left_type=left_type, right_type=right_type) -@dataclass -class NTupleType: - """Marker type for NTuples.""" - - types: List[NadaType] +def _generate_accessor(value: Any, accessor: Any) -> NadaType: + ty = type(value) - def to_mir(self): - """Convert a n tuple object into a Nada type.""" - return { - "NTuple": { - "types": [ty.to_mir() for ty in self.types], - } - } + if ty.is_scalar(): + if ty.is_literal(): + return value + return ty(child=accessor) + if ty == Array: + return Array( + child=accessor, + contained_type=value.contained_type, + size=value.size, + ) + if ty == NTuple: + return NTuple( + child=accessor, + values=value.values, + ) + if ty == Object: + return Object( + child=accessor, + values=value.values, + ) + raise TypeError(f"Unsupported type for accessor: {ty}") -class NTuple(NadaType): +class NTuple(Collection): """The NTuple type""" - types: List[NadaType] + values: List[NadaType] - def __init__(self, child, types: List[NadaType]): - self.types = types + def __init__(self, child, values: List[NadaType]): + self.values = values self.child = child super().__init__(self.child) @classmethod - def new(cls, types: List[NadaType]) -> "NTuple": + def new(cls, values: List[NadaType]) -> "NTuple": """Constructs a new NTuple.""" return NTuple( - types=types, + values=values, child=NTupleNew( - child=types, + child=values, source_ref=SourceRef.back_frame(), ), ) - @classmethod - def generic_type(cls, types: List[NadaType]) -> NTupleType: - """Returns the generic type for this NTuple""" - return NTupleType(types=types) + def __getitem__(self, index: int) -> NadaType: + if index >= len(self.values): + raise IndexError(f"Invalid index {index} for NTuple.") - def to_mir(self): - """Convert operation wrapper to a dictionary representing its type.""" - return {"NTuple": {"types": [ty.to_mir() for ty in self.types]}} + accessor = NTupleAccessor( + index=index, + child=self, + source_ref=SourceRef.back_frame(), + ) + + return _generate_accessor(self.values[index], accessor) @dataclass -class ObjectType: - """Marker type for Objects.""" +class NTupleAccessor: + """Accessor for NTuple""" - types: Dict[str, NadaType] + child: NTuple + index: int + source_ref: SourceRef - def to_mir(self): - """Convert an object into a Nada type.""" - return { - "Object": { - "types": {name: ty.to_mir() for name, ty in self.types.items()}, - } - } + def __init__( + self, + child: NTuple, + index: int, + source_ref: SourceRef, + ): + self.id = next_operation_id() + self.child = child + self.index = index + self.source_ref = source_ref + def store_in_ast(self, ty: object): + """Store this accessor in the AST.""" + AST_OPERATIONS[self.id] = NTupleAccessorASTOperation( + id=self.id, + source=self.child.child.id, + index=self.index, + source_ref=self.source_ref, + ty=ty, + ) -class Object(NadaType): + +class Object(Collection): """The Object type""" - types: Dict[str, NadaType] + values: Dict[str, NadaType] - def __init__(self, child, types: Dict[str, NadaType]): - self.types = types + def __init__(self, child, values: Dict[str, NadaType]): + self.values = values self.child = child super().__init__(self.child) @classmethod - def new(cls, types: Dict[str, NadaType]) -> "Object": + def new(cls, values: Dict[str, NadaType]) -> "Object": """Constructs a new Object.""" return Object( - types=types, + values=values, child=ObjectNew( - child=types, + child=values, source_ref=SourceRef.back_frame(), ), ) - @classmethod - def generic_type(cls, types: Dict[str, NadaType]) -> ObjectType: - """Returns the generic type for this Object""" - return ObjectType(types=types) + def __getattr__(self, attr: str) -> NadaType: + if attr not in self.values: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{attr}'" + ) - def to_mir(self): - """Convert operation wrapper to a dictionary representing its type.""" - return { - "Object": { - "types": {name: ty.to_mir() for name, ty in self.types.items()}, - } - } + accessor = ObjectAccessor( + key=attr, + child=self, + source_ref=SourceRef.back_frame(), + ) + + return _generate_accessor(self.values[attr], accessor) + + +@dataclass +class ObjectAccessor: + """Accessor for Object""" + + child: Object + key: str + source_ref: SourceRef + + def __init__( + self, + child: Object, + key: str, + source_ref: SourceRef, + ): + self.id = next_operation_id() + self.child = child + self.key = key + self.source_ref = source_ref + + def store_in_ast(self, ty: object): + """Store this accessor in the AST.""" + AST_OPERATIONS[self.id] = ObjectAccessorASTOperation( + id=self.id, + source=self.child.child.id, + key=self.key, + source_ref=self.source_ref, + ty=ty, + ) # pylint: disable=W0511 @@ -501,24 +598,12 @@ def new(cls, *args) -> "Array[T]": ), ) - @classmethod - def generic_type(cls, contained_type: T, size: int) -> ArrayType: - """Return the generic type of the Array.""" - return ArrayType(contained_type=contained_type, size=size) - @classmethod def init_as_template_type(cls, contained_type) -> "Array[T]": """Construct an empty template array with the given child type.""" return Array(child=None, contained_type=contained_type, size=None) -@dataclass -class VectorType(Collection): - """The generic type for Vectors.""" - - contained_type: AllTypesType - - @dataclass class Vector(Generic[T], Collection): """ @@ -579,11 +664,6 @@ def reduce( ) ) # type: ignore - @classmethod - def generic_type(cls, contained_type: T) -> VectorType: - """Returns the generic type for a Vector with the given child type.""" - return VectorType(child=None, contained_type=contained_type) - @classmethod def init_as_template_type(cls, contained_type) -> "Vector[T]": """Construct an empty Vector with the given child type.""" @@ -621,10 +701,10 @@ class NTupleNew: Represents the creation of a new Tuple. """ - child: typing.Tuple + child: List[NadaType] source_ref: SourceRef - def __init__(self, child: typing.Tuple, source_ref: SourceRef): + def __init__(self, child: List[NadaType], source_ref: SourceRef): self.id = next_operation_id() self.child = child self.source_ref = source_ref @@ -646,10 +726,10 @@ class ObjectNew: Represents the creation of a new Object. """ - child: typing.Dict + child: Dict[str, NadaType] source_ref: SourceRef - def __init__(self, child: typing.Dict, source_ref: SourceRef): + def __init__(self, child: Dict[str, NadaType], source_ref: SourceRef): self.id = next_operation_id() self.child = child self.source_ref = source_ref @@ -681,7 +761,6 @@ def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]: ) -@dataclass class ArrayNew(Generic[T]): """MIR Array new operation""" diff --git a/nada_dsl/nada_types/function.py b/nada_dsl/nada_types/function.py index c72bab1..4209d8d 100644 --- a/nada_dsl/nada_types/function.py +++ b/nada_dsl/nada_types/function.py @@ -101,7 +101,7 @@ def store_in_ast(self): name=self.function.__name__, args=[arg.id for arg in self.args], id=self.id, - ty=self.return_type.class_to_type(), + ty=self.return_type.class_to_mir(), source_ref=self.source_ref, child=self.child.child.id, ) @@ -125,7 +125,7 @@ def __init__(self, nada_function, args, source_ref): self.args = args self.fn = nada_function self.source_ref = source_ref - self.store_in_ast(nada_function.return_type.class_to_type()) + self.store_in_ast(nada_function.return_type.class_to_mir()) def store_in_ast(self, ty): """Store this function call in the AST.""" diff --git a/nada_dsl/nada_types/scalar_types.py b/nada_dsl/nada_types/scalar_types.py index f9c12d4..0c6bae8 100644 --- a/nada_dsl/nada_types/scalar_types.py +++ b/nada_dsl/nada_types/scalar_types.py @@ -77,7 +77,7 @@ def to_public(self) -> Self: return self @classmethod - def is_scalable(cls) -> bool: + def is_scalar(cls) -> bool: return True diff --git a/test-programs/ntuple_accessor.py b/test-programs/ntuple_accessor.py new file mode 100644 index 0000000..6edef15 --- /dev/null +++ b/test-programs/ntuple_accessor.py @@ -0,0 +1,24 @@ +from nada_dsl import * + + +def nada_main(): + party1 = Party(name="Party1") + my_int1 = PublicInteger(Input(name="my_int1", party=party1)) + my_int2 = PublicInteger(Input(name="my_int2", party=party1)) + + array = Array.new(my_int1, my_int1) + + # Store a scalar, a compound type and a literal. + tuple = NTuple.new([my_int1, array, Integer(42)]) + + scalar = tuple[0] + array = tuple[1] + literal = tuple[2] + + @nada_fn + def add(a: PublicInteger) -> PublicInteger: + return a + my_int2 + + sum = array.reduce(add, Integer(0)) + + return [Output(scalar + literal + sum, "my_output", party1)] diff --git a/test-programs/object_accessor.py b/test-programs/object_accessor.py new file mode 100644 index 0000000..0258b8e --- /dev/null +++ b/test-programs/object_accessor.py @@ -0,0 +1,24 @@ +from nada_dsl import * + + +def nada_main(): + party1 = Party(name="Party1") + my_int1 = PublicInteger(Input(name="my_int1", party=party1)) + my_int2 = PublicInteger(Input(name="my_int2", party=party1)) + + array = Array.new(my_int1, my_int1) + + # Store a scalar, a compound type and a literal. + object = Object.new({"a": my_int1, "b": array, "c": Integer(42)}) + + scalar = object.a + array = object.b + literal = object.c + + @nada_fn + def add(a: PublicInteger) -> PublicInteger: + return a + my_int2 + + sum = array.reduce(add, Integer(0)) + + return [Output(scalar + literal + sum, "my_output", party1)] diff --git a/tests/compile_test.py b/tests/compile_test.py index ddabb12..225e94f 100644 --- a/tests/compile_test.py +++ b/tests/compile_test.py @@ -178,3 +178,13 @@ def nada_main(): encoded_program_str = base64.b64encode(bytes(program_str, "utf-8")).decode("utf_8") output = compile_string(encoded_program_str) print_output(output) + + +def test_compile_ntuple(): + mir_str = compile_script(f"{get_test_programs_folder()}/ntuple_accessor.py").mir + assert mir_str != "" + + +def test_compile_object(): + mir_str = compile_script(f"{get_test_programs_folder()}/object_accessor.py").mir + assert mir_str != "" diff --git a/tests/nada_type_test.py b/tests/nada_type_test.py index a4763ba..ad20ea5 100644 --- a/tests/nada_type_test.py +++ b/tests/nada_type_test.py @@ -13,6 +13,6 @@ (PublicBoolean, "Boolean"), ], ) -def test_class_to_type(cls: NadaType, expected: str): - """Tests `NadaType.class_to_type()""" - assert cls.class_to_type() == expected +def test_class_to_mir(cls: NadaType, expected: str): + """Tests `NadaType.class_to_mir()""" + assert cls.class_to_mir() == expected