Skip to content

Commit

Permalink
chore: fix test and delete @nada_fn
Browse files Browse the repository at this point in the history
  • Loading branch information
lumasepa committed Nov 22, 2024
1 parent e2635fa commit f249951
Show file tree
Hide file tree
Showing 10 changed files with 92 additions and 435 deletions.
158 changes: 40 additions & 118 deletions nada_dsl/nada_types/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
InvalidTypeError,
NotAllowedException,
)
from nada_dsl.nada_types.function import NadaFunction, nada_fn
from nada_dsl.nada_types.function import NadaFunction, create_nada_fn
from nada_dsl.nada_types.generics import U, T, R
from . import AllTypes, AllTypesType, NadaTypeRepr, OperationType

Expand Down Expand Up @@ -112,12 +112,13 @@ def store_in_ast(self, ty):
)


@dataclass
class TupleMetaType(MetaType):
"""Marker type for Tuples."""
is_compound = True

left_type: NadaType
right_type: NadaType
def __init__(self, left_type: MetaType, right_type: MetaType):
self.left_type = left_type
self.right_type = right_type

def instantiate(self, child):
return Tuple(child, self.left_type, self.right_type)
Expand Down Expand Up @@ -145,29 +146,6 @@ def __init__(self, child, left_type: T, right_type: U):
self.child = child
super().__init__(self.child)

"""TODO this should be deleted and use MetaType.to_mir"""

# def to_mir(self):
# return {
# "Tuple": {
# "left_type": (
# self.left_type.to_mir()
# if isinstance(
# self.left_type, (NadaType, ArrayMetaType, TupleMetaType)
# )
# else self.left_type.class_to_mir()
# ),
# "right_type": (
# self.right_type.to_mir()
# if isinstance(
# self.right_type,
# (NadaType, ArrayMetaType, TupleMetaType),
# )
# else self.right_type.class_to_mir()
# ),
# }
# }

@classmethod
def new(cls, left_value: NadaType, right_value: NadaType) -> "Tuple[T, U]":
"""Constructs a new Tuple."""
Expand All @@ -193,34 +171,14 @@ def _generate_accessor(ty: Any, accessor: Any) -> NadaType:
if hasattr(ty, "ty") and ty.ty.is_literal(): # TODO: fix
raise TypeError("Literals are not supported in accessors")
return ty.instantiate(accessor)
# if ty.is_scalar():
# if ty.is_literal():
# return ty # value.instantiate(child=accessor) ?
# return ty(child=accessor)
# if ty == Array:
# return Array(
# child=accessor,
# contained_type=ty.contained_type,
# size=ty.size,
# )
# if ty == NTuple:
# return NTuple(
# child=accessor,
# types=ty.types,
# )
# if ty == Object:
# return Object(
# child=accessor,
# types=ty.types,
# )
# raise TypeError(f"Unsupported type for accessor: {ty}")


@dataclass
class NTupleMetaType(MetaType):
"""Marker type for NTuples."""
is_compound = True

types: List[NadaType]
def __init__(self, types: List[MetaType]):
self.types = types

def instantiate(self, child):
return NTuple(child, self.types)
Expand Down Expand Up @@ -269,22 +227,6 @@ def __getitem__(self, index: int) -> NadaType:

return _generate_accessor(self.types[index], accessor)

"""TODO this should be deleted and use MetaType.to_mir"""

# def to_mir(self):
# return {
# "NTuple": {
# "types": [
# (
# ty.to_mir()
# if isinstance(ty, (NadaType, ArrayMetaType, TupleMetaType))
# else ty.class_to_mir()
# )
# for ty in self.types
# ]
# }
# }

def metatype(self):
return NTupleMetaType(self.types)

Expand Down Expand Up @@ -319,15 +261,20 @@ def store_in_ast(self, ty: object):
)


@dataclass
class ObjectMetaType(MetaType):
"""Marker type for Objects."""
is_compound = True

types: Dict[str, Any]
def __init__(self, types: Dict[str, MetaType]):
self.types = types

def to_mir(self):
"""Convert an object into a Nada type."""
return {"Object": {name: ty.to_mir() for name, ty in self.types.items()}}
return {
"Object": {
"types": { name: ty.to_mir() for name, ty in self.types.items() }
}
}

