Skip to content

Commit

Permalink
Improve typing annotations in types
Browse files Browse the repository at this point in the history
  • Loading branch information
Tinitto committed Feb 10, 2023
1 parent cb2d04f commit a97b09d
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 38 deletions.
80 changes: 50 additions & 30 deletions funml/types.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,37 @@
"""All types used by funml"""
from __future__ import annotations

import functools
from collections.abc import Awaitable
from collections.abc import Awaitable, Callable
from inspect import signature, Parameter, Signature
from typing import Any, Union, Callable, Optional, List, Tuple
from typing import Any, Optional, List, Tuple, Generic, Union, TypeVar

from typing_extensions import ParamSpec

from funml import errors
from funml.utils import is_equal_or_of_type


R = TypeVar("R")
P = ParamSpec("P")


class MLType:
"""An ML-enabled type, that can easily be used in pattern matching, piping etc.
Methods common to ML-enabled types are defined in this class.
"""

def generate_case(self, do: "Operation") -> Tuple[Callable, "Expression"]:
def generate_case(
self, do: "Operation"[P, R]
) -> Tuple[Callable[[Any], bool], "Expression"[P, R]]:
"""Generates a case statement for pattern matching.
Args:
do: The operation to do if the arg matches on this type
Returns:
A tuple (checker, expn) where checker is a function that checks if argument matches this case, and expn
A tuple (checker, expn) where checker is a function that checks if argument matches this case, and expn \
is the expression that is called when the case is matched.
"""
raise NotImplemented("generate_case not implemented")
Expand All @@ -38,7 +48,7 @@ def _is_like(self, other: Any) -> bool:
raise NotImplemented("_is_like not implemented")


