Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Jmgr committed Nov 13, 2024
1 parent ef1e3b3 commit 6b4ef11
Show file tree
Hide file tree
Showing 8 changed files with 344 additions and 30 deletions.
22 changes: 22 additions & 0 deletions nada_dsl/ast_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,25 @@ def to_mir(self):
"source_ref_index": self.source_ref.to_index(),
}
}


@dataclass
class NTupleAccessorASTOperation(ASTOperation):
"""AST representation of a n tuple accessor operation."""

index: int
source: int

def inner_operations(self):
return [self.source]

def to_mir(self):
return {
"NTupleAccessor": {
"id": self.id,
"index": self.index,
"source": self.source,
"type": self.ty,
"source_ref_index": self.source_ref.to_index(),
}
}
97 changes: 97 additions & 0 deletions nada_dsl/compile_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,100 @@ def test_compile_map_simple():
raise Exception(f"Unexpected operation: {name}")
assert map_inner > 0 and array_input_id > 0 and map_inner == array_input_id
assert function_op_id > 0 and output_id == function_op_id


def test_compile_tuple_accessor():
mir_str = compile_script(f"{get_test_programs_folder()}/tuple_accessor.py").mir
assert mir_str != ""
print(f"MIR: {mir_str}")
# mir = json.loads(mir_str)
# assert len(mir["operations"]) == 2
# assert len(mir["functions"]) == 1
# function_id = mir["functions"][0]["id"]
# operations_found = 0
# array_input_id = 0
# map_inner = 0
# output_id = mir["outputs"][0]["operation_id"]
# function_op_id = 0
# for operation in mir["operations"].values():
# for name, op in operation.items():
# op_id = op["id"]
# if name == "InputReference":
# array_input_id = op_id
# assert op["type"] == {
# "Array": {"inner_type": "SecretInteger", "size": 3}
# }
# operations_found += 1
# elif name == "Map":
# assert op["fn"] == function_id
# map_inner = op["inner"]
# function_op_id = op["id"]
# operations_found += 1
# else:
# raise Exception(f"Unexpected operation: {name}")
# assert map_inner > 0 and array_input_id > 0 and map_inner == array_input_id
# assert function_op_id > 0 and output_id == function_op_id

def test_compile_ntuple_accessor():
mir_str = compile_script(f"{get_test_programs_folder()}/ntuple_accessor.py").mir
assert mir_str != ""
print(f"MIR: {mir_str}")
# mir = json.loads(mir_str)
# assert len(mir["operations"]) == 2
# assert len(mir["functions"]) == 1
# function_id = mir["functions"][0]["id"]
# operations_found = 0
# array_input_id = 0
# map_inner = 0
# output_id = mir["outputs"][0]["operation_id"]
# function_op_id = 0
# for operation in mir["operations"].values():
# for name, op in operation.items():
# op_id = op["id"]
# if name == "InputReference":
# array_input_id = op_id
# assert op["type"] == {
# "Array": {"inner_type": "SecretInteger", "size": 3}
# }
# operations_found += 1
# elif name == "Map":
# assert op["fn"] == function_id
# map_inner = op["inner"]
# function_op_id = op["id"]
# operations_found += 1
# else:
# raise Exception(f"Unexpected operation: {name}")
# assert map_inner > 0 and array_input_id > 0 and map_inner == array_input_id
# assert function_op_id > 0 and output_id == function_op_id

def test_compile_object_accessor():
mir_str = compile_script(f"{get_test_programs_folder()}/object_accessor.py").mir
assert mir_str != ""
print(f"MIR: {mir_str}")
# mir = json.loads(mir_str)
# assert len(mir["operations"]) == 2
# assert len(mir["functions"]) == 1
# function_id = mir["functions"][0]["id"]
# operations_found = 0
# array_input_id = 0
# map_inner = 0
# output_id = mir["outputs"][0]["operation_id"]
# function_op_id = 0
# for operation in mir["operations"].values():
# for name, op in operation.items():
# op_id = op["id"]
# if name == "InputReference":
# array_input_id = op_id
# assert op["type"] == {
# "Array": {"inner_type": "SecretInteger", "size": 3}
# }
# operations_found += 1
# elif name == "Map":
# assert op["fn"] == function_id
# map_inner = op["inner"]
# function_op_id = op["id"]
# operations_found += 1
# else:
# raise Exception(f"Unexpected operation: {name}")
# assert map_inner > 0 and array_input_id > 0 and map_inner == array_input_id
# assert function_op_id > 0 and output_id == function_op_id
2 changes: 2 additions & 0 deletions nada_dsl/compiler_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
RandomASTOperation,
ReduceASTOperation,
UnaryASTOperation,
NTupleAccessorASTOperation,
)
from nada_dsl.timer import timer
from nada_dsl.source_ref import SourceRef
Expand Down Expand Up @@ -296,6 +297,7 @@ def process_operation(
NewASTOperation,
RandomASTOperation,
NadaFunctionArgASTOperation,
NTupleAccessorASTOperation,
),
):
processed_operation = ProcessOperationOutput(operation.to_mir(), None)
Expand Down
6 changes: 6 additions & 0 deletions nada_dsl/nada_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,9 @@ def class_to_type(cls) -> str:

def __bool__(self):
raise NotImplementedError

def is_scalar(self):
return False

def is_literal(self):
return False
Loading

0 comments on commit 6b4ef11

Please sign in to comment.