Skip to content

Commit

Permalink
Remove typing annotations on Types
Browse files Browse the repository at this point in the history
  • Loading branch information
Tinitto committed Feb 13, 2023
1 parent dda9524 commit fb42ea8
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 69 deletions.
5 changes: 3 additions & 2 deletions funml/expressions.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Assigning variables and literals
"""
from typing import Any
from collections.abc import Callable
from typing import Any, Union

from funml.types import to_expn, Expression


def val(v: Any) -> Expression:
def val(v: Union[Expression, Callable, Any]) -> Expression:
"""Converts a generic value or lambda expression into a functional expression.
This is useful when one needs to use piping on a non-ml function or
Expand Down
143 changes: 76 additions & 67 deletions funml/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,21 @@
import functools
from collections.abc import Awaitable, Callable
from inspect import signature, Parameter, Signature
from typing import Any, Optional, List, Tuple, Generic, Union, TypeVar

from typing_extensions import ParamSpec
from typing import Any, Optional, List, Tuple, Union

from funml import errors
from funml.utils import is_equal_or_of_type


R = TypeVar("R")
P = ParamSpec("P")


class MLType:
"""An ML-enabled type, that can easily be used in pattern matching, piping etc.
Methods common to ML-enabled types are defined in this class.
"""

def generate_case(
self, do: "Operation"[P, R]
) -> Tuple[Callable[[Any], bool], "Expression"[P, R]]:
self, do: "Operation"
) -> Tuple[Callable[[Any], bool], "Expression"]:
"""Generates a case statement for pattern matching.
Args:
Expand All @@ -48,7 +42,7 @@ def _is_like(self, other: Any) -> bool:
raise NotImplemented("_is_like not implemented")


class Pipeline(Generic[P, R]):
class Pipeline:
"""A series of logic blocks that operate on the same data in sequence.
This has internal state so it is not be used in such stuff as recursion.
Expand All @@ -60,8 +54,13 @@ def __init__(self):
self._is_terminated = False

def __rshift__(
self, nxt: Union["Expression"[P, R], Callable[P, R], "Pipeline"[P, R]]
) -> Union["Pipeline"[P, R], R]:
self,
nxt: Union[
"Expression",
Callable,
"Pipeline",
],
) -> Union["Pipeline", Any]:
"""Uses `>>` to append the nxt expression, callable, pipeline to this pipeline.
Args:
Expand All @@ -82,7 +81,7 @@ def __rshift__(

def __call__(
self, *args: Any, **kwargs: Any
) -> Union[R, "Expression"[P, R], Awaitable[R], Awaitable["Expression"[P, R]]]:
) -> Union["Expression", Awaitable["Expression"], Awaitable[Any], Any,]:
"""Computes the logic within the pipeline and returns the value.
This method runs all those expressions in the queue sequentially,
Expand Down Expand Up @@ -122,14 +121,21 @@ def __copy__(self):
new_pipeline._is_terminated = self._is_terminated
return new_pipeline

def __as_async(self) -> "AsyncPipeline"[P, R]:
def __as_async(self) -> "AsyncPipeline":
"""Creates an async pipeline from this pipeline"""
pipe = AsyncPipeline()
pipe._queue = [*self._queue]
pipe._is_terminated = self._is_terminated
return pipe

def __update_queue(self, nxt):
def __update_queue(
self,
nxt: Union[
"Expression",
Callable,
"Pipeline",
],
):
"""Appends a pipeline or an expression to the queue."""
if self._is_terminated:
raise ValueError("a terminated pipeline cannot be extended.")
Expand All @@ -149,7 +155,7 @@ class AsyncPipeline(Pipeline):
See more details in the [base class](funml.types.Pipeline)
"""

async def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Expression"[P, R]]:
async def __call__(self, *args: Any, **kwargs: Any) -> Union["Expression", Any]:
"""Computes the logic within the pipeline and returns the value.
This method runs all those expressions in the queue sequentially,
Expand Down Expand Up @@ -178,7 +184,46 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Expression"[P,
return output


class Expression(Generic[P, R]):
class Operation:
"""A computation.
Args:
func: the logic to run as part of the operation.
"""

def __init__(self, func: Callable):
self.__signature = _get_func_signature(func)
self.__args_length = _get_non_variable_args_length(self.__signature)

if len(self.__signature.parameters) == 0:
# be more fault tolerant by using variable params
self.__f = lambda *args, **kwargs: func()
else:
self.__f = func

