Skip to content

Commit

Permalink
feat: support wait as a function
Browse files Browse the repository at this point in the history
  • Loading branch information
imnotjames committed Nov 12, 2023
1 parent 7e8bfde commit a15dd7a
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 108 deletions.
1 change: 1 addition & 0 deletions asyncio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
__repo__ = "https://github.com/Adafruit/Adafruit_CircuitPython_asyncio.git"

_attrs = {
"wait": "funcs",
"wait_for": "funcs",
"wait_for_ms": "funcs",
"gather": "funcs",
Expand Down
249 changes: 141 additions & 108 deletions asyncio/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,109 @@
Functions
=========
"""
try:
from typing import List, Tuple, Optional, Union

from .task import TaskQueue, Task
except ImportError:
pass

from . import core


async def _run(waiter, aw):
try:
result = await aw
status = True
except BaseException as er:
result = None
status = er
if waiter.data is None:
# The waiter is still waiting, cancel it.
if waiter.cancel():
# Waiter was cancelled by us, change its CancelledError to an instance of
# CancelledError that contains the status and result of waiting on aw.
# If the wait_for task subsequently gets cancelled externally then this
# instance will be reset to a CancelledError instance without arguments.
waiter.data = core.CancelledError(status, result)

async def wait_for(aw, timeout, sleep=core.sleep):
ALL_COMPLETED = 'ALL_COMPLETED'
FIRST_COMPLETED = 'FIRST_COMPLETED'
FIRST_EXCEPTION = 'FIRST_EXCEPTION'


async def wait(
*aws,
timeout: Optional[Union[int, float]]=None,
return_when: Union[ALL_COMPLETED, FIRST_COMPLETED, FIRST_EXCEPTION]=ALL_COMPLETED
) -> Tuple[List[Task], List[Task]]:
"""
Wait for the awaitables given by aws to complete.
Returns two lists of tasks: (done, pending)
Usage:
done, pending = await asyncio.wait(aws)
If a timeout is set and occurs, any tasks that haven't completed will be returns
in the second list of tasks (pending)
This is a coroutine.
"""
if not aws:
raise ValueError('Set of awaitable is empty.')

if return_when not in (FIRST_COMPLETED, FIRST_EXCEPTION, ALL_COMPLETED):
raise ValueError(f'Invalid return_when value: {return_when}')

aws = [core._promote_to_task(aw) for aw in aws]
task_self = core.cur_task

tasks_done: List[Task] = [aw for aw in aws if aw.done()]
tasks_pending: List[Task] = [aw for aw in aws if not aw.done()]

if len(done) > 0 and return_when == FIRST_COMPLETED:
return tasks_done, tasks_pending

if len(pending) == 0 and return_when == ALL_COMPLETED:
return tasks_done, tasks_pending

if return_when == FIRST_EXCEPTION:
has_exception = any([
(
not isinstance(t.data, core.CancelledError) and
not isinstance(t.data, StopIteration) and
isinstance(t.data, Exception)
)
for t in tasks_done
])

if has_exception:
return tasks_done, tasks_pending

def _done_callback(t: Task, er):
tasks_pending.remove(t)
tasks_done.add(t)

if len(pending) == 0:
core._task_queue.push_head(task_self)
elif return_when == FIRST_COMPLETED:
core._task_queue.push_head(task_self)
elif er is not None and return_when == FIRST_EXCEPTION:
core._task_queue.push_head(task_self)
return

for t in pending:
t.state = _done_callback

task_timeout = None
if timeout is not None:
def _timeout_callback():
core._task_queue.push_head(task_self)

task_timeout = core._promote_to_task(core.sleep(timeout))
task_timeout.state = _timeout_callback

# Pass back to the task queue until needed
await core._never()

if task_timeout is not None:
task_timeout.cancel()

# Clean up and remove the callback from pending tasks
for t in pending:
if t.state is _done_callback:
t.state = True

return tasks_done, tasks_pending


async def wait_for(aw, timeout: Union[int, float]):
"""Wait for the *aw* awaitable to complete, but cancel if it takes longer
than *timeout* seconds. If *aw* is not a task then a task will be created
from it.
Expand All @@ -48,48 +129,35 @@ async def wait_for(aw, timeout, sleep=core.sleep):
This is a coroutine.
"""

aw = core._promote_to_task(aw)
if timeout is None:
return await aw

# Run aw in a separate runner task that manages its exceptions.
runner_task = core.create_task(_run(core.cur_task, aw))
task_aw = core._promote_to_task(aw)

try:
# Wait for the timeout to elapse.
await sleep(timeout)
except core.CancelledError as er:
status = er.args[0] if er.args else None
if status is None:
# This wait_for was cancelled externally, so cancel aw and re-raise.
runner_task.cancel()
raise er
elif status is True:
# aw completed successfully and cancelled the sleep, so return aw's result.
return er.args[1]
else:
# aw raised an exception, propagate it out to the caller.
raise status
done, pending = await wait(aw, timeout=timeout)

if len(pending) > 0:
# If our tasks are still pending we timed out
# Per the Python 3.11 docs
# > If a timeout occurs, it cancels the task and raises TimeoutError.
for t in pending:
t.cancel()
raise core.TimeoutError()
except core.CancelledError:
# Per the Python 3.11 docs
# > If the wait is cancelled, the future aw is also cancelled.
task_aw.cancel()
raise

# The sleep finished before aw, so cancel aw and raise TimeoutError.
runner_task.cancel()
await runner_task
raise core.TimeoutError
# This should be completed, so it should immediately return the value or exception when awaiting it.
return await task_aw


def wait_for_ms(aw, timeout):
"""Similar to `wait_for` but *timeout* is an integer in milliseconds.
This is a coroutine, and a MicroPython extension.
"""

return wait_for(aw, timeout, core.sleep_ms)


class _Remove:
@staticmethod
def remove(t):
pass
return wait_for(aw, timeout / 1000)


async def gather(*aws, return_exceptions=False):
Expand All @@ -101,65 +169,30 @@ async def gather(*aws, return_exceptions=False):
if not aws:
return []

def done(t, er):
# Sub-task "t" has finished, with exception "er".
nonlocal state
if gather_task.data is not _Remove:
# The main gather task has already been scheduled, so do nothing.
# This happens if another sub-task already raised an exception and
# woke the main gather task (via this done function), or if the main
# gather task was cancelled externally.
return
elif not return_exceptions and not isinstance(er, StopIteration):
# A sub-task raised an exception, indicate that to the gather task.
state = er
else:
state -= 1
if state:
# Still some sub-tasks running.
return
# Gather waiting is done, schedule the main gather task.
core._task_queue.push_head(gather_task)

ts = [core._promote_to_task(aw) for aw in aws]
for i in range(len(ts)):
if ts[i].state is not True:
# Task is not running, gather not currently supported for this case.
raise RuntimeError("can't gather")
# Register the callback to call when the task is done.
ts[i].state = done

# Set the state for execution of the gather.
gather_task = core.cur_task
state = len(ts)
cancel_all = False

# Wait for the a sub-task to need attention.
gather_task.data = _Remove
tasks = [core._promote_to_task(aw) for aw in aws]

try:
await core._never()
except core.CancelledError as er:
cancel_all = True
state = er

# Clean up tasks.
for i in range(len(ts)):
if ts[i].state is done:
# Sub-task is still running, deregister the callback and cancel if needed.
ts[i].state = True
if cancel_all:
ts[i].cancel()
elif isinstance(ts[i].data, StopIteration):
# Sub-task ran to completion, get its return value.
ts[i] = ts[i].data.value
if not return_exceptions:
await wait(tasks, return_when=FIRST_EXCEPTION)
else:
# Sub-task had an exception with return_exceptions==True, so get its exception.
ts[i] = ts[i].data

# Either this gather was cancelled, or one of the sub-tasks raised an exception with
# return_exceptions==False, so reraise the exception here.
if state is not 0:
raise state

# Return the list of return values of each sub-task.
return ts
await wait(tasks, return_when=ALL_COMPLETED)
except core.CancelledError:
for task in tasks:
task.cancel()
raise

results = []
for task in tasks:
if not task.done():
results.append(None)
continue

try:
results.append(task.result())
except BaseException as e:
if not return_exceptions:
raise e

results.append(e)

return results

0 comments on commit a15dd7a

Please sign in to comment.