Skip to content

Commit

Permalink
Implement exception handling for async generators.
Browse files Browse the repository at this point in the history
Fixes #12.

The default behavior is to cache Exceptions. However, there is an option
to retry exceptions, which will also respect the concurrency guarentees
from once.
  • Loading branch information
aebrahim committed Oct 20, 2023
1 parent 059c90c commit 7825066
Show file tree
Hide file tree
Showing 3 changed files with 343 additions and 149 deletions.
112 changes: 83 additions & 29 deletions once/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,16 @@ def return_value(self, value: typing.Any) -> None:
_ONCE_FACTORY_TYPE = collections.abc.Callable # type: ignore


class _CachedException:
def __init__(self, exception: Exception):
self.exception = exception


def _wrap(
func: collections.abc.Callable,
once_factory: _ONCE_FACTORY_TYPE,
fn_type: _WrappedFunctionType,
retry_exceptions: bool,
) -> collections.abc.Callable:
"""Generate a wrapped function appropriate to the function type.
Expand All @@ -119,7 +125,7 @@ async def wrapped(*args, **kwargs) -> typing.Any:
async with once_base.async_lock:
if not once_base.called:
once_base.return_value = _iterator_wrappers.AsyncGeneratorWrapper(
func, *args, **kwargs
retry_exceptions, func, *args, **kwargs
)
once_base.called = True
return_value = once_base.return_value
Expand All @@ -132,24 +138,58 @@ async def wrapped(*args, **kwargs) -> typing.Any:
return

elif fn_type == _WrappedFunctionType.ASYNC_FUNCTION:
if retry_exceptions:

async def wrapped(*args, **kwargs) -> typing.Any:
once_base: _OnceBase = once_factory()
async with once_base.async_lock:
if not once_base.called:
once_base.return_value = await func(*args, **kwargs)
once_base.called = True
return once_base.return_value
async def wrapped(*args, **kwargs) -> typing.Any:
once_base: _OnceBase = once_factory()
async with once_base.async_lock:
if not once_base.called:
once_base.return_value = await func(*args, **kwargs)
once_base.called = True
return once_base.return_value

else:

async def wrapped(*args, **kwargs) -> typing.Any:
once_base: _OnceBase = once_factory()
async with once_base.async_lock:
if not once_base.called:
try:
once_base.return_value = await func(*args, **kwargs)
except Exception as exception:
once_base.return_value = _CachedException(exception)
once_base.called = True
return_value = once_base.return_value
if isinstance(return_value, _CachedException):
raise return_value.exception
return return_value

elif fn_type == _WrappedFunctionType.SYNC_FUNCTION:
if retry_exceptions:

def wrapped(*args, **kwargs) -> typing.Any:
once_base: _OnceBase = once_factory()
with once_base.lock:
if not once_base.called:
once_base.return_value = func(*args, **kwargs)
once_base.called = True
return once_base.return_value
def wrapped(*args, **kwargs) -> typing.Any:
once_base: _OnceBase = once_factory()
with once_base.lock:
if not once_base.called:
once_base.return_value = func(*args, **kwargs)
once_base.called = True
return once_base.return_value

else:

def wrapped(*args, **kwargs) -> typing.Any:
once_base: _OnceBase = once_factory()
with once_base.lock:
if not once_base.called:
try:
once_base.return_value = func(*args, **kwargs)
except Exception as exception:
once_base.return_value = _CachedException(exception)
once_base.called = True
return_value = once_base.return_value
if isinstance(return_value, _CachedException):
raise return_value.exception
return return_value

elif fn_type == _WrappedFunctionType.SYNC_GENERATOR:

Expand All @@ -158,7 +198,7 @@ def wrapped(*args, **kwargs) -> typing.Any:
with once_base.lock:
if not once_base.called:
once_base.return_value = _iterator_wrappers.GeneratorWrapper(
func, *args, **kwargs
retry_exceptions, func, *args, **kwargs
)
once_base.called = True
iterator = once_base.return_value
Expand Down Expand Up @@ -195,7 +235,7 @@ def _once_factory(is_async: bool, per_thread: bool) -> _ONCE_FACTORY_TYPE:
return lambda: singleton_once


def once(*args, per_thread=False) -> collections.abc.Callable:
def once(*args, per_thread=False, retry_exceptions=False) -> collections.abc.Callable:
"""Decorator to ensure a function is only called once.
The restriction of only one call also holds across threads. However, this
Expand Down Expand Up @@ -225,15 +265,15 @@ def once(*args, per_thread=False) -> collections.abc.Callable:
# This trick lets this function be a decorator directly, or be called
# to create a decorator.
# Both @once and @once() will function correctly.
return functools.partial(once, per_thread=per_thread)
return functools.partial(once, per_thread=per_thread, retry_exceptions=retry_exceptions)
if _is_method(func):
raise SyntaxError(
"Attempting to use @once.once decorator on method "
"instead of @once.once_per_class or @once.once_per_instance"
)
fn_type = _wrapped_function_type(func)
once_factory = _once_factory(is_async=fn_type in _ASYNC_FN_TYPES, per_thread=per_thread)
return _wrap(func, once_factory, fn_type)
return _wrap(func, once_factory, fn_type, retry_exceptions)


