From a15dd7a18e6c5b98d1fc4c49ffab3fa4f926cd24 Mon Sep 17 00:00:00 2001 From: James Ward Date: Sat, 11 Nov 2023 00:15:41 -0500 Subject: [PATCH] feat: support `wait` as a function --- asyncio/__init__.py | 1 + asyncio/funcs.py | 249 +++++++++++++++++++++++++------------------- 2 files changed, 142 insertions(+), 108 deletions(-) diff --git a/asyncio/__init__.py b/asyncio/__init__.py index ce8837d..1748f1d 100644 --- a/asyncio/__init__.py +++ b/asyncio/__init__.py @@ -17,6 +17,7 @@ __repo__ = "https://github.com/Adafruit/Adafruit_CircuitPython_asyncio.git" _attrs = { + "wait": "funcs", "wait_for": "funcs", "wait_for_ms": "funcs", "gather": "funcs", diff --git a/asyncio/funcs.py b/asyncio/funcs.py index b1bb24a..d6836c5 100644 --- a/asyncio/funcs.py +++ b/asyncio/funcs.py @@ -14,28 +14,109 @@ Functions ========= """ +try: + from typing import List, Tuple, Optional, Union + from .task import TaskQueue, Task +except ImportError: + pass from . import core -async def _run(waiter, aw): - try: - result = await aw - status = True - except BaseException as er: - result = None - status = er - if waiter.data is None: - # The waiter is still waiting, cancel it. - if waiter.cancel(): - # Waiter was cancelled by us, change its CancelledError to an instance of - # CancelledError that contains the status and result of waiting on aw. - # If the wait_for task subsequently gets cancelled externally then this - # instance will be reset to a CancelledError instance without arguments. - waiter.data = core.CancelledError(status, result) - -async def wait_for(aw, timeout, sleep=core.sleep): +ALL_COMPLETED = 'ALL_COMPLETED' +FIRST_COMPLETED = 'FIRST_COMPLETED' +FIRST_EXCEPTION = 'FIRST_EXCEPTION' + + +async def wait( + *aws, + timeout: Optional[Union[int, float]]=None, + return_when: Union[ALL_COMPLETED, FIRST_COMPLETED, FIRST_EXCEPTION]=ALL_COMPLETED +) -> Tuple[List[Task], List[Task]]: + """ + Wait for the awaitables given by aws to complete. + + Returns two lists of tasks: (done, pending) + + Usage: + + done, pending = await asyncio.wait(aws) + + If a timeout is set and occurs, any tasks that haven't completed will be returns + in the second list of tasks (pending) + + This is a coroutine. + """ + if not aws: + raise ValueError('Set of awaitable is empty.') + + if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED): + raise ValueError(f'Invalid return_when value: {return_when}') + + aws = [core._promote_to_task(aw) for aw in aws] + task_self = core.cur_task + + tasks_done: List[Task] = [aw for aw in aws if aw.done()] + tasks_pending: List[Task] = [aw for aw in aws if not aw.done()] + + if len(done) > 0 and return_when == FIRST_COMPLETED: + return tasks_done, tasks_pending + + if len(pending) == 0 and return_when == ALL_COMPLETED: + return tasks_done, tasks_pending + + if return_when == FIRST_EXCEPTION: + has_exception = any([ + ( + not isinstance(t.data, core.CancelledError) and + not isinstance(t.data, StopIteration) and + isinstance(t.data, Exception) + ) + for t in tasks_done + ]) + + if has_exception: + return tasks_done, tasks_pending + + def _done_callback(t: Task, er): + tasks_pending.remove(t) + tasks_done.add(t) + + if len(pending) == 0: + core._task_queue.push_head(task_self) + elif return_when == FIRST_COMPLETED: + core._task_queue.push_head(task_self) + elif er is not None and return_when == FIRST_EXCEPTION: + core._task_queue.push_head(task_self) + return + + for t in pending: + t.state = _done_callback + + task_timeout = None + if timeout is not None: + def _timeout_callback(): + core._task_queue.push_head(task_self) + + task_timeout = core._promote_to_task(core.sleep(timeout)) + task_timeout.state = _timeout_callback + + # Pass back to the task queue until needed + await core._never() + + if task_timeout is not None: + task_timeout.cancel() + + # Clean up and remove the callback from pending tasks + for t in pending: + if t.state is _done_callback: + t.state = True + + return tasks_done, tasks_pending + + +async def wait_for(aw, timeout: Union[int, float]): """Wait for the *aw* awaitable to complete, but cancel if it takes longer than *timeout* seconds. If *aw* is not a task then a task will be created from it. @@ -48,33 +129,27 @@ async def wait_for(aw, timeout, sleep=core.sleep): This is a coroutine. """ - aw = core._promote_to_task(aw) - if timeout is None: - return await aw - - # Run aw in a separate runner task that manages its exceptions. - runner_task = core.create_task(_run(core.cur_task, aw)) + task_aw = core._promote_to_task(aw) try: # Wait for the timeout to elapse. - await sleep(timeout) - except core.CancelledError as er: - status = er.args[0] if er.args else None - if status is None: - # This wait_for was cancelled externally, so cancel aw and re-raise. - runner_task.cancel() - raise er - elif status is True: - # aw completed successfully and cancelled the sleep, so return aw's result. - return er.args[1] - else: - # aw raised an exception, propagate it out to the caller. - raise status + done, pending = await wait(aw, timeout=timeout) + + if len(pending) > 0: + # If our tasks are still pending we timed out + # Per the Python 3.11 docs + # > If a timeout occurs, it cancels the task and raises TimeoutError. + for t in pending: + t.cancel() + raise core.TimeoutError() + except core.CancelledError: + # Per the Python 3.11 docs + # > If the wait is cancelled, the future aw is also cancelled. + task_aw.cancel() + raise - # The sleep finished before aw, so cancel aw and raise TimeoutError. - runner_task.cancel() - await runner_task - raise core.TimeoutError + # This should be completed, so it should immediately return the value or exception when awaiting it. + return await task_aw def wait_for_ms(aw, timeout): @@ -82,14 +157,7 @@ def wait_for_ms(aw, timeout): This is a coroutine, and a MicroPython extension. """ - - return wait_for(aw, timeout, core.sleep_ms) - - -class _Remove: - @staticmethod - def remove(t): - pass + return wait_for(aw, timeout / 1000) async def gather(*aws, return_exceptions=False): @@ -101,65 +169,30 @@ async def gather(*aws, return_exceptions=False): if not aws: return [] - def done(t, er): - # Sub-task "t" has finished, with exception "er". - nonlocal state - if gather_task.data is not _Remove: - # The main gather task has already been scheduled, so do nothing. - # This happens if another sub-task already raised an exception and - # woke the main gather task (via this done function), or if the main - # gather task was cancelled externally. - return - elif not return_exceptions and not isinstance(er, StopIteration): - # A sub-task raised an exception, indicate that to the gather task. - state = er - else: - state -= 1 - if state: - # Still some sub-tasks running. - return - # Gather waiting is done, schedule the main gather task. - core._task_queue.push_head(gather_task) - - ts = [core._promote_to_task(aw) for aw in aws] - for i in range(len(ts)): - if ts[i].state is not True: - # Task is not running, gather not currently supported for this case. - raise RuntimeError("can't gather") - # Register the callback to call when the task is done. - ts[i].state = done - - # Set the state for execution of the gather. - gather_task = core.cur_task - state = len(ts) - cancel_all = False - - # Wait for the a sub-task to need attention. - gather_task.data = _Remove + tasks = [core._promote_to_task(aw) for aw in aws] + try: - await core._never() - except core.CancelledError as er: - cancel_all = True - state = er - - # Clean up tasks. - for i in range(len(ts)): - if ts[i].state is done: - # Sub-task is still running, deregister the callback and cancel if needed. - ts[i].state = True - if cancel_all: - ts[i].cancel() - elif isinstance(ts[i].data, StopIteration): - # Sub-task ran to completion, get its return value. - ts[i] = ts[i].data.value + if not return_exceptions: + await wait(tasks, return_when=FIRST_EXCEPTION) else: - # Sub-task had an exception with return_exceptions==True, so get its exception. - ts[i] = ts[i].data - - # Either this gather was cancelled, or one of the sub-tasks raised an exception with - # return_exceptions==False, so reraise the exception here. - if state is not 0: - raise state - - # Return the list of return values of each sub-task. - return ts + await wait(tasks, return_when=ALL_COMPLETED) + except core.CancelledError: + for task in tasks: + task.cancel() + raise + + results = [] + for task in tasks: + if not task.done(): + results.append(None) + continue + + try: + results.append(task.result()) + except BaseException as e: + if not return_exceptions: + raise e + + results.append(e) + + return results