Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add support for Python builtin sum #25

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions nada_dsl/compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,48 @@ def test_compile_nada_fn_simple():
assert len(mir_functions) == 1


def test_compile_sum_integers():
mir_str = compile_script(f"{get_test_programs_folder()}/sum_integers.py").mir
assert mir_str != ""
mir = json.loads(mir_str)
# The MIR operations look like this:
# - 2 InputReference
# - 1 LiteralReference for the initial accumulator
# - 2 Additions, one for the first input reference and the literal,
# the other addition of this addition and the other input reference
literal_id = 0
input_ids = []
additions = {}
for operation in mir["operations"].values():
for name, op in operation.items():
op_id = op["id"]
if name == "LiteralReference":
literal_id = op_id
assert op["type"] == "Integer"
elif name == "InputReference":
input_ids.append(op_id)
elif name == "Addition":
additions[op_id] = op
else:
raise Exception(f"Unexpected operation: {name}")
assert literal_id != 0
assert len(input_ids) == 2
assert len(additions) == 2
# Now lets check that the two additions are well constructed
# left: input reference, right: literal
first_addition_found = False
# left: addition, right: input reference
second_addition_found = False
for addition in additions.values():
left_id = addition["left"]
right_id = addition["right"]
if left_id in input_ids and right_id == literal_id:
first_addition_found = True
if left_id in additions.keys() and right_id in input_ids:
second_addition_found = True
assert first_addition_found and second_addition_found


def test_compile_nada_fn_compound():
program_str = """
from nada_dsl import *
Expand Down
10 changes: 10 additions & 0 deletions nada_dsl/nada_types/scalar_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,16 @@ def __ge__(self, other):
"GreaterOrEqualThan", ">=", self, other, lambda lhs, rhs: lhs >= rhs
)

def __radd__(self, other):
"""This adds support for builtin `sum` operation for numeric types."""
if isinstance(other, int):
other_type = new_scalar_type(mode=Mode.CONSTANT, base_type=self.base_type)(
other
)
return self.__add__(other_type)

return self.__add__(other)


def binary_arithmetic_operation(
operation, operator, left: ScalarType, right: ScalarType, f
Expand Down
11 changes: 11 additions & 0 deletions test-programs/sum_integers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from nada_dsl import *


def nada_main():
party1 = Party(name="Party1")
my_int1 = SecretInteger(Input(name="my_int1", party=party1))
my_int2 = SecretInteger(Input(name="my_int2", party=party1))

total = sum([my_int1, my_int2], -2)

return [Output(total, "my_output", party1)]
Loading