diff --git a/funml/expressions.py b/funml/expressions.py index 4d805a6..15cf7b9 100644 --- a/funml/expressions.py +++ b/funml/expressions.py @@ -1,11 +1,12 @@ """Assigning variables and literals """ -from typing import Any +from collections.abc import Callable +from typing import Any, Union from funml.types import to_expn, Expression -def val(v: Any) -> Expression: +def val(v: Union[Expression, Callable, Any]) -> Expression: """Converts a generic value or lambda expression into a functional expression. This is useful when one needs to use piping on a non-ml function or diff --git a/funml/types.py b/funml/types.py index ec70a7a..d09fcba 100644 --- a/funml/types.py +++ b/funml/types.py @@ -4,18 +4,12 @@ import functools from collections.abc import Awaitable, Callable from inspect import signature, Parameter, Signature -from typing import Any, Optional, List, Tuple, Generic, Union, TypeVar - -from typing_extensions import ParamSpec +from typing import Any, Optional, List, Tuple, Union from funml import errors from funml.utils import is_equal_or_of_type -R = TypeVar("R") -P = ParamSpec("P") - - class MLType: """An ML-enabled type, that can easily be used in pattern matching, piping etc. @@ -23,8 +17,8 @@ class MLType: """ def generate_case( - self, do: "Operation"[P, R] - ) -> Tuple[Callable[[Any], bool], "Expression"[P, R]]: + self, do: "Operation" + ) -> Tuple[Callable[[Any], bool], "Expression"]: """Generates a case statement for pattern matching. Args: @@ -48,7 +42,7 @@ def _is_like(self, other: Any) -> bool: raise NotImplemented("_is_like not implemented") -class Pipeline(Generic[P, R]): +class Pipeline: """A series of logic blocks that operate on the same data in sequence. This has internal state so it is not be used in such stuff as recursion. @@ -60,8 +54,13 @@ def __init__(self): self._is_terminated = False def __rshift__( - self, nxt: Union["Expression"[P, R], Callable[P, R], "Pipeline"[P, R]] - ) -> Union["Pipeline"[P, R], R]: + self, + nxt: Union[ + "Expression", + Callable, + "Pipeline", + ], + ) -> Union["Pipeline", Any]: """Uses `>>` to append the nxt expression, callable, pipeline to this pipeline. Args: @@ -82,7 +81,7 @@ def __rshift__( def __call__( self, *args: Any, **kwargs: Any - ) -> Union[R, "Expression"[P, R], Awaitable[R], Awaitable["Expression"[P, R]]]: + ) -> Union["Expression", Awaitable["Expression"], Awaitable[Any], Any,]: """Computes the logic within the pipeline and returns the value. This method runs all those expressions in the queue sequentially, @@ -122,14 +121,21 @@ def __copy__(self): new_pipeline._is_terminated = self._is_terminated return new_pipeline - def __as_async(self) -> "AsyncPipeline"[P, R]: + def __as_async(self) -> "AsyncPipeline": """Creates an async pipeline from this pipeline""" pipe = AsyncPipeline() pipe._queue = [*self._queue] pipe._is_terminated = self._is_terminated return pipe - def __update_queue(self, nxt): + def __update_queue( + self, + nxt: Union[ + "Expression", + Callable, + "Pipeline", + ], + ): """Appends a pipeline or an expression to the queue.""" if self._is_terminated: raise ValueError("a terminated pipeline cannot be extended.") @@ -149,7 +155,7 @@ class AsyncPipeline(Pipeline): See more details in the [base class](funml.types.Pipeline) """ - async def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Expression"[P, R]]: + async def __call__(self, *args: Any, **kwargs: Any) -> Union["Expression", Any]: """Computes the logic within the pipeline and returns the value. This method runs all those expressions in the queue sequentially, @@ -178,7 +184,46 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Expression"[P, return output -class Expression(Generic[P, R]): +class Operation: + """A computation. + + Args: + func: the logic to run as part of the operation. + """ + + def __init__(self, func: Callable): + self.__signature = _get_func_signature(func) + self.__args_length = _get_non_variable_args_length(self.__signature) + + if len(self.__signature.parameters) == 0: + # be more fault tolerant by using variable params + self.__f = lambda *args, **kwargs: func() + else: + self.__f = func + + def __call__(self, *args: Any, **kwargs: Any) -> Union["Operation", Any]: + """Applies the logic attached to this operation and returns output. + + Args: + args: the args passed + kwargs: the context in which the operation is being run. + + Returns: + the final output of the operation's logic code or a partial operation if the args and kwargs \ + provided are less than those expected. + """ + try: + args_length = _get_num_of_relevant_args(self.__signature, *args, **kwargs) + if args_length < self.__args_length: + return Operation(func=functools.partial(self.__f, *args, **kwargs)) + except TypeError: + # binding is impossible so just use the default implementation + pass + + return self.__f(*args, **kwargs) + + +class Expression: """Logic that returns a value when applied. This is the basic building block of all functions and thus @@ -188,10 +233,10 @@ class Expression(Generic[P, R]): f: the operation or logic to run as part of this expression """ - def __init__(self, f: Optional["Operation"[P, R]] = None): + def __init__(self, f: Optional["Operation"] = None): self._f = f if f is not None else Operation(lambda x, *args, **kwargs: x) - def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Expression"[P, R]]: + def __call__(self, *args: Any, **kwargs: Any) -> Union["Expression", Any]: """Computes the logic within and returns the value. Args: @@ -208,8 +253,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Expression"[P, R]]: return value def __rshift__( - self, nxt: Union["Expression"[P, R], "Pipeline"[P, R], Callable[P, R]] - ) -> Union["Pipeline"[P, R], R]: + self, + nxt: Union[ + "Expression", + "Pipeline", + Callable, + ], + ) -> Union["Pipeline", Any]: """This makes piping using the '>>' symbol possible. Combines with the given `nxt` expression or pipeline to produce a new pipeline @@ -252,12 +302,10 @@ class MatchExpression(Expression): def __init__(self, arg: Optional[Any] = None): super().__init__() - self._matches: List[Tuple[Callable[[Any], bool], Expression[..., P]]] = [] + self._matches: List[Tuple[Callable[[Any], bool], Expression[T]]] = [] self.__arg = arg - def case( - self, pattern: Union[MLType, Any], do: Callable[P, R] - ) -> "MatchExpression"[P, R]: + def case(self, pattern: Union[MLType, Any], do: Callable) -> "MatchExpression": """Adds a case to a match statement. This is chainable, allowing multiple cases to be added to the same @@ -279,7 +327,7 @@ def case( self.__add_match(check=check, expn=expn) return self - def __add_match(self, check: Callable[[Any], bool], expn: "Expression"[P, R]): + def __add_match(self, check: Callable[[Any], bool], expn: "Expression"): """Adds a match set to the list of match sets A match set comprises a checker function and an expression. @@ -300,7 +348,7 @@ def __add_match(self, check: Callable[[Any], bool], expn: "Expression"[P, R]): self._matches.append((check, expn)) - def __call__(self, arg: Optional[Any] = None) -> Union[R, "Expression"[P, R]]: + def __call__(self, arg: Optional[Any] = None) -> Union["Expression", Any]: """Applies the matched case and returns the output. The match cases are surveyed for any that matches the given argument @@ -326,45 +374,6 @@ def __call__(self, arg: Optional[Any] = None) -> Union[R, "Expression"[P, R]]: raise errors.MatchError(arg) -class Operation(Generic[P, R]): - """A computation. - - Args: - func: the logic to run as part of the operation. - """ - - def __init__(self, func: Callable[P, R]): - self.__signature = _get_func_signature(func) - self.__args_length = _get_non_variable_args_length(self.__signature) - - if len(self.__signature.parameters) == 0: - # be more fault tolerant by using variable params - self.__f = lambda *args, **kwargs: func() - else: - self.__f = func - - def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Operation"[P, R]]: - """Applies the logic attached to this operation and returns output. - - Args: - args: the args passed - kwargs: the context in which the operation is being run. - - Returns: - the final output of the operation's logic code or a partial operation if the args and kwargs \ - provided are less than those expected. - """ - try: - args_length = _get_num_of_relevant_args(self.__signature, *args, **kwargs) - if args_length < self.__args_length: - return Operation(func=functools.partial(self.__f, *args, **kwargs)) - except TypeError: - # binding is impossible so just use the default implementation - pass - - return self.__f(*args, **kwargs) - - def _get_num_of_relevant_args(sig: Signature, *args, **kwargs) -> int: """Computes the number of args and kwargs relevant to the signature @@ -400,7 +409,7 @@ def _get_func_signature(func: Callable): return signature(func.__call__) -def to_expn(v: Union["Expression"[P, R], Callable[P, R], R]) -> "Expression"[P, R]: +def to_expn(v: Union["Expression", Callable, Any]) -> "Expression": """Converts a Callable or Expression into an Expression""" if isinstance(v, Expression): return v