diff --git a/nada_dsl/compile_test.py b/nada_dsl/compile_test.py index f7c2079..a9e57e8 100644 --- a/nada_dsl/compile_test.py +++ b/nada_dsl/compile_test.py @@ -36,6 +36,48 @@ def test_compile_nada_fn_simple(): assert len(mir_functions) == 1 +def test_compile_sum_integers(): + mir_str = compile_script(f"{get_test_programs_folder()}/sum_integers.py").mir + assert mir_str != "" + mir = json.loads(mir_str) + # The MIR operations look like this: + # - 2 InputReference + # - 1 LiteralReference for the initial accumulator + # - 2 Additions, one for the first input reference and the literal, + # the other addition of this addition and the other input reference + literal_id = 0 + input_ids = [] + additions = {} + for operation in mir["operations"].values(): + for name, op in operation.items(): + op_id = op["id"] + if name == "LiteralReference": + literal_id = op_id + assert op["type"] == "Integer" + elif name == "InputReference": + input_ids.append(op_id) + elif name == "Addition": + additions[op_id] = op + else: + raise Exception(f"Unexpected operation: {name}") + assert literal_id != 0 + assert len(input_ids) == 2 + assert len(additions) == 2 + # Now lets check that the two additions are well constructed + # left: input reference, right: literal + first_addition_found = False + # left: addition, right: input reference + second_addition_found = False + for addition in additions.values(): + left_id = addition["left"] + right_id = addition["right"] + if left_id in input_ids and right_id == literal_id: + first_addition_found = True + if left_id in additions.keys() and right_id in input_ids: + second_addition_found = True + assert first_addition_found and second_addition_found + + def test_compile_nada_fn_compound(): program_str = """ from nada_dsl import * diff --git a/nada_dsl/nada_types/scalar_types.py b/nada_dsl/nada_types/scalar_types.py index a458c15..f8509b7 100644 --- a/nada_dsl/nada_types/scalar_types.py +++ b/nada_dsl/nada_types/scalar_types.py @@ -161,6 +161,16 @@ def __ge__(self, other): "GreaterOrEqualThan", ">=", self, other, lambda lhs, rhs: lhs >= rhs ) + def __radd__(self, other): + """This adds support for builtin `sum` operation for numeric types.""" + if isinstance(other, int): + other_type = new_scalar_type(mode=Mode.CONSTANT, base_type=self.base_type)( + other + ) + return self.__add__(other_type) + + return self.__add__(other) + def binary_arithmetic_operation( operation, operator, left: ScalarType, right: ScalarType, f diff --git a/test-programs/sum_integers.py b/test-programs/sum_integers.py new file mode 100644 index 0000000..9d3c079 --- /dev/null +++ b/test-programs/sum_integers.py @@ -0,0 +1,11 @@ +from nada_dsl import * + + +def nada_main(): + party1 = Party(name="Party1") + my_int1 = SecretInteger(Input(name="my_int1", party=party1)) + my_int2 = SecretInteger(Input(name="my_int2", party=party1)) + + total = sum([my_int1, my_int2], -2) + + return [Output(total, "my_output", party1)]