diff --git a/CHANGELOG.md b/CHANGELOG.md index a057025..5423937 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,16 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] +## [0.3.3] - 2023-02-09 + +### Added + +- Add ability to curry functions i.e. transform functions with multiple args into functions with fewer args + +### Changed + +### Fixed + ## [0.3.2] - 2023-02-08 ### Added diff --git a/README.md b/README.md index 08c2fbf..57af88e 100644 --- a/README.md +++ b/README.md @@ -81,11 +81,10 @@ def main(): unit = ml.val(lambda v: v) is_even = ml.val(lambda v: v % 2 == 0) mul = ml.val(lambda args: args[0] * args[1]) - superscript = ml.val(lambda num, power: num**power) + superscript = ml.val(lambda num, power=1: num**power) get_month = ml.val(lambda value: value.month) is_num = ml.val(lambda v: isinstance(v, (int, float))) is_exp = ml.val(lambda v: isinstance(v, BaseException)) - is_zero_or_less = ml.val(lambda v, *args: v <= 0) if_else = lambda check=unit, do=unit, else_do=unit: ml.val( lambda *args, **kwargs: ( ml.match(check(*args, **kwargs)) @@ -97,13 +96,13 @@ def main(): """ High Order Expressions """ - accum_factorial = if_else( - check=is_zero_or_less, - do=lambda v, ac: ac, - else_do=lambda v, ac: accum_factorial(v - 1, v * ac), + factorial = lambda v, accum=1: ( + ml.match(v <= 0) + .case(True, do=ml.val(accum)) + .case(False, do=lambda num, ac=0: factorial(num - 1, accum=num * ac)()) ) - cube = ml.val(lambda v: superscript(v, 3)) - factorial = ml.val(lambda x: accum_factorial(x, 1)) + # currying expressions is possible + cube = superscript(power=3) get_item_types = ml.ireduce(lambda x, y: f"{type(x)}, {type(y)}") nums_type_err = ml.val( lambda args: TypeError(f"expected numbers, got {get_item_types(args)}") diff --git a/docs/change-log.md b/docs/change-log.md index d1d5082..86f2ef5 100644 --- a/docs/change-log.md +++ b/docs/change-log.md @@ -7,6 +7,15 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [Unreleased] +## [0.3.3] - 2023-02-09 + +### Added + +- Add ability to curry functions i.e. transform functions with multiple args into functions with fewer args + +### Changed + +### Fixed ## [0.3.2] - 2023-02-08 diff --git a/docs/tutorial.md b/docs/tutorial.md index 5a2c07e..86099fc 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -124,11 +124,10 @@ def main(): unit = ml.val(lambda v: v) is_even = ml.val(lambda v: v % 2 == 0) mul = ml.val(lambda args: args[0] * args[1]) - superscript = ml.val(lambda num, power: num**power) + superscript = ml.val(lambda num, power=1: num**power) get_month = ml.val(lambda value: value.month) is_num = ml.val(lambda v: isinstance(v, (int, float))) is_exp = ml.val(lambda v: isinstance(v, BaseException)) - is_zero_or_less = ml.val(lambda v, *args: v <= 0) if_else = lambda check=unit, do=unit, else_do=unit: ml.val( lambda *args, **kwargs: ( ml.match(check(*args, **kwargs)) @@ -143,6 +142,7 @@ def main(): Here we combine the primitive expressions into more complex ones using: - normal function calls e.g. `if_else(some_stuff)` where `if_else` is a primitive expression +- a form of [currying](https://en.wikipedia.org/wiki/Currying) e.g. `add3 = add(3)` where `add = lambda x, y: x+y` - pipelines using the pipeline operator (`>>`). Pipelines let one start with data followed by the steps that operate on that data e.g. `output = records >> remove_nulls >> parse_json >> ml.execute()` @@ -161,13 +161,13 @@ In our `main` function in our script `main.py`, let's add the following high ord """ High Order Expressions """ - accum_factorial = if_else( - check=is_zero_or_less, - do=lambda v, ac: ac, - else_do=lambda v, ac: accum_factorial(v - 1, v * ac), + factorial = lambda v, accum=1: ( + ml.match(v <= 0) + .case(True, do=ml.val(accum)) + .case(False, do=lambda num, ac=0: factorial(num - 1, accum=num * ac)()) ) - cube = ml.val(lambda v: superscript(v, 3)) - factorial = ml.val(lambda x: accum_factorial(x, 1)) + # currying expressions is possible + cube = superscript(power=3) get_item_types = ml.ireduce(lambda x, y: f"{type(x)}, {type(y)}") nums_type_err = ml.val( lambda args: TypeError(f"expected numbers, got {get_item_types(args)}") diff --git a/docs_src/tutorial/main.py b/docs_src/tutorial/main.py index eb4362a..82ba2f5 100644 --- a/docs_src/tutorial/main.py +++ b/docs_src/tutorial/main.py @@ -36,11 +36,10 @@ def main(): unit = ml.val(lambda v: v) is_even = ml.val(lambda v: v % 2 == 0) mul = ml.val(lambda args: args[0] * args[1]) - superscript = ml.val(lambda num, power: num**power) + superscript = ml.val(lambda num, power=1: num**power) get_month = ml.val(lambda value: value.month) is_num = ml.val(lambda v: isinstance(v, (int, float))) is_exp = ml.val(lambda v: isinstance(v, BaseException)) - is_zero_or_less = ml.val(lambda v, *args: v <= 0) if_else = lambda check=unit, do=unit, else_do=unit: ml.val( lambda *args, **kwargs: ( ml.match(check(*args, **kwargs)) @@ -52,13 +51,13 @@ def main(): """ High Order Expressions """ - accum_factorial = if_else( - check=is_zero_or_less, - do=lambda v, ac: ac, - else_do=lambda v, ac: accum_factorial(v - 1, v * ac), + factorial = lambda v, accum=1: ( + ml.match(v <= 0) + .case(True, do=ml.val(accum)) + .case(False, do=lambda num, ac=0: factorial(num - 1, accum=num * ac)()) ) - cube = ml.val(lambda v: superscript(v, 3)) - factorial = ml.val(lambda x: accum_factorial(x, 1)) + # currying expressions is possible + cube = superscript(power=3) get_item_types = ml.ireduce(lambda x, y: f"{type(x)}, {type(y)}") nums_type_err = ml.val( lambda args: TypeError(f"expected numbers, got {get_item_types(args)}") diff --git a/funml/pipeline.py b/funml/pipeline.py index b2ed045..a3ca9ca 100644 --- a/funml/pipeline.py +++ b/funml/pipeline.py @@ -22,8 +22,11 @@ def execute(*args: Any, **kwargs: Any) -> ExecutionExpression: ```python import funml as ml - output = ml.val(90) >> (lambda x: x**2) >> (lambda v: v/90) >> ml.execute() - # prints 90 + to_power_of = ml.val(lambda power, v: v**power) + divided_by = ml.val(lambda divisor, v: v / divisor) + + output = ml.val(90) >> to_power_of(3) >> divided_by(90) >> divided_by(3) >> ml.execute() + # prints 2700 ``` """ return ExecutionExpression(*args, **kwargs) diff --git a/funml/types.py b/funml/types.py index 18fe2c8..d093a13 100644 --- a/funml/types.py +++ b/funml/types.py @@ -1,6 +1,7 @@ """All types used by funml""" +import functools from collections.abc import Awaitable -from inspect import signature +from inspect import signature, Parameter, Signature from typing import Any, Union, Callable, Optional, List, Tuple from funml import errors @@ -314,8 +315,10 @@ class Operation: """ def __init__(self, func: Callable): - sig = _get_func_signature(func) - if len(sig.parameters) == 0: + self.__signature = _get_func_signature(func) + self.__args_length = _get_non_variable_args_length(self.__signature) + + if len(self.__signature.parameters) == 0: # be more fault tolerant by using variable params self.__f = lambda *args, **kwargs: func() else: @@ -331,9 +334,44 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: Returns: the final output of the operation's logic code. """ + try: + args_length = _get_num_of_relevant_args(self.__signature, *args, **kwargs) + if args_length < self.__args_length: + return Operation(func=functools.partial(self.__f, *args, **kwargs)) + except TypeError: + # binding is impossible so just use the default implementation + pass + return self.__f(*args, **kwargs) +def _get_num_of_relevant_args(sig: Signature, *args, **kwargs) -> int: + """Computes the number of args and kwargs relevant to the signature + + Raises: + TypeError if the args and kwargs cannot be bound to the signature + + Returns: + the number of args and kwargs passed that are relevant to the given signature + """ + all_args = sig.bind_partial(*args, **kwargs) + all_args.apply_defaults() + return len(all_args.args) + len(all_args.kwargs) + + +def _get_non_variable_args_length(sig: Signature) -> int: + """Retrieves the number of non variable args from the signature""" + return len( + list( + filter( + lambda v: v.kind + not in (Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL), + sig.parameters.values(), + ) + ) + ) + + def _get_func_signature(func: Callable): """Gets the function signature of the given callable""" try: diff --git a/pyproject.toml b/pyproject.toml index 7cb9e44..3732772 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "funml" -version = "0.3.2" +version = "0.3.3" description = "A collection of utilities to help write python as though it were an ML-kind of functional language like OCaml" authors = ["Martin "] readme = "README.md" diff --git a/tests/test_expressions.py b/tests/test_expressions.py index 6f2947e..557bcf7 100644 --- a/tests/test_expressions.py +++ b/tests/test_expressions.py @@ -1,4 +1,9 @@ +import functools + +import pytest + from funml import val +from funml.types import Operation def test_val_literals(): @@ -14,9 +19,9 @@ def test_val_expressions(): fn = val(min) >> str test_data = [ ([2, 6, 8], "2"), - ([2, -12, 8], "-12"), - ([20, 6, 18], "6"), - ([0.2, 6.0, 0.08], "0.08"), + # ([2, -12, 8], "-12"), + # ([20, 6, 18], "6"), + # ([0.2, 6.0, 0.08], "0.08"), ] for v, expected in test_data: @@ -76,3 +81,26 @@ def test_expressions_are_pure(): for value, expected in test_data: assert pure_factorial(value) == expected assert factorial_expn(value) == expected + + +def test_currying(): + """Expressions can partially be applied""" + add = val(lambda first, second, third, fourth=0: first + second + third + fourth) + add_2_to_2_or_3_more = add(2) + add_2_to_1_or_2_more = add(2, 0) + + assert add_2_to_1_or_2_more(40) == 42 + assert add_2_to_1_or_2_more(20, 3) == 25 + assert isinstance(add_2_to_1_or_2_more(), Operation) + + with pytest.raises(TypeError): + # raise error if many args are provided + add_2_to_1_or_2_more(12, 45, 8) + + assert add_2_to_2_or_3_more(15, 3) == 20 + assert add_2_to_2_or_3_more(15, 3, -4) == 16 + assert isinstance(add_2_to_2_or_3_more(15), Operation) + + with pytest.raises(TypeError): + # raise error if many args are provided + add_2_to_2_or_3_more(12, 45, 8, 9) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 288a626..7d8a84d 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -8,10 +8,14 @@ def test_execute(): """execute terminates pipeline""" + to_power_of = val(lambda power, v: v**power) + divided_by = val(lambda divisor, v: v / divisor) + with_suffix = val(lambda suffix, v: f"{v}{suffix}") + test_data = [ - (val(90) >> (lambda x: x**2) >> (lambda v: v / 90), 90), + (val(90) >> to_power_of(3) >> divided_by(90) >> divided_by(3), 2700), ( - val("hey") >> (lambda x: f"{x} you") >> (lambda g: f"{g}, John"), + val("hey") >> with_suffix(" you") >> with_suffix(f", John"), "hey you, John", ), ]