Skip to content

Commit

Permalink
chore: more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jmgr committed Nov 22, 2024
1 parent f249951 commit 0822056
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 74 deletions.
3 changes: 1 addition & 2 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -452,8 +452,7 @@ timeout-methods=requests.api.delete,requests.api.get,requests.api.head,requests.

# List of note tags to take in consideration, separated by a comma.
notes=FIXME,
XXX,
TODO
XXX

# Regular expression of note tags to take in consideration.
notes-rgx=
Expand Down
17 changes: 6 additions & 11 deletions nada_dsl/nada_types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
"""Nada type definitions."""

from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, TypeAlias, Union, Type
from typing import Dict, TypeAlias, Union, Type
from nada_dsl.source_ref import SourceRef
from abc import abstractmethod


@dataclass
Expand Down Expand Up @@ -147,15 +147,6 @@ def __init__(self, child: OperationType):
if self.child is not None:
self.child.store_in_ast(self.metatype().to_mir())

# @classmethod
# def class_to_mir(cls) -> str:
# """Converts a class into a MIR Nada type."""
# name = cls.__name__
# # Rename public variables so they are considered as the same as literals.
# if name.startswith("Public"):
# name = name[len("Public") :].lstrip()
# return name

def __bool__(self):
raise NotImplementedError

Expand All @@ -168,3 +159,7 @@ def is_scalar(cls) -> bool:
def is_literal(cls) -> bool:
"""Returns True if the type is a literal."""
return False

@abstractmethod
def metatype(self):
"""Returns a meta type for this NadaType."""
53 changes: 30 additions & 23 deletions nada_dsl/nada_types/collections.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
"""Nada Collection type definitions."""

import copy
from dataclasses import dataclass
import inspect
import traceback
from typing import Any, Dict, Generic, List
import typing
from typing import TypeVar

from nada_dsl.ast_util import (
AST_OPERATIONS,
Expand Down Expand Up @@ -114,14 +110,15 @@ def store_in_ast(self, ty):

class TupleMetaType(MetaType):
"""Marker type for Tuples."""

is_compound = True

def __init__(self, left_type: MetaType, right_type: MetaType):
self.left_type = left_type
self.right_type = right_type

def instantiate(self, child):
return Tuple(child, self.left_type, self.right_type)
def instantiate(self, child_or_value):
return Tuple(child_or_value, self.left_type, self.right_type)

def to_mir(self):
"""Convert a tuple object into a Nada type."""
Expand Down Expand Up @@ -164,6 +161,7 @@ def generic_type(cls, left_type: U, right_type: T) -> TupleMetaType:
return TupleMetaType(left_type=left_type, right_type=right_type)

def metatype(self):
"""Metatype for Tuple"""
return TupleMetaType(self.left_type, self.right_type)


Expand All @@ -175,13 +173,14 @@ def _generate_accessor(ty: Any, accessor: Any) -> NadaType:

class NTupleMetaType(MetaType):
"""Marker type for NTuples."""

is_compound = True

def __init__(self, types: List[MetaType]):
self.types = types

def instantiate(self, child):
return NTuple(child, self.types)
def instantiate(self, child_or_value):
return NTuple(child_or_value, self.types)

def to_mir(self):
"""Convert a tuple object into a Nada type."""
Expand Down Expand Up @@ -228,6 +227,7 @@ def __getitem__(self, index: int) -> NadaType:
return _generate_accessor(self.types[index], accessor)

def metatype(self):
"""Metatype for NTuple"""
return NTupleMetaType(self.types)


Expand Down Expand Up @@ -263,6 +263,7 @@ def store_in_ast(self, ty: object):

class ObjectMetaType(MetaType):
"""Marker type for Objects."""

is_compound = True

def __init__(self, types: Dict[str, MetaType]):
Expand All @@ -271,13 +272,11 @@ def __init__(self, types: Dict[str, MetaType]):
def to_mir(self):
"""Convert an object into a Nada type."""
return {
"Object": {
"types": { name: ty.to_mir() for name, ty in self.types.items() }
}
"Object": {"types": {name: ty.to_mir() for name, ty in self.types.items()}}
}

def instantiate(self, child):
return Object(child, self.types)
def instantiate(self, child_or_value):
return Object(child_or_value, self.types)


@dataclass
Expand Down Expand Up @@ -318,6 +317,7 @@ def __getattr__(self, attr: str) -> NadaType:
return _generate_accessor(self.types[attr], accessor)

def metatype(self):
"""Metatype for Object"""
return ObjectMetaType(types=self.types)


Expand Down Expand Up @@ -411,10 +411,11 @@ def store_in_ast(self, ty: NadaTypeRepr):
ty=ty,
)


