Skip to content

Commit

Permalink
fix: compiling nada functions overrides AST operations (#33)
Browse files Browse the repository at this point in the history
  • Loading branch information
Juan M Salamanca authored Oct 4, 2024
1 parent dab55ca commit dd3a769
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 11 deletions.
9 changes: 0 additions & 9 deletions nada_dsl/ast_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
from nada_dsl.source_ref import SourceRef

OPERATION_ID_COUNTER = 0
FUNCTION_ID_COUNTER = 0


def next_operation_id() -> int:
"""Returns the next value of the operation id counter."""
Expand All @@ -19,13 +17,6 @@ def next_operation_id() -> int:
return OPERATION_ID_COUNTER


def next_function_id() -> int:
"""Returns the next value of the function id counter."""
global FUNCTION_ID_COUNTER
FUNCTION_ID_COUNTER += 1
return FUNCTION_ID_COUNTER


@dataclass
class ASTOperation(ABC):
"""AST Operations.
Expand Down
32 changes: 32 additions & 0 deletions nada_dsl/compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,35 @@ def add_times(a: SecretInteger, b: SecretInteger) -> SecretInteger:
def test_compile_nada_fn_literals():
with pytest.raises(NotAllowedException):
mir_str = compile_script(f"{get_test_programs_folder()}/nada_fn_literal.py").mir


def test_compile_map_simple():
mir_str = compile_script(f"{get_test_programs_folder()}/map_simple.py").mir
assert mir_str != ""
mir = json.loads(mir_str)
assert len(mir["operations"]) == 2
assert len(mir["functions"]) == 1
function_id = mir["functions"][0]["id"]
operations_found = 0
array_input_id = 0
map_inner = 0
output_id = mir["outputs"][0]["operation_id"]
function_op_id = 0
for operation in mir["operations"].values():
for name, op in operation.items():
op_id = op["id"]
if name == "InputReference":
array_input_id = op_id
assert op["type"] == {
"Array": {"inner_type": "SecretInteger", "size": 3}
}
operations_found += 1
elif name == "Map":
assert op["fn"] == function_id
map_inner = op["inner"]
function_op_id = op["id"]
operations_found += 1
else:
raise Exception(f"Unexpected operation: {name}")
assert map_inner > 0 and array_input_id > 0 and map_inner == array_input_id
assert function_op_id > 0 and output_id == function_op_id
3 changes: 1 addition & 2 deletions nada_dsl/nada_types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
NadaFunctionASTOperation,
NadaFunctionArgASTOperation,
NadaFunctionCallASTOperation,
next_function_id,
next_operation_id,
)
from nada_dsl.nada_types.generics import T, R
Expand Down Expand Up @@ -165,7 +164,7 @@ def nada_fn(fn, args_ty=None, return_ty=None) -> NadaFunction[T, R]:

args = inspect.getfullargspec(fn)
nada_args = []
function_id = next_function_id()
function_id = next_operation_id()
for arg in args.args:
arg_type = args_ty[arg] if args_ty else args.annotations[arg]
arg_type = inner_type(arg_type)
Expand Down
17 changes: 17 additions & 0 deletions test-programs/map_simple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from nada_dsl import *


def nada_main():
party1 = Party(name="Party1")
my_array_1 = Array(SecretInteger(Input(name="my_array_1", party=party1)), size=3)
my_int = SecretInteger(Input(name="my_int", party=party1))

@nada_fn
def inc(a: SecretInteger) -> SecretInteger:
return a + my_int

new_array = my_array_1.map(inc)

out = Output(new_array, "my_output", party1)

return [out]

0 comments on commit dd3a769

Please sign in to comment.