Skip to content

Commit

Permalink
feat: Add support for Python builtin sum (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
navasvarela authored Sep 24, 2024
1 parent 124026e commit efbedc2
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 0 deletions.
42 changes: 42 additions & 0 deletions nada_dsl/compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,48 @@ def test_compile_nada_fn_simple():
assert len(mir_functions) == 1


def test_compile_sum_integers():
mir_str = compile_script(f"{get_test_programs_folder()}/sum_integers.py").mir
assert mir_str != ""
mir = json.loads(mir_str)
# The MIR operations look like this:
# - 2 InputReference
# - 1 LiteralReference for the initial accumulator
# - 2 Additions, one for the first input reference and the literal,
# the other addition of this addition and the other input reference
literal_id = 0
input_ids = []
additions = {}
for operation in mir["operations"].values():
for name, op in operation.items():
op_id = op["id"]
if name == "LiteralReference":
literal_id = op_id
assert op["type"] == "Integer"
elif name == "InputReference":
input_ids.append(op_id)
elif name == "Addition":
additions[op_id] = op
else:
raise Exception(f"Unexpected operation: {name}")
assert literal_id != 0
assert len(input_ids) == 2
assert len(additions) == 2
# Now lets check that the two additions are well constructed
# left: input reference, right: literal
first_addition_found = False
# left: addition, right: input reference
second_addition_found = False
for addition in additions.values():
left_id = addition["left"]
right_id = addition["right"]
if left_id in input_ids and right_id == literal_id:
first_addition_found = True
if left_id in additions.keys() and right_id in input_ids:
second_addition_found = True
assert first_addition_found and second_addition_found


def test_compile_nada_fn_compound():
program_str = """
from nada_dsl import *
Expand Down
10 changes: 10 additions & 0 deletions nada_dsl/nada_types/scalar_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ def __ge__(self, other):
"GreaterOrEqualThan", ">=", self, other, lambda lhs, rhs: lhs >= rhs
)

def __radd__(self, other):
"""This adds support for builtin `sum` operation for numeric types."""
if isinstance(other, int):
other_type = new_scalar_type(mode=Mode.CONSTANT, base_type=self.base_type)(
other
)
return self.__add__(other_type)

return self.__add__(other)


def binary_arithmetic_operation(
operation, operator, left: ScalarType, right: ScalarType, f
Expand Down
11 changes: 11 additions & 0 deletions test-programs/sum_integers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from nada_dsl import *


def nada_main():
party1 = Party(name="Party1")
my_int1 = SecretInteger(Input(name="my_int1", party=party1))
my_int2 = SecretInteger(Input(name="my_int2", party=party1))

total = sum([my_int1, my_int2], -2)

return [Output(total, "my_output", party1)]

0 comments on commit efbedc2

Please sign in to comment.