class once_per_class: # pylint: disable=invalid-name
Expand All @@ -243,15 +283,21 @@ class once_per_class: # pylint: disable=invalid-name
is_staticmethod: bool

@classmethod
def with_options(cls, per_thread: bool = False):
return lambda func: cls(func, per_thread=per_thread)

def __init__(self, func: collections.abc.Callable, per_thread: bool = False) -> None:
def with_options(cls, per_thread: bool = False, retry_exceptions=False):
return lambda func: cls(func, per_thread=per_thread, retry_exceptions=retry_exceptions)

def __init__(
self,
func: collections.abc.Callable,
per_thread: bool = False,
retry_exceptions: bool = False,
) -> None:
self.func = self._inspect_function(func)
self.fn_type = _wrapped_function_type(self.func)
self.once_factory = _once_factory(
is_async=self.fn_type in _ASYNC_FN_TYPES, per_thread=per_thread
)
self.retry_exceptions = retry_exceptions

def _inspect_function(self, func: collections.abc.Callable):
if not _is_method(func):
Expand Down Expand Up @@ -280,17 +326,22 @@ def __get__(self, obj, cls) -> collections.abc.Callable:
func = self.func
else:
func = functools.partial(self.func, obj)
return _wrap(func, self.once_factory, self.fn_type)
return _wrap(func, self.once_factory, self.fn_type, self.retry_exceptions)


class once_per_instance: # pylint: disable=invalid-name
"""A version of once for class methods which runs once per instance."""

@classmethod
def with_options(cls, per_thread: bool = False):
return lambda func: cls(func, per_thread=per_thread)

def __init__(self, func: collections.abc.Callable, per_thread: bool = False) -> None:
def with_options(cls, per_thread: bool = False, retry_exceptions=False):
return lambda func: cls(func, per_thread=per_thread, retry_exceptions=retry_exceptions)

def __init__(
self,
func: collections.abc.Callable,
per_thread: bool = False,
retry_exceptions: bool = False,
) -> None:
self.func = self._inspect_function(func)
self.fn_type = _wrapped_function_type(self.func)
self.is_async_fn = self.fn_type in _ASYNC_FN_TYPES
Expand All @@ -299,6 +350,7 @@ def __init__(self, func: collections.abc.Callable, per_thread: bool = False) ->
typing.Any, collections.abc.Callable
] = weakref.WeakKeyDictionary()
self.per_thread = per_thread
self.retry_exceptions = retry_exceptions

def once_factory(self) -> _ONCE_FACTORY_TYPE:
"""Generate a new once factory.
Expand All @@ -324,6 +376,8 @@ def __get__(self, obj, cls) -> collections.abc.Callable:
with self.callables_lock:
if (callable := self.callables.get(obj)) is None:
bound_func = functools.partial(self.func, obj)
callable = _wrap(bound_func, self.once_factory(), self.fn_type)
callable = _wrap(
bound_func, self.once_factory(), self.fn_type, self.retry_exceptions
)
self.callables[obj] = callable
return callable
Loading

0 comments on commit 7825066

Please sign in to comment.