Skip to content

Commit

Permalink
chore: various renames and fixes (#54)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jmgr authored Nov 15, 2024
1 parent 12eec64 commit c9ad05a
Show file tree
Hide file tree
Showing 10 changed files with 394 additions and 378 deletions.
59 changes: 29 additions & 30 deletions nada_dsl/ast_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ class ASTOperation(ABC):
source_ref: SourceRef
ty: NadaTypeRepr

def inner_operations(self) -> List[int]:
"""Returns the list of identifiers of all the inner operations of this operation."""
def child_operations(self) -> List[int]:
"""Returns the list of identifiers of all the child operations of this operation."""
return []

def to_mir(self):
Expand All @@ -68,7 +68,7 @@ class BinaryASTOperation(ASTOperation):
left: int
right: int

def inner_operations(self) -> List[int]:
def child_operations(self) -> List[int]:
return [self.left, self.right]

def to_mir(self):
Expand All @@ -88,17 +88,17 @@ class UnaryASTOperation(ASTOperation):
"""Superclass of all the unary operations in AST representation"""

name: str
inner: int
child: int

def inner_operations(self):
return [self.inner]
def child_operations(self):
return [self.child]

def to_mir(self):

return {
self.name: {
"id": self.id,
"this": self.inner,
"this": self.child,
"type": self.ty,
"source_ref_index": self.source_ref.to_index(),
}
Expand All @@ -109,20 +109,20 @@ def to_mir(self):
class IfElseASTOperation(ASTOperation):
"""AST Representation of an IfElse operation."""

this: int
arg_0: int
arg_1: int
condition: int
true_branch_child: int
false_branch_child: int

def inner_operations(self):
return [self.this, self.arg_0, self.arg_1]
def child_operations(self):
return [self.condition, self.true_branch_child, self.false_branch_child]

def to_mir(self):
return {
"IfElse": {
"id": self.id,
"this": self.this,
"arg_0": self.arg_0,
"arg_1": self.arg_1,
"this": self.condition,
"arg_0": self.true_branch_child,
"arg_1": self.false_branch_child,
"type": self.ty,
"source_ref_index": self.source_ref.to_index(),
}
Expand All @@ -133,7 +133,7 @@ def to_mir(self):
class RandomASTOperation(ASTOperation):
"""AST Representation of a Random operation."""

def inner_operations(self):
def child_operations(self):
return []

def to_mir(self):
Expand Down Expand Up @@ -215,19 +215,19 @@ def to_mir(self):
class ReduceASTOperation(ASTOperation):
"""AST Representation of a Reduce operation."""

inner: int
child: int
fn: int
initial: int

def inner_operations(self):
return [self.inner, self.initial]
def child_operations(self):
return [self.child, self.initial]

def to_mir(self):
return {
"Reduce": {
"id": self.id,
"fn": self.fn,
"inner": self.inner,
"inner": self.child,
"initial": self.initial,
"type": self.ty,
"source_ref_index": self.source_ref.to_index(),
Expand All @@ -239,18 +239,18 @@ def to_mir(self):
class MapASTOperation(ASTOperation):
"""AST representation of a Map operation."""

inner: int
child: int
fn: int

def inner_operations(self):
return [self.inner]
def child_operations(self):
return [self.child]

def to_mir(self):
return {
"Map": {
"id": self.id,
"fn": self.fn,
"inner": self.inner,
"inner": self.child,
"type": self.ty,
"source_ref_index": self.source_ref.to_index(),
}
Expand All @@ -263,9 +263,8 @@ class NewASTOperation(ASTOperation):

name: str
elements: List[int]
inner_type: object

def inner_operations(self):
def child_operations(self):
return self.elements

def to_mir(self):
Expand All @@ -286,7 +285,7 @@ class NadaFunctionCallASTOperation(ASTOperation):
args: List[int]
fn: int

def inner_operations(self):
def child_operations(self):
return self.args

def to_mir(self):
Expand Down Expand Up @@ -327,7 +326,7 @@ class NadaFunctionASTOperation(ASTOperation):

name: str
args: List[int]
inner: int
child: int

# pylint: disable=arguments-differ
def to_mir(self, operations):
Expand All @@ -347,7 +346,7 @@ def to_mir(self, operations):
for arg in arg_operations
],
"function": self.name,
"return_operation_id": self.inner,
"return_operation_id": self.child,
"operations": operations,
"return_type": self.ty,
"source_ref_index": self.source_ref.to_index(),
Expand All @@ -364,7 +363,7 @@ class CastASTOperation(ASTOperation):

target: int

def inner_operations(self):
def child_operations(self):
return [self.target]

def to_mir(self):
Expand Down
6 changes: 3 additions & 3 deletions nada_dsl/compiler_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def nada_dsl_to_nada_mir(outputs: List[Output]) -> Dict[str, Any]:
timer.start(
f"nada_dsl.compiler_frontend.nada_dsl_to_nada_mir.{output.name}.process_operation"
)
out_operation_id = output.inner.inner.id
out_operation_id = output.child.child.id
extra_fns = traverse_and_process_operations(
out_operation_id, operations, FUNCTIONS
)
Expand Down Expand Up @@ -184,7 +184,7 @@ def to_mir_function_list(functions: Dict[int, NadaFunctionASTOperation]) -> List
function_operations = {}

extra_functions = traverse_and_process_operations(
function.inner,
function.child,
function_operations,
functions,
)
Expand Down Expand Up @@ -255,7 +255,7 @@ def traverse_and_process_operations(
extra_functions[wrapped_operation.extra_function.id] = (
wrapped_operation.extra_function
)
stack.extend(operation.inner_operations())
stack.extend(operation.child_operations())
return extra_functions


Expand Down
32 changes: 21 additions & 11 deletions nada_dsl/nada_types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def __init__(self, name):
"""Type alias for the NadaType representation.
This representation can be either a string ("SecretInteger")
or a dictionary (Array{inner_type=SecretInteger, size=3}).
or a dictionary (Array{contained_types=SecretInteger, size=3}).
"""


Expand Down Expand Up @@ -121,7 +121,7 @@ class NadaType:
This is the parent class of all nada types.
In Nada, all the types wrap Operations. For instance, an addition between two integers
is represented like this SecretInteger(inner=Addition(...)).
is represented like this SecretInteger(child=Addition(...)).
In MIR, the representation is based around operations. A MIR operation points to other
operations and has a return type.
Expand All @@ -130,25 +130,25 @@ class NadaType:
MIR-friendly format, as subclasses of ASTOperation.
Whenever the Python interpreter constructs an instance of NadaType, it will also store
in memory the corresponding inner operation. In order to do so, the ASTOperation will
in memory the corresponding child operation. In order to do so, the ASTOperation will
need the type in MIR format. Which is why all instances of `NadaType` provide an implementation
of `to_type()`.
of `to_mir()`.
"""

inner: OperationType
child: OperationType

def __init__(self, inner: OperationType):
def __init__(self, child: OperationType):
"""NadaType default constructor
Args:
inner (OperationType): The inner operation of this Data type
child (OperationType): The child operation of this Data type
"""
self.inner = inner
if self.inner is not None:
self.inner.store_in_ast(self.to_type())
self.child = child
if self.child is not None:
self.child.store_in_ast(self.to_mir())

def to_type(self):
def to_mir(self):
"""Default implementation for the Conversion of a type into MIR representation."""
return self.__class__.class_to_type()

Expand All @@ -163,3 +163,13 @@ def class_to_type(cls) -> str:

def __bool__(self):
raise NotImplementedError

@classmethod
def is_scalable(cls) -> bool:
"""Returns True if the type is a scalable."""
return False

@classmethod
def is_literal(cls) -> bool:
"""Returns True if the type is a literal."""
return False
Loading

0 comments on commit c9ad05a

Please sign in to comment.