diff --git a/nada_dsl/nada_types/__init__.py b/nada_dsl/nada_types/__init__.py index 9fb9c43..182150d 100644 --- a/nada_dsl/nada_types/__init__.py +++ b/nada_dsl/nada_types/__init__.py @@ -90,7 +90,7 @@ def __init__(self, name): """Type alias for the NadaType representation. This representation can be either a string ("SecretInteger") -or a dictionary (Array{inner_type=SecretInteger, size=3}). +or a dictionary (Array{contained_types=SecretInteger, size=3}). """ diff --git a/nada_dsl/nada_types/collections.py b/nada_dsl/nada_types/collections.py index a61e662..aacd227 100644 --- a/nada_dsl/nada_types/collections.py +++ b/nada_dsl/nada_types/collections.py @@ -50,17 +50,17 @@ class Collection(NadaType): left_type: AllTypesType right_type: AllTypesType - inner_type: AllTypesType + contained_types: AllTypesType def to_type(self): """Convert operation wrapper to a dictionary representing its type.""" if isinstance(self, (Array, ArrayType)): size = {"size": self.size} if self.size else {} - inner_type = self.retrieve_inner_type() - return {"Array": {"inner_type": inner_type, **size}} + contained_types = self.retrieve_inner_type() + return {"Array": {"contained_types": contained_types, **size}} if isinstance(self, (Vector, VectorType)): - inner_type = self.retrieve_inner_type() - return {"Vector": {"inner_type": inner_type}} + contained_types = self.retrieve_inner_type() + return {"Vector": {"contained_types": contained_types}} if isinstance(self, (Tuple, TupleType)): return { "Tuple": { @@ -82,11 +82,11 @@ def to_type(self): def retrieve_inner_type(self): """Retrieves the parent type of this collection""" - if isinstance(self.inner_type, TypeVar): + if isinstance(self.contained_types, TypeVar): return "T" - if inspect.isclass(self.inner_type): - return self.inner_type.class_to_type() - return self.inner_type.to_type() + if inspect.isclass(self.contained_types): + return self.contained_types.class_to_type() + return self.contained_types.to_type() class Map(Generic[T, R]): @@ -295,11 +295,11 @@ def to_type(self): } -def get_inner_type(inner_type): +def get_inner_type(contained_types): """Utility that returns the parent type for a composite type.""" - inner_type = copy.copy(inner_type) - setattr(inner_type, "parent", None) - return inner_type + contained_types = copy.copy(contained_types) + setattr(contained_types, "parent", None) + return contained_types class Zip: @@ -367,12 +367,17 @@ def store_in_ast(self, ty: NadaTypeRepr): class ArrayType: """Marker type for arrays.""" - inner_type: AllTypesType + contained_types: AllTypesType size: int def to_type(self): """Convert this generic type into a MIR Nada type.""" - return {"Array": {"inner_type": self.inner_type.to_type(), "size": self.size}} + return { + "Array": { + "contained_types": self.contained_types.to_type(), + "size": self.size, + } + } class Array(Generic[T], Collection): @@ -383,7 +388,7 @@ class Array(Generic[T], Collection): Attributes ---------- - inner_type: T + contained_types: T The type of the array parent: The optional parent operation @@ -391,18 +396,18 @@ class Array(Generic[T], Collection): The size of the array """ - inner_type: T + contained_types: T size: int - def __init__(self, parent, size: int, inner_type: T = None): - self.inner_type = ( - inner_type - if (parent is None or inner_type is not None) + def __init__(self, parent, size: int, contained_types: T = None): + self.contained_types = ( + contained_types + if (parent is None or contained_types is not None) else get_inner_type(parent) ) self.size = size self.parent = ( - parent if inner_type is not None else getattr(parent, "parent", None) + parent if contained_types is not None else getattr(parent, "parent", None) ) if self.parent is not None: self.parent.store_in_ast(self.to_type()) @@ -419,7 +424,7 @@ def map(self: "Array[T]", function) -> "Array": nada_function = nada_fn(function) return Array( size=self.size, - inner_type=nada_function.return_type, + contained_types=nada_function.return_type, parent=Map( parent=self, fn=nada_function, source_ref=SourceRef.back_frame() ), @@ -444,8 +449,10 @@ def zip(self: "Array[T]", other: "Array[U]") -> "Array[Tuple[T, U]]": raise IncompatibleTypesError("Cannot zip arrays of different size") return Array( size=self.size, - inner_type=Tuple( - left_type=self.inner_type, right_type=other.inner_type, parent=None + contained_types=Tuple( + left_type=self.contained_types, + right_type=other.contained_types, + parent=None, ), parent=Zip(left=self, right=other, source_ref=SourceRef.back_frame()), ) @@ -460,12 +467,12 @@ def inner_product(self: "Array[T]", other: "Array[T]") -> T: if is_primitive_integer(self.retrieve_inner_type()) and is_primitive_integer( other.retrieve_inner_type() ): - inner_type = ( - self.inner_type - if inspect.isclass(self.inner_type) - else self.inner_type.__class__ + contained_types = ( + self.contained_types + if inspect.isclass(self.contained_types) + else self.contained_types.__class__ ) - return inner_type( + return contained_types( parent=InnerProduct( left=self, right=other, source_ref=SourceRef.back_frame() ) @@ -486,7 +493,7 @@ def new(cls, *args) -> "Array[T]": raise TypeError("All arguments must be of the same type") return Array( - inner_type=first_arg, + contained_types=first_arg, size=len(args), parent=ArrayNew( parent=args, @@ -495,21 +502,21 @@ def new(cls, *args) -> "Array[T]": ) @classmethod - def generic_type(cls, inner_type: T, size: int) -> ArrayType: + def generic_type(cls, contained_types: T, size: int) -> ArrayType: """Return the generic type of the Array.""" - return ArrayType(inner_type=inner_type, size=size) + return ArrayType(contained_types=contained_types, size=size) @classmethod - def init_as_template_type(cls, inner_type) -> "Array[T]": + def init_as_template_type(cls, contained_types) -> "Array[T]": """Construct an empty template array with the given parent type.""" - return Array(parent=None, inner_type=inner_type, size=None) + return Array(parent=None, contained_types=contained_types, size=None) @dataclass class VectorType(Collection): """The generic type for Vectors.""" - inner_type: AllTypesType + contained_types: AllTypesType @dataclass @@ -522,17 +529,17 @@ class Vector(Generic[T], Collection): its size may change at runtime. """ - inner_type: T + contained_types: T size: int - def __init__(self, parent, size, inner_type=None): - self.inner_type = ( - inner_type - if (parent is None or inner_type is not None) + def __init__(self, parent, size, contained_types=None): + self.contained_types = ( + contained_types + if (parent is None or contained_types is not None) else get_inner_type(parent) ) self.size = size - self.parent = parent if inner_type else getattr(parent, "parent", None) + self.parent = parent if contained_types else getattr(parent, "parent", None) self.parent.store_in_ast(self.to_type()) def __iter__(self): @@ -545,7 +552,7 @@ def map(self: "Vector[T]", function: NadaFunction[T, R]) -> "Vector[R]": """The map operation for Nada Vectors.""" return Vector( size=self.size, - inner_type=function.return_type, + contained_types=function.return_type, parent=(Map(parent=self, fn=function, source_ref=SourceRef.back_frame())), ) @@ -553,7 +560,9 @@ def zip(self: "Vector[T]", other: "Vector[R]") -> "Vector[Tuple[T, R]]": """The Zip operation for Nada Vectors.""" return Vector( size=self.size, - inner_type=Tuple.generic_type(self.inner_type, other.inner_type), + contained_types=Tuple.generic_type( + self.contained_types, other.contained_types + ), parent=Zip(left=self, right=other, source_ref=SourceRef.back_frame()), ) @@ -571,14 +580,14 @@ def reduce( ) # type: ignore @classmethod - def generic_type(cls, inner_type: T) -> VectorType: + def generic_type(cls, contained_types: T) -> VectorType: """Returns the generic type for a Vector with the given parent type.""" - return VectorType(parent=None, inner_type=inner_type) + return VectorType(parent=None, contained_types=contained_types) @classmethod - def init_as_template_type(cls, inner_type) -> "Vector[T]": + def init_as_template_type(cls, contained_types) -> "Vector[T]": """Construct an empty Vector with the given parent type.""" - return Vector(parent=None, inner_type=inner_type, size=None) + return Vector(parent=None, contained_types=contained_types, size=None) class TupleNew(Generic[T, U]): @@ -658,8 +667,12 @@ def store_in_ast(self, ty: object): def unzip(array: Array[Tuple[T, R]]) -> Tuple[Array[T], Array[R]]: """The Unzip operation for Arrays.""" - right_type = ArrayType(inner_type=array.inner_type.right_type, size=array.size) - left_type = ArrayType(inner_type=array.inner_type.left_type, size=array.size) + right_type = ArrayType( + contained_types=array.contained_types.right_type, size=array.size + ) + left_type = ArrayType( + contained_types=array.contained_types.left_type, size=array.size + ) return Tuple( right_type=right_type, diff --git a/nada_dsl/nada_types/function.py b/nada_dsl/nada_types/function.py index ce7306a..87fea86 100644 --- a/nada_dsl/nada_types/function.py +++ b/nada_dsl/nada_types/function.py @@ -138,13 +138,13 @@ def store_in_ast(self, ty): ) -def inner_type(ty): +def contained_types(ty): """Utility function that calculates the parent type for a function argument.""" origin_ty = getattr(ty, "__origin__", ty) if not issubclass(origin_ty, ScalarType): inner_ty = getattr(ty, "__args__", None) - inner_ty = inner_type(inner_ty[0]) if inner_ty else T + inner_ty = contained_types(inner_ty[0]) if inner_ty else T return origin_ty.init_as_template_type(inner_ty) if origin_ty.mode == Mode.CONSTANT: return origin_ty(value=0) @@ -167,7 +167,7 @@ def nada_fn(fn, args_ty=None, return_ty=None) -> NadaFunction[T, R]: function_id = next_operation_id() for arg in args.args: arg_type = args_ty[arg] if args_ty else args.annotations[arg] - arg_type = inner_type(arg_type) + arg_type = contained_types(arg_type) # We'll get the function source ref for now nada_arg = NadaFunctionArg( function_id, diff --git a/nada_mir/src/nada_mir_proto/nillion/nada/types/v1/__init__.py b/nada_mir/src/nada_mir_proto/nillion/nada/types/v1/__init__.py index 1bcbff0..c9e3e1a 100644 --- a/nada_mir/src/nada_mir_proto/nillion/nada/types/v1/__init__.py +++ b/nada_mir/src/nada_mir_proto/nillion/nada/types/v1/__init__.py @@ -48,7 +48,7 @@ class Object(betterproto.Message): class Array(betterproto.Message): """Array type, defines a collection of homogeneous values""" - inner_type: "NadaType" = betterproto.message_field(1) + contained_types: "NadaType" = betterproto.message_field(1) """Inner type of the elements of this array""" size: int = betterproto.uint32_field(2) diff --git a/tests/compile_test.py b/tests/compile_test.py index 27643b3..611d5f6 100644 --- a/tests/compile_test.py +++ b/tests/compile_test.py @@ -149,7 +149,7 @@ def test_compile_map_simple(): if name == "InputReference": array_input_id = op_id assert op["type"] == { - "Array": {"inner_type": "SecretInteger", "size": 3} + "Array": {"contained_types": "SecretInteger", "size": 3} } operations_found += 1 elif name == "Map": diff --git a/tests/compiler_frontend_test.py b/tests/compiler_frontend_test.py index 117213e..e8e62f4 100644 --- a/tests/compiler_frontend_test.py +++ b/tests/compiler_frontend_test.py @@ -155,7 +155,7 @@ def test_zip(input_type, input_name): right = AST_OPERATIONS[zip_mir["right"]] assert left.name == "left" assert right.name == "right" - assert zip_mir["type"][input_name]["inner_type"] == { + assert zip_mir["type"][input_name]["contained_types"] == { "Tuple": { "left_type": "SecretInteger", "right_type": "SecretInteger", @@ -187,8 +187,8 @@ def test_unzip(input_type: type[Array]): assert zip_ast.name == "Zip" assert unzip_mir["type"] == { "Tuple": { - "left_type": {"Array": {"inner_type": "SecretInteger", "size": 10}}, - "right_type": {"Array": {"inner_type": "SecretInteger", "size": 10}}, + "left_type": {"Array": {"contained_types": "SecretInteger", "size": 10}}, + "right_type": {"Array": {"contained_types": "SecretInteger", "size": 10}}, } } @@ -218,7 +218,7 @@ def nada_function(a: SecretInteger) -> SecretInteger: assert list(parent["type"].keys()) == [input_name] inner_inner = AST_OPERATIONS[parent["parent"]] assert inner_inner.name == "parent" - assert parent["type"][input_name]["inner_type"] == "SecretInteger" + assert parent["type"][input_name]["contained_types"] == "SecretInteger" @pytest.mark.parametrize( @@ -423,9 +423,9 @@ def matrix_addition( assert matrix_addition_fn["function"] == "matrix_addition" args = matrix_addition_fn["args"] assert len(args) == 2 - a_arg_type = {input_name: {"inner_type": "SecretInteger"}} + a_arg_type = {input_name: {"contained_types": "SecretInteger"}} check_arg(args[0], "a", a_arg_type) - b_arg_type = {input_name: {"inner_type": "SecretInteger"}} + b_arg_type = {input_name: {"contained_types": "SecretInteger"}} check_arg(args[1], "b", b_arg_type) assert matrix_addition_fn["return_type"] == "SecretInteger" @@ -440,7 +440,7 @@ def matrix_addition( assert list(reduce_op_inner.keys()) == ["Map"] map_op = reduce_op_inner["Map"] map_op["function_id"] = add_fn["id"] - map_op["type"] = {input_name: {"inner_type": "SecretInteger", "size": None}} + map_op["type"] = {input_name: {"contained_types": "SecretInteger", "size": None}} map_op_inner = operations[map_op["parent"]] assert list(map_op_inner.keys()) == ["Zip"] @@ -469,7 +469,7 @@ def test_array_new(): assert first.name == "first" assert second.name == "second" assert parent["type"]["Array"] == { - "inner_type": "SecretInteger", + "contained_types": "SecretInteger", "size": 2, }