def __call__(self, *args: Any, **kwargs: Any) -> Union["Operation", Any]:
"""Applies the logic attached to this operation and returns output.
Args:
args: the args passed
kwargs: the context in which the operation is being run.
Returns:
the final output of the operation's logic code or a partial operation if the args and kwargs \
provided are less than those expected.
"""
try:
args_length = _get_num_of_relevant_args(self.__signature, *args, **kwargs)
if args_length < self.__args_length:
return Operation(func=functools.partial(self.__f, *args, **kwargs))
except TypeError:
# binding is impossible so just use the default implementation
pass

return self.__f(*args, **kwargs)


class Expression:
"""Logic that returns a value when applied.
This is the basic building block of all functions and thus
Expand All @@ -188,10 +233,10 @@ class Expression(Generic[P, R]):
f: the operation or logic to run as part of this expression
"""

def __init__(self, f: Optional["Operation"[P, R]] = None):
def __init__(self, f: Optional["Operation"] = None):
self._f = f if f is not None else Operation(lambda x, *args, **kwargs: x)

def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Expression"[P, R]]:
def __call__(self, *args: Any, **kwargs: Any) -> Union["Expression", Any]:
"""Computes the logic within and returns the value.
Args:
Expand All @@ -208,8 +253,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Expression"[P, R]]:
return value

def __rshift__(
self, nxt: Union["Expression"[P, R], "Pipeline"[P, R], Callable[P, R]]
) -> Union["Pipeline"[P, R], R]:
self,
nxt: Union[
"Expression",
"Pipeline",
Callable,
],
) -> Union["Pipeline", Any]:
"""This makes piping using the '>>' symbol possible.
Combines with the given `nxt` expression or pipeline to produce a new pipeline
Expand Down Expand Up @@ -252,12 +302,10 @@ class MatchExpression(Expression):

def __init__(self, arg: Optional[Any] = None):
super().__init__()
self._matches: List[Tuple[Callable[[Any], bool], Expression[..., P]]] = []
self._matches: List[Tuple[Callable[[Any], bool], Expression[T]]] = []
self.__arg = arg

def case(
self, pattern: Union[MLType, Any], do: Callable[P, R]
) -> "MatchExpression"[P, R]:
def case(self, pattern: Union[MLType, Any], do: Callable) -> "MatchExpression":
"""Adds a case to a match statement.
This is chainable, allowing multiple cases to be added to the same
Expand All @@ -279,7 +327,7 @@ def case(
self.__add_match(check=check, expn=expn)
return self

def __add_match(self, check: Callable[[Any], bool], expn: "Expression"[P, R]):
def __add_match(self, check: Callable[[Any], bool], expn: "Expression"):
"""Adds a match set to the list of match sets
A match set comprises a checker function and an expression.
Expand All @@ -300,7 +348,7 @@ def __add_match(self, check: Callable[[Any], bool], expn: "Expression"[P, R]):

self._matches.append((check, expn))

def __call__(self, arg: Optional[Any] = None) -> Union[R, "Expression"[P, R]]:
def __call__(self, arg: Optional[Any] = None) -> Union["Expression", Any]:
"""Applies the matched case and returns the output.
The match cases are surveyed for any that matches the given argument
Expand All @@ -326,45 +374,6 @@ def __call__(self, arg: Optional[Any] = None) -> Union[R, "Expression"[P, R]]:
raise errors.MatchError(arg)


class Operation(Generic[P, R]):
"""A computation.
Args:
func: the logic to run as part of the operation.
"""

def __init__(self, func: Callable[P, R]):
self.__signature = _get_func_signature(func)
self.__args_length = _get_non_variable_args_length(self.__signature)

if len(self.__signature.parameters) == 0:
# be more fault tolerant by using variable params
self.__f = lambda *args, **kwargs: func()
else:
self.__f = func

def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Operation"[P, R]]:
"""Applies the logic attached to this operation and returns output.
Args:
args: the args passed
kwargs: the context in which the operation is being run.
Returns:
the final output of the operation's logic code or a partial operation if the args and kwargs \
provided are less than those expected.
"""
try:
args_length = _get_num_of_relevant_args(self.__signature, *args, **kwargs)
if args_length < self.__args_length:
return Operation(func=functools.partial(self.__f, *args, **kwargs))
except TypeError:
# binding is impossible so just use the default implementation
pass

return self.__f(*args, **kwargs)


def _get_num_of_relevant_args(sig: Signature, *args, **kwargs) -> int:
"""Computes the number of args and kwargs relevant to the signature
Expand Down Expand Up @@ -400,7 +409,7 @@ def _get_func_signature(func: Callable):
return signature(func.__call__)


def to_expn(v: Union["Expression"[P, R], Callable[P, R], R]) -> "Expression"[P, R]:
def to_expn(v: Union["Expression", Callable, Any]) -> "Expression":
"""Converts a Callable or Expression into an Expression"""
if isinstance(v, Expression):
return v
Expand Down

0 comments on commit fb42ea8

Please sign in to comment.