class ArrayMetaType(MetaType):
"""Marker type for arrays."""
is_compound = True

is_compound = True

def __init__(self, contained_type: AllTypesType, size: int):
self.contained_type = contained_type
Expand All @@ -428,13 +429,15 @@ def to_mir(self):
# so we now the size of the array
if self.size is None:
raise NotImplementedError("ArrayMetaType.to_mir")
size = {"size": self.size} if self.size else {}
return {
"Array": {"inner_type": self.contained_type.to_mir(), **size} # TODO: why?
"Array": {
"inner_type": self.contained_type.to_mir(),
"size": self.size,
}
}

def instantiate(self, child):
return Array(child, self.size, self.contained_type)
def instantiate(self, child_or_value):
return Array(child_or_value, self.size, self.contained_type)


@dataclass
Expand Down Expand Up @@ -473,8 +476,11 @@ def __iter__(self):
)

def check_not_constant(self, ty):
"""Checks that a type is not a constant"""
if ty.is_constant:
raise NotAllowedException("functors (map and reduce) can't be called with constant args")
raise NotAllowedException(
"functors (map and reduce) can't be called with constant args"
)

def map(self: "Array[T]", function) -> "Array":
"""The map operation for Arrays."""
Expand All @@ -490,7 +496,9 @@ def reduce(self: "Array[T]", function, initial: R) -> R:
"""The Reduce operation for arrays."""
self.check_not_constant(self.contained_type)
self.check_not_constant(initial.metatype())
function = create_nada_fn(function, args_ty=[initial.metatype(), self.contained_type])
function = create_nada_fn(
function, args_ty=[initial.metatype(), self.contained_type]
)
return function.return_type.instantiate(
Reduce(
child=self,
Expand Down Expand Up @@ -525,9 +533,7 @@ def inner_product(self: "Array[T]", other: "Array[T]") -> T:
):

return self.contained_type.instantiate(
child=InnerProduct(
left=self, right=other, source_ref=SourceRef.back_frame()
)
InnerProduct(left=self, right=other, source_ref=SourceRef.back_frame())
) # type: ignore

raise InvalidTypeError(
Expand All @@ -554,6 +560,7 @@ def new(cls, *args) -> "Array[T]":
)

def metatype(self):
"""Metatype for Array"""
return ArrayMetaType(self.contained_type, self.size)


Expand Down
6 changes: 2 additions & 4 deletions nada_dsl/nada_types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import inspect
from dataclasses import dataclass
from typing import Generic, List, Callable
from copy import copy
from nada_dsl import SourceRef
from nada_dsl.ast_util import (
AST_OPERATIONS,
Expand All @@ -16,9 +15,7 @@
next_operation_id,
)
from nada_dsl.nada_types.generics import T, R
from nada_dsl.nada_types import Mode, NadaType
from nada_dsl.nada_types.scalar_types import ScalarType
from nada_dsl.errors import NotAllowedException
from nada_dsl.nada_types import NadaType


class NadaFunctionArg(Generic[T]):
Expand Down Expand Up @@ -121,6 +118,7 @@ def store_in_ast(self, ty):
ty=ty,
)


def create_nada_fn(fn, args_ty) -> NadaFunction[T, R]:
"""
Can be used also for lambdas
Expand Down
Loading

0 comments on commit 0822056

Please sign in to comment.