Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flatten the implementation of the pipeline() decorator. #43

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 27 additions & 98 deletions slicerator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import collections.abc
import itertools
from functools import wraps
from functools import partial, wraps
from copy import copy
import inspect

Expand Down Expand Up @@ -503,7 +503,7 @@ def __setstate__(self, data_as_list):
return self.__init__(lambda x: x, data_as_list)


def pipeline(func=None, **kwargs):
def pipeline(func=None, *, retain_doc=False, ancestor_count=1):
"""Decorator to enable lazy evaluation of a function.

When the function is applied to a Slicerator or Pipeline object, it
Expand Down Expand Up @@ -540,8 +540,8 @@ def pipeline(func=None, **kwargs):
Apply the pipeline decorator to your image processing function.

>>> @pipeline
... def color_channel(image, channel):
... return image[channel, :, :]
... def color_channel(image, channel):
... return image[channel, :, :]
...


Expand Down Expand Up @@ -583,94 +583,19 @@ def pipeline(func=None, **kwargs):
... def sum_offset(img1, img2, offset):
... return img1 + img2 + offset
"""
def wrapper(f):
return _pipeline(f, **kwargs)

if func is None:
return wrapper
else:
return wrapper(func)
return partial(
pipeline, retain_doc=retain_doc, ancestor_count=ancestor_count)

if ancestor_count == 'all':
ancestor_count = len(
p for p in inspect.signature(func).parameters
if p.kind.name in ["POSITIONAL_ONLY", "POSITIONAL_OR_KEYWORD"])

def _pipeline(func_or_class, **kwargs):
try:
is_class = issubclass(func_or_class, Pipeline)
is_class = issubclass(func, Pipeline)
except TypeError:
is_class = False
if is_class:
return _pipeline_fromclass(func_or_class, **kwargs)
else:
return _pipeline_fromfunc(func_or_class, **kwargs)


def _pipeline_fromclass(cls, retain_doc=False, ancestor_count=1):
"""Actual `pipeline` implementation

Parameters
----------
func : class
Class for lazy evaluation
retain_doc : bool
If True, don't modify `func`'s doc string to say that it has been
made lazy
ancestor_count : int or 'all', optional
Number of inputs to the pipeline. Defaults to 1.

Returns
-------
Pipeline
Lazy function evaluation :py:class:`Pipeline` for `func`.
"""
if ancestor_count == 'all':
# subtract 1 for `self`
ancestor_count = len(inspect.getfullargspec(cls).args) - 1

@wraps(cls)
def process(*args, **kwargs):
ancestors = args[:ancestor_count]
args = args[ancestor_count:]
all_pipe = all(hasattr(a, '_slicerator_flag') or
isinstance(a, Slicerator) or
isinstance(a, Pipeline) for a in ancestors)
if all_pipe:
return cls(*(ancestors + args), **kwargs)
else:
# Fall back on normal behavior of func, interpreting input
# as a single image.
return cls(*(tuple([a] for a in ancestors) + args), **kwargs)[0]

if not retain_doc:
if process.__doc__ is None:
process.__doc__ = ''
process.__doc__ = ("This function has been made lazy. When passed\n"
"a Slicerator, it will return a \n"
"Pipeline of the results. When passed \n"
"any other objects, its behavior is "
"unchanged.\n\n") + process.__doc__
process.__name__ = cls.__name__
return process


def _pipeline_fromfunc(func, retain_doc=False, ancestor_count=1):
"""Actual `pipeline` implementation

Parameters
----------
func : callable
Function for lazy evaluation
retain_doc : bool
If True, don't modify `func`'s doc string to say that it has been
made lazy
ancestor_count : int or 'all', optional
Number of inputs to the pipeline. Defaults to 1.

Returns
-------
Pipeline
Lazy function evaluation :py:class:`Pipeline` for `func`.
"""
if ancestor_count == 'all':
ancestor_count = len(inspect.getfullargspec(func).args)

@wraps(func)
def process(*args, **kwargs):
Expand All @@ -679,24 +604,28 @@ def process(*args, **kwargs):
all_pipe = all(hasattr(a, '_slicerator_flag') or
isinstance(a, Slicerator) or
isinstance(a, Pipeline) for a in ancestors)
if all_pipe:
def proc_func(*x):
return func(*(x + args), **kwargs)

return Pipeline(proc_func, *ancestors)
if is_class:
return (func(*ancestors, *args, **kwargs)
if all_pipe else
# Fall back on normal behavior of func, interpreting input
# as a single image.
func(*[[a] for a in ancestors], *args, **kwargs)[0])

else:
# Fall back on normal behavior of func, interpreting input
# as a single image.
return func(*(ancestors + args), **kwargs)
return (Pipeline(lambda *x: func(*x, *args, **kwargs), *ancestors)
if all_pipe else
# Fall back on normal behavior of func, interpreting input
# as a single image.
func(*ancestors, *args, **kwargs))

if not retain_doc:
if process.__doc__ is None:
process.__doc__ = ''
process.__doc__ = ("This function has been made lazy. When passed\n"
"a Slicerator, it will return a \n"
"Pipeline of the results. When passed \n"
"any other objects, its behavior is "
"unchanged.\n\n") + process.__doc__
process.__doc__ = (
"This function has been made lazy. When passed a Slicerator, it \n"
"will return a Pipeline of the results. When passed any other \n"
"objects, its behavior is unchanged.\n\n" + process.__doc__)
process.__name__ = func.__name__
return process

Expand Down