diff --git a/nada_dsl/nada_types/collections.py b/nada_dsl/nada_types/collections.py index a9aa911..7735f5b 100644 --- a/nada_dsl/nada_types/collections.py +++ b/nada_dsl/nada_types/collections.py @@ -28,7 +28,7 @@ InvalidTypeError, NotAllowedException, ) -from nada_dsl.nada_types.function import NadaFunction, nada_fn +from nada_dsl.nada_types.function import NadaFunction, create_nada_fn from nada_dsl.nada_types.generics import U, T, R from . import AllTypes, AllTypesType, NadaTypeRepr, OperationType @@ -112,12 +112,13 @@ def store_in_ast(self, ty): ) -@dataclass class TupleMetaType(MetaType): """Marker type for Tuples.""" + is_compound = True - left_type: NadaType - right_type: NadaType + def __init__(self, left_type: MetaType, right_type: MetaType): + self.left_type = left_type + self.right_type = right_type def instantiate(self, child): return Tuple(child, self.left_type, self.right_type) @@ -145,29 +146,6 @@ def __init__(self, child, left_type: T, right_type: U): self.child = child super().__init__(self.child) - """TODO this should be deleted and use MetaType.to_mir""" - - # def to_mir(self): - # return { - # "Tuple": { - # "left_type": ( - # self.left_type.to_mir() - # if isinstance( - # self.left_type, (NadaType, ArrayMetaType, TupleMetaType) - # ) - # else self.left_type.class_to_mir() - # ), - # "right_type": ( - # self.right_type.to_mir() - # if isinstance( - # self.right_type, - # (NadaType, ArrayMetaType, TupleMetaType), - # ) - # else self.right_type.class_to_mir() - # ), - # } - # } - @classmethod def new(cls, left_value: NadaType, right_value: NadaType) -> "Tuple[T, U]": """Constructs a new Tuple.""" @@ -193,34 +171,14 @@ def _generate_accessor(ty: Any, accessor: Any) -> NadaType: if hasattr(ty, "ty") and ty.ty.is_literal(): # TODO: fix raise TypeError("Literals are not supported in accessors") return ty.instantiate(accessor) - # if ty.is_scalar(): - # if ty.is_literal(): - # return ty # value.instantiate(child=accessor) ? - # return ty(child=accessor) - # if ty == Array: - # return Array( - # child=accessor, - # contained_type=ty.contained_type, - # size=ty.size, - # ) - # if ty == NTuple: - # return NTuple( - # child=accessor, - # types=ty.types, - # ) - # if ty == Object: - # return Object( - # child=accessor, - # types=ty.types, - # ) - # raise TypeError(f"Unsupported type for accessor: {ty}") -@dataclass class NTupleMetaType(MetaType): """Marker type for NTuples.""" + is_compound = True - types: List[NadaType] + def __init__(self, types: List[MetaType]): + self.types = types def instantiate(self, child): return NTuple(child, self.types) @@ -269,22 +227,6 @@ def __getitem__(self, index: int) -> NadaType: return _generate_accessor(self.types[index], accessor) - """TODO this should be deleted and use MetaType.to_mir""" - - # def to_mir(self): - # return { - # "NTuple": { - # "types": [ - # ( - # ty.to_mir() - # if isinstance(ty, (NadaType, ArrayMetaType, TupleMetaType)) - # else ty.class_to_mir() - # ) - # for ty in self.types - # ] - # } - # } - def metatype(self): return NTupleMetaType(self.types) @@ -319,15 +261,20 @@ def store_in_ast(self, ty: object): ) -@dataclass class ObjectMetaType(MetaType): """Marker type for Objects.""" + is_compound = True - types: Dict[str, Any] + def __init__(self, types: Dict[str, MetaType]): + self.types = types def to_mir(self): """Convert an object into a Nada type.""" - return {"Object": {name: ty.to_mir() for name, ty in self.types.items()}} + return { + "Object": { + "types": { name: ty.to_mir() for name, ty in self.types.items() } + } + } def instantiate(self, child): return Object(child, self.types) @@ -351,7 +298,7 @@ def new(cls, values: Dict[str, Any]) -> "Object": return Object( types=types, child=ObjectNew( - child=types, + child=values, source_ref=SourceRef.back_frame(), ), ) @@ -370,22 +317,6 @@ def __getattr__(self, attr: str) -> NadaType: return _generate_accessor(self.types[attr], accessor) - """TODO delete this use Meta.to_mir""" - - # def to_mir(self): - # return { - # "Object": { - # "types": { - # name: ( - # ty.to_mir() - # if isinstance(ty, (NadaType, ArrayMetaType, TupleMetaType)) - # else ty.class_to_mir() - # ) - # for name, ty in self.types.items() - # } - # } - # } - def metatype(self): return ObjectMetaType(types=self.types) @@ -480,16 +411,23 @@ def store_in_ast(self, ty: NadaTypeRepr): ty=ty, ) - -@dataclass class ArrayMetaType(MetaType): """Marker type for arrays.""" + is_compound = True - contained_type: AllTypesType - size: int + + def __init__(self, contained_type: AllTypesType, size: int): + self.contained_type = contained_type + self.size = size def to_mir(self): """Convert this generic type into a MIR Nada type.""" + # TODO size is None when array used in function argument and used @nada_fn + # So you know the type but not the size, we should stop using @nada_fn decorator + # and apply the same logic when the function gets passed to .map() or .reduce() + # so we now the size of the array + if self.size is None: + raise NotImplementedError("ArrayMetaType.to_mir") size = {"size": self.size} if self.size else {} return { "Array": {"inner_type": self.contained_type.to_mir(), **size} # TODO: why? @@ -520,16 +458,7 @@ class Array(Generic[T], NadaType): size: int def __init__(self, child, size: int, contained_type: T = None): - self.contained_type = ( - contained_type if (child is None or contained_type is not None) else child - ) - - # TODO: can we simplify the following 10 lines? - # If it's not a metatype, fetch it - if self.contained_type is not None and not isinstance( - self.contained_type, MetaType - ): - self.contained_type = self.contained_type.metatype() + self.contained_type = contained_type or child.metatype() self.size = size self.child = ( @@ -543,11 +472,14 @@ def __iter__(self): "Cannot loop over a Nada Array, use functional style Array operations (map, reduce, zip)." ) + def check_not_constant(self, ty): + if ty.is_constant: + raise NotAllowedException("functors (map and reduce) can't be called with constant args") + def map(self: "Array[T]", function) -> "Array": """The map operation for Arrays.""" - nada_function = function - if not isinstance(function, NadaFunction): - nada_function = nada_fn(function) + self.check_not_constant(self.contained_type) + nada_function = create_nada_fn(function, args_ty=[self.contained_type]) return Array( size=self.size, contained_type=nada_function.return_type, @@ -556,9 +488,10 @@ def map(self: "Array[T]", function) -> "Array": def reduce(self: "Array[T]", function, initial: R) -> R: """The Reduce operation for arrays.""" - if not isinstance(function, NadaFunction): - function = nada_fn(function) - return function.return_type( + self.check_not_constant(self.contained_type) + self.check_not_constant(initial.metatype()) + function = create_nada_fn(function, args_ty=[initial.metatype(), self.contained_type]) + return function.return_type.instantiate( Reduce( child=self, fn=function, @@ -601,12 +534,6 @@ def inner_product(self: "Array[T]", other: "Array[T]") -> T: "Inner product is only implemented for arrays of integer types" ) - # TODO delete - - # def to_mir(self): - # size = {"size": self.size} if self.size else {} - # return {"Array": {"inner_type": self.contained_type, **size}} - @classmethod def new(cls, *args) -> "Array[T]": """Constructs a new Array.""" @@ -618,7 +545,7 @@ def new(cls, *args) -> "Array[T]": raise TypeError("All arguments must be of the same type") return Array( - contained_type=first_arg, + contained_type=first_arg.metatype(), size=len(args), child=ArrayNew( child=args, @@ -626,11 +553,6 @@ def new(cls, *args) -> "Array[T]": ), ) - @classmethod - def init_as_template_type(cls, contained_type) -> "Array[T]": - """Construct an empty template array with the given child type.""" - return Array(child=None, contained_type=contained_type, size=None) - def metatype(self): return ArrayMetaType(self.contained_type, self.size) diff --git a/nada_dsl/nada_types/function.py b/nada_dsl/nada_types/function.py index 7d8ba7f..28d1c3b 100644 --- a/nada_dsl/nada_types/function.py +++ b/nada_dsl/nada_types/function.py @@ -35,7 +35,7 @@ def __init__(self, function_id: int, name: str, arg_type: T, source_ref: SourceR self.name = name self.type = arg_type self.source_ref = source_ref - self.store_in_ast(arg_type.metatype().to_mir()) + self.store_in_ast(arg_type.to_mir()) def store_in_ast(self, ty): """Store object in AST.""" @@ -53,8 +53,6 @@ class NadaFunction(Generic[T, R]): Represents a Nada Function. Nada functions are special types of functions that are used in map / reduce operations. - - They are decorated using the `@nada_fn` decorator. """ id: int @@ -72,20 +70,6 @@ def __init__( source_ref: SourceRef, child: NadaType, ): - if issubclass(return_type, ScalarType) and return_type.mode == Mode.CONSTANT: - raise NotAllowedException( - "Nada functions with literal return types are not allowed" - ) - # Nada functions with literal argument types are not supported. - # This is because the compiler consolidates operations between literals. - if all( - issubclass(arg.type.__class__, ScalarType) - and arg.type.mode == Mode.CONSTANT - for arg in args - ): - raise NotAllowedException( - "Nada functions with literal argument types are not allowed" - ) self.child = child self.id = function_id self.args = args @@ -101,7 +85,7 @@ def store_in_ast(self): name=self.function.__name__, args=[arg.id for arg in self.args], id=self.id, - ty=self.return_type.metatype().to_mir(), + ty=self.return_type.to_mir(), source_ref=self.source_ref, child=self.child.child.id, ) @@ -137,21 +121,7 @@ def store_in_ast(self, ty): ty=ty, ) - -def contained_types(ty): - """Utility function that calculates the child type for a function argument.""" - - origin_ty = getattr(ty, "__origin__", ty) - if not issubclass(origin_ty, ScalarType): - inner_ty = getattr(ty, "__args__", None) - inner_ty = contained_types(inner_ty[0]) if inner_ty else T - return origin_ty.init_as_template_type(inner_ty) - if origin_ty.mode == Mode.CONSTANT: - return origin_ty(value=0) - return origin_ty(child=None) - - -def nada_fn(fn, args_ty=None, return_ty=None) -> NadaFunction[T, R]: +def create_nada_fn(fn, args_ty) -> NadaFunction[T, R]: """ Can be used also for lambdas ```python @@ -165,28 +135,21 @@ def nada_fn(fn, args_ty=None, return_ty=None) -> NadaFunction[T, R]: args = inspect.getfullargspec(fn) nada_args = [] function_id = next_operation_id() - for arg in args.args: - arg_type = args_ty[arg] if args_ty else args.annotations[arg] - arg_type = contained_types(arg_type) + nada_args_type_wrapped = [] + for arg, arg_ty in zip(args.args, args_ty): # We'll get the function source ref for now nada_arg = NadaFunctionArg( function_id, name=arg, - arg_type=arg_type, + arg_type=arg_ty, source_ref=SourceRef.back_frame(), ) nada_args.append(nada_arg) - - nada_args_type_wrapped = [] - - for arg in nada_args: - arg_type = copy(arg.type) - arg_type.child = arg - nada_args_type_wrapped.append(arg_type) + nada_args_type_wrapped.append(arg_ty.instantiate(nada_arg)) child = fn(*nada_args_type_wrapped) - return_type = return_ty if return_ty else args.annotations["return"] + return_type = child.metatype() return NadaFunction( function_id, function=fn, diff --git a/nada_dsl/nada_types/scalar_types.py b/nada_dsl/nada_types/scalar_types.py index 4f8553a..964ac37 100644 --- a/nada_dsl/nada_types/scalar_types.py +++ b/nada_dsl/nada_types/scalar_types.py @@ -1,7 +1,7 @@ # pylint:disable=W0401,W0614 """The Nada Scalar type definitions.""" -from abc import ABC +from abc import ABC, abstractmethod from dataclasses import dataclass from typing import Union, TypeVar from typing_extensions import Self @@ -349,16 +349,24 @@ def binary_logical_operation( return SecretBoolean(child=operation) -@dataclass class MetaType(ABC): - pass + is_constant = False + is_scalar = False + is_compound = False + + @abstractmethod + def to_mir(self): + pass + + @abstractmethod + def instantiate(self, child_or_value): + pass -@dataclass class MetaTypePassthroughMixin(MetaType): @classmethod def instantiate(cls, child_or_value): - cls.ty(child_or_value) + return cls.ty(child_or_value) @classmethod def to_mir(cls): @@ -398,7 +406,8 @@ def metatype(cls): class IntegerMetaType(MetaTypePassthroughMixin): ty = Integer - + is_constant = True + is_scalar = True @dataclass @register_scalar_type(Mode.CONSTANT, BaseType.UNSIGNED_INTEGER) @@ -432,6 +441,8 @@ def metatype(cls): class UnsignedIntegerMetaType(MetaTypePassthroughMixin): ty = UnsignedInteger + is_constant = True + is_scalar = True @register_scalar_type(Mode.CONSTANT, BaseType.BOOLEAN) @@ -471,6 +482,8 @@ def metatype(cls): class BooleanMetaType(MetaTypePassthroughMixin): ty = Boolean + is_constant = True + is_scalar = True @register_scalar_type(Mode.PUBLIC, BaseType.INTEGER) @@ -499,6 +512,7 @@ def metatype(cls): class PublicIntegerMetaType(MetaTypePassthroughMixin): ty = PublicInteger + is_scalar = True @register_scalar_type(Mode.PUBLIC, BaseType.UNSIGNED_INTEGER) @@ -527,6 +541,7 @@ def metatype(cls): class PublicUnsignedIntegerMetaType(MetaTypePassthroughMixin): ty = PublicUnsignedInteger + is_scalar = True @dataclass @@ -560,6 +575,7 @@ def metatype(cls): class PublicBooleanMetaType(MetaTypePassthroughMixin): ty = PublicBoolean + is_scalar = True @dataclass @@ -613,6 +629,7 @@ def metatype(cls): class SecretIntegerMetaType(MetaTypePassthroughMixin): ty = SecretInteger + is_scalar = True @dataclass @@ -668,6 +685,7 @@ def metatype(cls): class SecretUnsignedIntegerMetaType(MetaTypePassthroughMixin): ty = SecretUnsignedInteger + is_scalar = True @dataclass @@ -702,6 +720,7 @@ def metatype(cls): class SecretBooleanMetaType(MetaTypePassthroughMixin): ty = SecretBoolean + is_scalar = True @dataclass diff --git a/test-programs/map_simple.py b/test-programs/map_simple.py index bc72c11..2b7b33c 100644 --- a/test-programs/map_simple.py +++ b/test-programs/map_simple.py @@ -6,7 +6,6 @@ def nada_main(): my_array_1 = Array(SecretInteger(Input(name="my_array_1", party=party1)), size=3) my_int = SecretInteger(Input(name="my_int", party=party1)) - @nada_fn def inc(a: SecretInteger) -> SecretInteger: return a + my_int diff --git a/test-programs/nada_fn_literal.py b/test-programs/nada_fn_literal.py deleted file mode 100644 index f7258ad..0000000 --- a/test-programs/nada_fn_literal.py +++ /dev/null @@ -1,12 +0,0 @@ -from nada_dsl import * - - -def nada_main(): - party1 = Party(name="Party1") - - @nada_fn - def add(a: Integer, b: Integer) -> Integer: - return a + b - - new_int = add(Integer(2), Integer(-5)) - return [Output(new_int, "my_output", party1)] diff --git a/test-programs/nada_fn_simple.py b/test-programs/nada_fn_simple.py deleted file mode 100644 index cf53bbf..0000000 --- a/test-programs/nada_fn_simple.py +++ /dev/null @@ -1,14 +0,0 @@ -from nada_dsl import * - - -def nada_main(): - party1 = Party(name="Party1") - my_int1 = SecretInteger(Input(name="my_int1", party=party1)) - my_int2 = SecretInteger(Input(name="my_int2", party=party1)) - - @nada_fn - def add(a: SecretInteger, b: SecretInteger) -> SecretInteger: - return a + b - - new_int = add(my_int1, my_int2) - return [Output(new_int, "my_output", party1)] diff --git a/test-programs/ntuple_accessor.py b/test-programs/ntuple_accessor.py index 9701830..a5e9e2a 100644 --- a/test-programs/ntuple_accessor.py +++ b/test-programs/ntuple_accessor.py @@ -15,10 +15,16 @@ def nada_main(): array = tup[1] scalar2 = tup[2] - @nada_fn - def add(a: PublicInteger) -> PublicInteger: - return a + my_int2 + def add(acc: PublicInteger, a: PublicInteger) -> PublicInteger: + return a + acc - result = array.reduce(add, Integer(0)) + result = array.reduce(add, my_int1) - return [Output(scalar + scalar2 + result, "my_output", party1)] + scalar_sum = scalar + scalar2 + + final = result + scalar_sum + + return [Output(final, "my_output", party1)] + +if __name__ == "__main__": + nada_main() \ No newline at end of file diff --git a/test-programs/object_accessor.py b/test-programs/object_accessor.py index 0258b8e..0f5679f 100644 --- a/test-programs/object_accessor.py +++ b/test-programs/object_accessor.py @@ -9,16 +9,16 @@ def nada_main(): array = Array.new(my_int1, my_int1) # Store a scalar, a compound type and a literal. - object = Object.new({"a": my_int1, "b": array, "c": Integer(42)}) + object = Object.new({"a": my_int1, "b": array, "c": my_int2}) scalar = object.a array = object.b - literal = object.c + scalar_2 = object.c - @nada_fn - def add(a: PublicInteger) -> PublicInteger: - return a + my_int2 - sum = array.reduce(add, Integer(0)) + def add(acc: PublicInteger, a: PublicInteger) -> PublicInteger: + return acc + a - return [Output(scalar + literal + sum, "my_output", party1)] + sum = array.reduce(add, my_int2) + + return [Output(scalar + scalar_2 + sum, "my_output", party1)] diff --git a/tests/compile_test.py b/tests/compile_test.py index 225e94f..3045d29 100644 --- a/tests/compile_test.py +++ b/tests/compile_test.py @@ -29,14 +29,6 @@ def get_test_programs_folder(): return this_directory + "../test-programs/" -def test_compile_nada_fn_simple(): - mir_str = compile_script(f"{get_test_programs_folder()}/nada_fn_simple.py").mir - assert mir_str != "" - mir = json.loads(mir_str) - mir_functions = mir["functions"] - assert len(mir_functions) == 1 - - def test_compile_sum_integers(): mir_str = compile_script(f"{get_test_programs_folder()}/sum_integers.py").mir assert mir_str != "" @@ -88,11 +80,9 @@ def nada_main(): my_int1 = SecretInteger(Input(name="my_int1", party=party1)) my_int2 = SecretInteger(Input(name="my_int2", party=party1)) - @nada_fn def add(a: SecretInteger, b: SecretInteger) -> SecretInteger: return a + b - @nada_fn def add_times(a: SecretInteger, b: SecretInteger) -> SecretInteger: return a * add(a, b) @@ -104,33 +94,6 @@ def add_times(a: SecretInteger, b: SecretInteger) -> SecretInteger: compile_string(encoded_program_str) -# TODO recursive programs fail with `NameError` for now. This is incorrect. -def test_compile_program_with_recursion(): - program_str = """from nada_dsl import * - -def nada_main(): - party1 = Party(name="Party1") - my_int1 = SecretInteger(Input(name="my_int1", party=party1)) - my_int2 = SecretInteger(Input(name="my_int2", party=party1)) - - @nada_fn - def add_times(a: SecretInteger, b: SecretInteger) -> SecretInteger: - return a * add_times(a, b) - - new_int = add_times(my_int1, my_int2) - return [Output(new_int, "my_output", party1)] -""" - encoded_program_str = base64.b64encode(bytes(program_str, "utf-8")).decode("utf_8") - - with pytest.raises(NameError): - compile_string(encoded_program_str) - - -def test_compile_nada_fn_literals(): - with pytest.raises(NotAllowedException): - mir_str = compile_script(f"{get_test_programs_folder()}/nada_fn_literal.py").mir - - def test_compile_map_simple(): mir_str = compile_script(f"{get_test_programs_folder()}/map_simple.py").mir assert mir_str != "" diff --git a/tests/compiler_frontend_test.py b/tests/compiler_frontend_test.py index b53d0f7..1a53c8a 100644 --- a/tests/compiler_frontend_test.py +++ b/tests/compiler_frontend_test.py @@ -31,7 +31,7 @@ ) from nada_dsl.nada_types import AllTypes, Party from nada_dsl.nada_types.collections import Array, Tuple, NTuple, Object, unzip -from nada_dsl.nada_types.function import NadaFunctionArg, NadaFunctionCall, nada_fn +from nada_dsl.nada_types.function import NadaFunctionArg, NadaFunctionCall, create_nada_fn @pytest.fixture(autouse=True) @@ -197,7 +197,6 @@ def test_unzip(input_type: type[Array]): ], ) def test_map(input_type, input_name): - @nada_fn def nada_function(a: SecretInteger) -> SecretInteger: return a + a @@ -227,7 +226,6 @@ def nada_function(a: SecretInteger) -> SecretInteger: def test_reduce(input_type: type[Array]): c = create_input(SecretInteger, "c", "party", **{}) - @nada_fn def nada_function(a: SecretInteger, b: SecretInteger) -> SecretInteger: return a + b @@ -262,193 +260,6 @@ def check_nada_function_arg_ref(arg_ref, function_id, name, ty): assert arg_ref["NadaFunctionArgRef"]["type"] == ty -def nada_function_to_mir(function_name: str): - nada_function: NadaFunctionASTOperation = find_function_in_ast(function_name) - assert isinstance(nada_function, NadaFunctionASTOperation) - fn_ops = {} - traverse_and_process_operations(nada_function.child, fn_ops, {}) - return nada_function.to_mir(fn_ops) - - -def test_nada_function_simple(): - @nada_fn - def nada_function(a: SecretInteger, b: SecretInteger) -> SecretInteger: - return a + b - - nada_function = nada_function_to_mir("nada_function") - assert nada_function["function"] == "nada_function" - args = nada_function["args"] - assert len(args) == 2 - check_arg(args[0], "a", "SecretInteger") - check_arg(args[1], "b", "SecretInteger") - assert nada_function["return_type"] == "SecretInteger" - - operations = nada_function["operations"] - return_op = operations[nada_function["return_operation_id"]] - assert list(return_op.keys()) == ["Addition"] - addition = return_op["Addition"] - - check_nada_function_arg_ref( - operations[addition["left"]], nada_function["id"], "a", "SecretInteger" - ) - check_nada_function_arg_ref( - operations[addition["right"]], nada_function["id"], "b", "SecretInteger" - ) - - -def test_nada_function_using_inputs(): - c = create_input(SecretInteger, "c", "party", **{}) - - @nada_fn - def nada_function(a: SecretInteger, b: SecretInteger) -> SecretInteger: - return a + b + c - - nada_function = nada_function_to_mir("nada_function") - assert nada_function["function"] == "nada_function" - args = nada_function["args"] - assert len(args) == 2 - check_arg(args[0], "a", "SecretInteger") - check_arg(args[1], "b", "SecretInteger") - assert nada_function["return_type"] == "SecretInteger" - - operation = nada_function["operations"] - return_op = operation[nada_function["return_operation_id"]] - - assert list(return_op.keys()) == ["Addition"] - addition = return_op["Addition"] - addition_right = operation[addition["right"]] - assert input_reference(addition_right) == "c" - addition_left = operation[addition["left"]] - assert list(addition_left.keys()) == ["Addition"] - - addition = addition_left["Addition"] - - check_nada_function_arg_ref( - operation[addition["left"]], nada_function["id"], "a", "SecretInteger" - ) - check_nada_function_arg_ref( - operation[addition["right"]], nada_function["id"], "b", "SecretInteger" - ) - - -def test_nada_function_call(): - - c = create_input(SecretInteger, "c", "party", **{}) - d = create_input(SecretInteger, "c", "party", **{}) - - @nada_fn - def nada_function(a: SecretInteger, b: SecretInteger) -> SecretInteger: - return a + b - - nada_fn_call_return = nada_function(c, d) - nada_fn_type = nada_function_to_mir("nada_function") - - nada_function_call = nada_fn_call_return.child - assert isinstance(nada_function_call, NadaFunctionCall) - assert nada_function_call.fn.id == nada_fn_type["id"] - - -def test_nada_function_using_operations(): - - c = create_input(SecretInteger, "c", "party", **{}) - d = create_input(SecretInteger, "d", "party", **{}) - - @nada_fn - def nada_function(a: SecretInteger, b: SecretInteger) -> SecretInteger: - return a + b + c + d - - nada_function_ast = nada_function_to_mir("nada_function") - assert nada_function_ast["function"] == "nada_function" - args = nada_function_ast["args"] - assert len(args) == 2 - check_arg(args[0], "a", "SecretInteger") - check_arg(args[1], "b", "SecretInteger") - assert nada_function_ast["return_type"] == "SecretInteger" - - operation = nada_function_ast["operations"] - return_op = operation[nada_function_ast["return_operation_id"]] - - assert list(return_op.keys()) == ["Addition"] - addition = return_op["Addition"] - - assert input_reference(operation[addition["right"]]) == "d" - addition_left = operation[addition["left"]] - assert list(addition_left.keys()) == ["Addition"] - addition = addition_left["Addition"] - assert input_reference(operation[addition["right"]]) == "c" - - addition_left = operation[addition["left"]] - assert list(addition_left.keys()) == ["Addition"] - addition = addition_left["Addition"] - - check_nada_function_arg_ref( - operation[addition["left"]], nada_function_ast["id"], "a", "SecretInteger" - ) - check_nada_function_arg_ref( - operation[addition["right"]], nada_function_ast["id"], "b", "SecretInteger" - ) - - -def find_function_in_ast(fn_name: str): - for op in AST_OPERATIONS.values(): - if isinstance(op, NadaFunctionASTOperation) and op.name == fn_name: - return op - return None - - -@pytest.mark.parametrize( - ("input_type", "input_name"), - [ - (Array, "Array"), - ], -) -def test_nada_function_using_matrix(input_type, input_name): - c = create_input(SecretInteger, "c", "party", **{}) - - @nada_fn - def add(a: SecretInteger, b: SecretInteger) -> SecretInteger: - return a + b - - @nada_fn - def matrix_addition( - a: input_type[SecretInteger], b: input_type[SecretInteger] - ) -> SecretInteger: - return a.zip(b).map(add).reduce(add, c) - - add_fn = nada_function_to_mir("add") - matrix_addition_fn = nada_function_to_mir("matrix_addition") - assert matrix_addition_fn["function"] == "matrix_addition" - args = matrix_addition_fn["args"] - assert len(args) == 2 - a_arg_type = {input_name: {"inner_type": "SecretInteger"}} - check_arg(args[0], "a", a_arg_type) - b_arg_type = {input_name: {"inner_type": "SecretInteger"}} - check_arg(args[1], "b", b_arg_type) - assert matrix_addition_fn["return_type"] == "SecretInteger" - - operations = matrix_addition_fn["operations"] - return_op = operations[matrix_addition_fn["return_operation_id"]] - assert list(return_op.keys()) == ["Reduce"] - reduce_op = return_op["Reduce"] - reduce_op["function_id"] = add_fn["id"] - reduce_op["type"] = "SecretInteger" - - reduce_op_inner = operations[reduce_op["inner"]] - assert list(reduce_op_inner.keys()) == ["Map"] - map_op = reduce_op_inner["Map"] - map_op["function_id"] = add_fn["id"] - map_op["type"] = {input_name: {"inner_type": "SecretInteger", "size": None}} - - map_op_inner = operations[map_op["inner"]] - assert list(map_op_inner.keys()) == ["Zip"] - zip_op = map_op_inner["Zip"] - - zip_op_left = operations[zip_op["left"]] - zip_op_right = operations[zip_op["right"]] - check_nada_function_arg_ref(zip_op_left, matrix_addition_fn["id"], "a", a_arg_type) - check_nada_function_arg_ref(zip_op_right, matrix_addition_fn["id"], "b", b_arg_type) - - def test_array_new(): first_input = create_input(SecretInteger, "first", "party", **{}) second_input = create_input(SecretInteger, "second", "party", **{})