def instantiate(self, child):
return Object(child, self.types)
Expand All @@ -351,7 +298,7 @@ def new(cls, values: Dict[str, Any]) -> "Object":
return Object(
types=types,
child=ObjectNew(
child=types,
child=values,
source_ref=SourceRef.back_frame(),
),
)
Expand All @@ -370,22 +317,6 @@ def __getattr__(self, attr: str) -> NadaType:

return _generate_accessor(self.types[attr], accessor)

"""TODO delete this use Meta.to_mir"""

# def to_mir(self):
# return {
# "Object": {
# "types": {
# name: (
# ty.to_mir()
# if isinstance(ty, (NadaType, ArrayMetaType, TupleMetaType))
# else ty.class_to_mir()
# )
# for name, ty in self.types.items()
# }
# }
# }

def metatype(self):
return ObjectMetaType(types=self.types)

Expand Down Expand Up @@ -480,16 +411,23 @@ def store_in_ast(self, ty: NadaTypeRepr):
ty=ty,
)


@dataclass
class ArrayMetaType(MetaType):
"""Marker type for arrays."""
is_compound = True

contained_type: AllTypesType
size: int

def __init__(self, contained_type: AllTypesType, size: int):
self.contained_type = contained_type
self.size = size

def to_mir(self):
"""Convert this generic type into a MIR Nada type."""
# TODO size is None when array used in function argument and used @nada_fn
# So you know the type but not the size, we should stop using @nada_fn decorator
# and apply the same logic when the function gets passed to .map() or .reduce()
# so we now the size of the array
if self.size is None:
raise NotImplementedError("ArrayMetaType.to_mir")
size = {"size": self.size} if self.size else {}
return {
"Array": {"inner_type": self.contained_type.to_mir(), **size} # TODO: why?
Expand Down Expand Up @@ -520,16 +458,7 @@ class Array(Generic[T], NadaType):
size: int

def __init__(self, child, size: int, contained_type: T = None):
self.contained_type = (
contained_type if (child is None or contained_type is not None) else child
)

# TODO: can we simplify the following 10 lines?
# If it's not a metatype, fetch it
if self.contained_type is not None and not isinstance(
self.contained_type, MetaType
):
self.contained_type = self.contained_type.metatype()
self.contained_type = contained_type or child.metatype()

self.size = size
self.child = (
Expand All @@ -543,11 +472,14 @@ def __iter__(self):
"Cannot loop over a Nada Array, use functional style Array operations (map, reduce, zip)."
)

def check_not_constant(self, ty):
if ty.is_constant:
raise NotAllowedException("functors (map and reduce) can't be called with constant args")

def map(self: "Array[T]", function) -> "Array":
"""The map operation for Arrays."""
nada_function = function
if not isinstance(function, NadaFunction):
nada_function = nada_fn(function)
self.check_not_constant(self.contained_type)
nada_function = create_nada_fn(function, args_ty=[self.contained_type])
return Array(
size=self.size,
contained_type=nada_function.return_type,
Expand All @@ -556,9 +488,10 @@ def map(self: "Array[T]", function) -> "Array":

def reduce(self: "Array[T]", function, initial: R) -> R:
"""The Reduce operation for arrays."""
if not isinstance(function, NadaFunction):
function = nada_fn(function)
return function.return_type(
self.check_not_constant(self.contained_type)
self.check_not_constant(initial.metatype())
function = create_nada_fn(function, args_ty=[initial.metatype(), self.contained_type])
return function.return_type.instantiate(
Reduce(
child=self,
fn=function,
Expand Down Expand Up @@ -601,12 +534,6 @@ def inner_product(self: "Array[T]", other: "Array[T]") -> T:
"Inner product is only implemented for arrays of integer types"
)

# TODO delete

# def to_mir(self):
# size = {"size": self.size} if self.size else {}
# return {"Array": {"inner_type": self.contained_type, **size}}

@classmethod
def new(cls, *args) -> "Array[T]":
"""Constructs a new Array."""
Expand All @@ -618,19 +545,14 @@ def new(cls, *args) -> "Array[T]":
raise TypeError("All arguments must be of the same type")

return Array(
contained_type=first_arg,
contained_type=first_arg.metatype(),
size=len(args),
child=ArrayNew(
child=args,
source_ref=SourceRef.back_frame(),
),
)

@classmethod
def init_as_template_type(cls, contained_type) -> "Array[T]":
"""Construct an empty template array with the given child type."""
return Array(child=None, contained_type=contained_type, size=None)

def metatype(self):
return ArrayMetaType(self.contained_type, self.size)

Expand Down
53 changes: 8 additions & 45 deletions nada_dsl/nada_types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, function_id: int, name: str, arg_type: T, source_ref: SourceR
self.name = name
self.type = arg_type
self.source_ref = source_ref
self.store_in_ast(arg_type.metatype().to_mir())
self.store_in_ast(arg_type.to_mir())

def store_in_ast(self, ty):
"""Store object in AST."""
Expand All @@ -53,8 +53,6 @@ class NadaFunction(Generic[T, R]):
Represents a Nada Function. Nada functions are special types of functions that are used
in map / reduce operations.
They are decorated using the `@nada_fn` decorator.
"""

id: int
Expand All @@ -72,20 +70,6 @@ def __init__(
source_ref: SourceRef,
child: NadaType,
):
if issubclass(return_type, ScalarType) and return_type.mode == Mode.CONSTANT:
raise NotAllowedException(
"Nada functions with literal return types are not allowed"
)
# Nada functions with literal argument types are not supported.
# This is because the compiler consolidates operations between literals.
if all(
issubclass(arg.type.__class__, ScalarType)
and arg.type.mode == Mode.CONSTANT
for arg in args
):
raise NotAllowedException(
"Nada functions with literal argument types are not allowed"
)
self.child = child
self.id = function_id
self.args = args
Expand All @@ -101,7 +85,7 @@ def store_in_ast(self):
name=self.function.__name__,
args=[arg.id for arg in self.args],
id=self.id,
ty=self.return_type.metatype().to_mir(),
ty=self.return_type.to_mir(),
source_ref=self.source_ref,
child=self.child.child.id,
)
Expand Down Expand Up @@ -137,21 +121,7 @@ def store_in_ast(self, ty):
ty=ty,
)


def contained_types(ty):
"""Utility function that calculates the child type for a function argument."""

origin_ty = getattr(ty, "__origin__", ty)
if not issubclass(origin_ty, ScalarType):
inner_ty = getattr(ty, "__args__", None)
inner_ty = contained_types(inner_ty[0]) if inner_ty else T
return origin_ty.init_as_template_type(inner_ty)
if origin_ty.mode == Mode.CONSTANT:
return origin_ty(value=0)
return origin_ty(child=None)


def nada_fn(fn, args_ty=None, return_ty=None) -> NadaFunction[T, R]:
def create_nada_fn(fn, args_ty) -> NadaFunction[T, R]:
"""
Can be used also for lambdas
```python
Expand All @@ -165,28 +135,21 @@ def nada_fn(fn, args_ty=None, return_ty=None) -> NadaFunction[T, R]:
args = inspect.getfullargspec(fn)
nada_args = []
function_id = next_operation_id()
for arg in args.args:
arg_type = args_ty[arg] if args_ty else args.annotations[arg]
arg_type = contained_types(arg_type)
nada_args_type_wrapped = []
for arg, arg_ty in zip(args.args, args_ty):
# We'll get the function source ref for now
nada_arg = NadaFunctionArg(
function_id,
name=arg,
arg_type=arg_type,
arg_type=arg_ty,
source_ref=SourceRef.back_frame(),
)
nada_args.append(nada_arg)

nada_args_type_wrapped = []

for arg in nada_args:
arg_type = copy(arg.type)
arg_type.child = arg
nada_args_type_wrapped.append(arg_type)
nada_args_type_wrapped.append(arg_ty.instantiate(nada_arg))

child = fn(*nada_args_type_wrapped)

return_type = return_ty if return_ty else args.annotations["return"]
return_type = child.metatype()
return NadaFunction(
function_id,
function=fn,
Expand Down
Loading

0 comments on commit f249951

Please sign in to comment.