diff --git a/funml/types.py b/funml/types.py index d093a13..ec70a7a 100644 --- a/funml/types.py +++ b/funml/types.py @@ -1,27 +1,37 @@ """All types used by funml""" +from __future__ import annotations + import functools -from collections.abc import Awaitable +from collections.abc import Awaitable, Callable from inspect import signature, Parameter, Signature -from typing import Any, Union, Callable, Optional, List, Tuple +from typing import Any, Optional, List, Tuple, Generic, Union, TypeVar + +from typing_extensions import ParamSpec from funml import errors from funml.utils import is_equal_or_of_type +R = TypeVar("R") +P = ParamSpec("P") + + class MLType: """An ML-enabled type, that can easily be used in pattern matching, piping etc. Methods common to ML-enabled types are defined in this class. """ - def generate_case(self, do: "Operation") -> Tuple[Callable, "Expression"]: + def generate_case( + self, do: "Operation"[P, R] + ) -> Tuple[Callable[[Any], bool], "Expression"[P, R]]: """Generates a case statement for pattern matching. Args: do: The operation to do if the arg matches on this type Returns: - A tuple (checker, expn) where checker is a function that checks if argument matches this case, and expn + A tuple (checker, expn) where checker is a function that checks if argument matches this case, and expn \ is the expression that is called when the case is matched. """ raise NotImplemented("generate_case not implemented") @@ -38,7 +48,7 @@ def _is_like(self, other: Any) -> bool: raise NotImplemented("_is_like not implemented") -class Pipeline: +class Pipeline(Generic[P, R]): """A series of logic blocks that operate on the same data in sequence. This has internal state so it is not be used in such stuff as recursion. @@ -50,15 +60,15 @@ def __init__(self): self._is_terminated = False def __rshift__( - self, nxt: Union["Expression", Callable, "Pipeline"] - ) -> Union["Pipeline", Any]: + self, nxt: Union["Expression"[P, R], Callable[P, R], "Pipeline"[P, R]] + ) -> Union["Pipeline"[P, R], R]: """Uses `>>` to append the nxt expression, callable, pipeline to this pipeline. Args: nxt: the next expression, pipeline, or callable to apply after the current one. Returns: - the updated pipeline or the value when the pipeline is executed in case `nxt` is of + the updated pipeline or the value when the pipeline is executed in case `nxt` is of \ type `ExecutionExpression` Raises: @@ -70,7 +80,9 @@ def __rshift__( return self - def __call__(self, *args: Any, **kwargs: Any) -> Any: + def __call__( + self, *args: Any, **kwargs: Any + ) -> Union[R, "Expression"[P, R], Awaitable[R], Awaitable["Expression"[P, R]]]: """Computes the logic within the pipeline and returns the value. This method runs all those expressions in the queue sequentially, @@ -82,7 +94,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: kwargs: any key-word arguments passed Returns: - the computed output of this pipeline. + the computed output of this pipeline or a partial expression if the args and kwargs provided are less than those expected. """ output = None queue = self._queue[:-1] if self._is_terminated else self._queue @@ -110,7 +122,7 @@ def __copy__(self): new_pipeline._is_terminated = self._is_terminated return new_pipeline - def __as_async(self) -> "AsyncPipeline": + def __as_async(self) -> "AsyncPipeline"[P, R]: """Creates an async pipeline from this pipeline""" pipe = AsyncPipeline() pipe._queue = [*self._queue] @@ -137,7 +149,7 @@ class AsyncPipeline(Pipeline): See more details in the [base class](funml.types.Pipeline) """ - async def __call__(self, *args: Any, **kwargs: Any) -> Any: + async def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Expression"[P, R]]: """Computes the logic within the pipeline and returns the value. This method runs all those expressions in the queue sequentially, @@ -149,7 +161,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any: kwargs: any key-word arguments passed Returns: - the computed output of this pipeline. + the computed output of this pipeline or a partial expression if the args and kwargs provided \ + are less than expected. """ output = None queue = self._queue[:-1] if self._is_terminated else self._queue @@ -165,7 +178,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any: return output -class Expression: +class Expression(Generic[P, R]): """Logic that returns a value when applied. This is the basic building block of all functions and thus @@ -175,10 +188,10 @@ class Expression: f: the operation or logic to run as part of this expression """ - def __init__(self, f: Optional["Operation"] = None): + def __init__(self, f: Optional["Operation"[P, R]] = None): self._f = f if f is not None else Operation(lambda x, *args, **kwargs: x) - def __call__(self, *args: Any, **kwargs: Any) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Expression"[P, R]]: """Computes the logic within and returns the value. Args: @@ -186,13 +199,17 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: kwargs: any key-word arguments passed Returns: - the computed output of this expression. + the computed output of this expression or another expression with a partial operation \ + in it if the args provided are less than expected. """ - return self._f(*args, **kwargs) + value = self._f(*args, **kwargs) + if isinstance(value, Operation): + return Expression(f=value) + return value def __rshift__( - self, nxt: Union["Expression", "Pipeline", Callable] - ) -> Union["Pipeline", Any]: + self, nxt: Union["Expression"[P, R], "Pipeline"[P, R], Callable[P, R]] + ) -> Union["Pipeline"[P, R], R]: """This makes piping using the '>>' symbol possible. Combines with the given `nxt` expression or pipeline to produce a new pipeline @@ -202,7 +219,7 @@ def __rshift__( nxt: the next expression, pipeline, or callable to apply after the current one. Returns: - a new pipeline where the first expression is the current expression followed by `nxt` + a new pipeline where the first expression is the current expression followed by `nxt` \ or returns the value when the pipeline is executed in case `nxt` is of type `ExecutionExpression` """ new_pipeline = Pipeline() @@ -235,10 +252,12 @@ class MatchExpression(Expression): def __init__(self, arg: Optional[Any] = None): super().__init__() - self._matches: List[Tuple[Callable, Expression]] = [] + self._matches: List[Tuple[Callable[[Any], bool], Expression[..., P]]] = [] self.__arg = arg - def case(self, pattern: Union[MLType, Any], do: Callable) -> "MatchExpression": + def case( + self, pattern: Union[MLType, Any], do: Callable[P, R] + ) -> "MatchExpression"[P, R]: """Adds a case to a match statement. This is chainable, allowing multiple cases to be added to the same @@ -260,7 +279,7 @@ def case(self, pattern: Union[MLType, Any], do: Callable) -> "MatchExpression": self.__add_match(check=check, expn=expn) return self - def __add_match(self, check: Callable, expn: "Expression"): + def __add_match(self, check: Callable[[Any], bool], expn: "Expression"[P, R]): """Adds a match set to the list of match sets A match set comprises a checker function and an expression. @@ -281,7 +300,7 @@ def __add_match(self, check: Callable, expn: "Expression"): self._matches.append((check, expn)) - def __call__(self, arg: Optional[Any] = None) -> Any: + def __call__(self, arg: Optional[Any] = None) -> Union[R, "Expression"[P, R]]: """Applies the matched case and returns the output. The match cases are surveyed for any that matches the given argument @@ -307,14 +326,14 @@ def __call__(self, arg: Optional[Any] = None) -> Any: raise errors.MatchError(arg) -class Operation: +class Operation(Generic[P, R]): """A computation. Args: func: the logic to run as part of the operation. """ - def __init__(self, func: Callable): + def __init__(self, func: Callable[P, R]): self.__signature = _get_func_signature(func) self.__args_length = _get_non_variable_args_length(self.__signature) @@ -324,7 +343,7 @@ def __init__(self, func: Callable): else: self.__f = func - def __call__(self, *args: Any, **kwargs: Any) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Operation"[P, R]]: """Applies the logic attached to this operation and returns output. Args: @@ -332,7 +351,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: kwargs: the context in which the operation is being run. Returns: - the final output of the operation's logic code. + the final output of the operation's logic code or a partial operation if the args and kwargs \ + provided are less than those expected. """ try: args_length = _get_num_of_relevant_args(self.__signature, *args, **kwargs) @@ -380,7 +400,7 @@ def _get_func_signature(func: Callable): return signature(func.__call__) -def to_expn(v: Union["Expression", Callable, Any]) -> "Expression": +def to_expn(v: Union["Expression"[P, R], Callable[P, R], R]) -> "Expression"[P, R]: """Converts a Callable or Expression into an Expression""" if isinstance(v, Expression): return v diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 557bcf7..f12432d 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -1,9 +1,7 @@ -import functools - import pytest from funml import val -from funml.types import Operation +from funml.types import Expression def test_val_literals(): @@ -19,9 +17,9 @@ def test_val_expressions(): fn = val(min) >> str test_data = [ ([2, 6, 8], "2"), - # ([2, -12, 8], "-12"), - # ([20, 6, 18], "6"), - # ([0.2, 6.0, 0.08], "0.08"), + ([2, -12, 8], "-12"), + ([20, 6, 18], "6"), + ([0.2, 6.0, 0.08], "0.08"), ] for v, expected in test_data: @@ -91,7 +89,7 @@ def test_currying(): assert add_2_to_1_or_2_more(40) == 42 assert add_2_to_1_or_2_more(20, 3) == 25 - assert isinstance(add_2_to_1_or_2_more(), Operation) + assert isinstance(add_2_to_1_or_2_more(), Expression) with pytest.raises(TypeError): # raise error if many args are provided @@ -99,7 +97,7 @@ def test_currying(): assert add_2_to_2_or_3_more(15, 3) == 20 assert add_2_to_2_or_3_more(15, 3, -4) == 16 - assert isinstance(add_2_to_2_or_3_more(15), Operation) + assert isinstance(add_2_to_2_or_3_more(15), Expression) with pytest.raises(TypeError): # raise error if many args are provided