class Pipeline:
class Pipeline(Generic[P, R]):
"""A series of logic blocks that operate on the same data in sequence.
This has internal state so it is not be used in such stuff as recursion.
Expand All @@ -50,15 +60,15 @@ def __init__(self):
self._is_terminated = False

def __rshift__(
self, nxt: Union["Expression", Callable, "Pipeline"]
) -> Union["Pipeline", Any]:
self, nxt: Union["Expression"[P, R], Callable[P, R], "Pipeline"[P, R]]
) -> Union["Pipeline"[P, R], R]:
"""Uses `>>` to append the nxt expression, callable, pipeline to this pipeline.
Args:
nxt: the next expression, pipeline, or callable to apply after the current one.
Returns:
the updated pipeline or the value when the pipeline is executed in case `nxt` is of
the updated pipeline or the value when the pipeline is executed in case `nxt` is of \
type `ExecutionExpression`
Raises:
Expand All @@ -70,7 +80,9 @@ def __rshift__(

return self

def __call__(self, *args: Any, **kwargs: Any) -> Any:
def __call__(
self, *args: Any, **kwargs: Any
) -> Union[R, "Expression"[P, R], Awaitable[R], Awaitable["Expression"[P, R]]]:
"""Computes the logic within the pipeline and returns the value.
This method runs all those expressions in the queue sequentially,
Expand All @@ -82,7 +94,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
kwargs: any key-word arguments passed
Returns:
the computed output of this pipeline.
the computed output of this pipeline or a partial expression if the args and kwargs provided are less than those expected.
"""
output = None
queue = self._queue[:-1] if self._is_terminated else self._queue
Expand Down Expand Up @@ -110,7 +122,7 @@ def __copy__(self):
new_pipeline._is_terminated = self._is_terminated
return new_pipeline

def __as_async(self) -> "AsyncPipeline":
def __as_async(self) -> "AsyncPipeline"[P, R]:
"""Creates an async pipeline from this pipeline"""
pipe = AsyncPipeline()
pipe._queue = [*self._queue]
Expand All @@ -137,7 +149,7 @@ class AsyncPipeline(Pipeline):
See more details in the [base class](funml.types.Pipeline)
"""

async def __call__(self, *args: Any, **kwargs: Any) -> Any:
async def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Expression"[P, R]]:
"""Computes the logic within the pipeline and returns the value.
This method runs all those expressions in the queue sequentially,
Expand All @@ -149,7 +161,8 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
kwargs: any key-word arguments passed
Returns:
the computed output of this pipeline.
the computed output of this pipeline or a partial expression if the args and kwargs provided \
are less than expected.
"""
output = None
queue = self._queue[:-1] if self._is_terminated else self._queue
Expand All @@ -165,7 +178,7 @@ async def __call__(self, *args: Any, **kwargs: Any) -> Any:
return output


class Expression:
class Expression(Generic[P, R]):
"""Logic that returns a value when applied.
This is the basic building block of all functions and thus
Expand All @@ -175,24 +188,28 @@ class Expression:
f: the operation or logic to run as part of this expression
"""

def __init__(self, f: Optional["Operation"] = None):
def __init__(self, f: Optional["Operation"[P, R]] = None):
self._f = f if f is not None else Operation(lambda x, *args, **kwargs: x)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Expression"[P, R]]:
"""Computes the logic within and returns the value.
Args:
args: any arguments passed.
kwargs: any key-word arguments passed
Returns:
the computed output of this expression.
the computed output of this expression or another expression with a partial operation \
in it if the args provided are less than expected.
"""
return self._f(*args, **kwargs)
value = self._f(*args, **kwargs)
if isinstance(value, Operation):
return Expression(f=value)
return value

def __rshift__(
self, nxt: Union["Expression", "Pipeline", Callable]
) -> Union["Pipeline", Any]:
self, nxt: Union["Expression"[P, R], "Pipeline"[P, R], Callable[P, R]]
) -> Union["Pipeline"[P, R], R]:
"""This makes piping using the '>>' symbol possible.
Combines with the given `nxt` expression or pipeline to produce a new pipeline
Expand All @@ -202,7 +219,7 @@ def __rshift__(
nxt: the next expression, pipeline, or callable to apply after the current one.
Returns:
a new pipeline where the first expression is the current expression followed by `nxt`
a new pipeline where the first expression is the current expression followed by `nxt` \
or returns the value when the pipeline is executed in case `nxt` is of type `ExecutionExpression`
"""
new_pipeline = Pipeline()
Expand Down Expand Up @@ -235,10 +252,12 @@ class MatchExpression(Expression):

def __init__(self, arg: Optional[Any] = None):
super().__init__()
self._matches: List[Tuple[Callable, Expression]] = []
self._matches: List[Tuple[Callable[[Any], bool], Expression[..., P]]] = []
self.__arg = arg

def case(self, pattern: Union[MLType, Any], do: Callable) -> "MatchExpression":
def case(
self, pattern: Union[MLType, Any], do: Callable[P, R]
) -> "MatchExpression"[P, R]:
"""Adds a case to a match statement.
This is chainable, allowing multiple cases to be added to the same
Expand All @@ -260,7 +279,7 @@ def case(self, pattern: Union[MLType, Any], do: Callable) -> "MatchExpression":
self.__add_match(check=check, expn=expn)
return self

def __add_match(self, check: Callable, expn: "Expression"):
def __add_match(self, check: Callable[[Any], bool], expn: "Expression"[P, R]):
"""Adds a match set to the list of match sets
A match set comprises a checker function and an expression.
Expand All @@ -281,7 +300,7 @@ def __add_match(self, check: Callable, expn: "Expression"):

self._matches.append((check, expn))

def __call__(self, arg: Optional[Any] = None) -> Any:
def __call__(self, arg: Optional[Any] = None) -> Union[R, "Expression"[P, R]]:
"""Applies the matched case and returns the output.
The match cases are surveyed for any that matches the given argument
Expand All @@ -307,14 +326,14 @@ def __call__(self, arg: Optional[Any] = None) -> Any:
raise errors.MatchError(arg)


class Operation:
class Operation(Generic[P, R]):
"""A computation.
Args:
func: the logic to run as part of the operation.
"""

def __init__(self, func: Callable):
def __init__(self, func: Callable[P, R]):
self.__signature = _get_func_signature(func)
self.__args_length = _get_non_variable_args_length(self.__signature)

Expand All @@ -324,15 +343,16 @@ def __init__(self, func: Callable):
else:
self.__f = func

def __call__(self, *args: Any, **kwargs: Any) -> Any:
def __call__(self, *args: Any, **kwargs: Any) -> Union[R, "Operation"[P, R]]:
"""Applies the logic attached to this operation and returns output.
Args:
args: the args passed
kwargs: the context in which the operation is being run.
Returns:
the final output of the operation's logic code.
the final output of the operation's logic code or a partial operation if the args and kwargs \
provided are less than those expected.
"""
try:
args_length = _get_num_of_relevant_args(self.__signature, *args, **kwargs)
Expand Down Expand Up @@ -380,7 +400,7 @@ def _get_func_signature(func: Callable):
return signature(func.__call__)


def to_expn(v: Union["Expression", Callable, Any]) -> "Expression":
def to_expn(v: Union["Expression"[P, R], Callable[P, R], R]) -> "Expression"[P, R]:
"""Converts a Callable or Expression into an Expression"""
if isinstance(v, Expression):
return v
Expand Down
14 changes: 6 additions & 8 deletions tests/test_expressions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import functools

import pytest

from funml import val
from funml.types import Operation
from funml.types import Expression


def test_val_literals():
Expand All @@ -19,9 +17,9 @@ def test_val_expressions():
fn = val(min) >> str
test_data = [
([2, 6, 8], "2"),
# ([2, -12, 8], "-12"),
# ([20, 6, 18], "6"),
# ([0.2, 6.0, 0.08], "0.08"),
([2, -12, 8], "-12"),
([20, 6, 18], "6"),
([0.2, 6.0, 0.08], "0.08"),
]

for v, expected in test_data:
Expand Down Expand Up @@ -91,15 +89,15 @@ def test_currying():

assert add_2_to_1_or_2_more(40) == 42
assert add_2_to_1_or_2_more(20, 3) == 25
assert isinstance(add_2_to_1_or_2_more(), Operation)
assert isinstance(add_2_to_1_or_2_more(), Expression)

with pytest.raises(TypeError):
# raise error if many args are provided
add_2_to_1_or_2_more(12, 45, 8)

assert add_2_to_2_or_3_more(15, 3) == 20
assert add_2_to_2_or_3_more(15, 3, -4) == 16
assert isinstance(add_2_to_2_or_3_more(15), Operation)
assert isinstance(add_2_to_2_or_3_more(15), Expression)

with pytest.raises(TypeError):
# raise error if many args are provided
Expand Down

0 comments on commit a97b09d

Please sign in to comment.