Skip to content

Commit

Permalink
backward compatibility with Python <=3.10
Browse files Browse the repository at this point in the history
Signed-off-by: Achille Roussel <[email protected]>
  • Loading branch information
achille-roussel committed Jun 4, 2024
1 parent 1c7bef8 commit 09bd3e8
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 5 deletions.
105 changes: 105 additions & 0 deletions src/dispatch/asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import asyncio
import functools
import inspect
import signal
import threading

class Runner:
"""Runner is a class similar to asyncio.Runner but that we use for backward
compatibility with Python 3.10 and earlier.
"""

def __init__(self):
self._loop = asyncio.new_event_loop()
self._interrupt_count = 0

def __enter__(self):
return self

def __exit__(self, *args, **kwargs):
self.close()

def close(self):
try:
loop = self._loop
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
if hasattr(loop, 'shutdown_default_executor'): # Python 3.9+
loop.run_until_complete(loop.shutdown_default_executor())
finally:
loop.close()

def get_loop(self):
return self._loop

def run(self, coro):
if not inspect.iscoroutine(coro):
raise ValueError("a coroutine was expected, got {!r}".format(coro))

try:
asyncio.get_running_loop()
except RuntimeError:
pass
else:
raise RuntimeError("Runner.run() cannot be called from a running event loop")

task = self._loop.create_task(coro)
sigint_handler = None

if (threading.current_thread() is threading.main_thread()
and signal.getsignal(signal.SIGINT) is signal.default_int_handler
):
sigint_handler = functools.partial(self._on_sigint, main_task=task)
try:
signal.signal(signal.SIGINT, sigint_handler)
except ValueError:
# `signal.signal` may throw if `threading.main_thread` does
# not support signals (e.g. embedded interpreter with signals
# not registered - see gh-91880)
sigint_handler = None

self._interrupt_count = 0
try:
asyncio.set_event_loop(self._loop)
return self._loop.run_until_complete(task)
except asyncio.CancelledError:
if self._interrupt_count > 0:
uncancel = getattr(task, "uncancel", None)
if uncancel is not None and uncancel() == 0:
raise KeyboardInterrupt()
raise # CancelledError
finally:
asyncio.set_event_loop(None)
if (sigint_handler is not None
and signal.getsignal(signal.SIGINT) is sigint_handler
):
signal.signal(signal.SIGINT, signal.default_int_handler)

def _on_sigint(self, signum, frame, main_task):
self._interrupt_count += 1
if self._interrupt_count == 1 and not main_task.done():
main_task.cancel()
# wakeup loop if it is blocked by select() with long timeout
self._loop.call_soon_threadsafe(lambda: None)
return
raise KeyboardInterrupt()

def _cancel_all_tasks(loop):
to_cancel = asyncio.all_tasks(loop)
if not to_cancel:
return

for task in to_cancel:
task.cancel()

loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))

for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler({
'message': 'unhandled exception during asyncio.run() shutdown',
'exception': task.exception(),
'task': task,
})
3 changes: 2 additions & 1 deletion src/dispatch/experimental/lambda_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def handler(event, context):

from awslambdaric.lambda_context import LambdaContext

from dispatch.asyncio import Runner
from dispatch.function import Registry
from dispatch.proto import Input
from dispatch.sdk.v1 import function_pb2 as function_pb
Expand Down Expand Up @@ -93,7 +94,7 @@ def handle(

input = Input(req)
try:
with asyncio.Runner() as runner:
with Runner() as runner:
output = runner.run(func._primitive_call(input))
except Exception:
logger.error("function '%s' fatal error", req.function, exc_info=True)
Expand Down
3 changes: 2 additions & 1 deletion src/dispatch/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def read_root():

from flask import Flask, make_response, request

from dispatch.asyncio import Runner
from dispatch.function import Registry
from dispatch.http import FunctionServiceError, function_service_run
from dispatch.signature import Ed25519PublicKey, parse_verification_key
Expand Down Expand Up @@ -90,7 +91,7 @@ def _handle_error(self, exc: FunctionServiceError):
def _execute(self):
data: bytes = request.get_data(cache=False)

with asyncio.Runner() as runner:
with Runner() as runner:
content = runner.run(
function_service_run(
request.url,
Expand Down
3 changes: 2 additions & 1 deletion src/dispatch/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from http_message_signatures import InvalidSignature

from dispatch.asyncio import Runner
from dispatch.function import Registry
from dispatch.proto import Input
from dispatch.sdk.v1 import function_pb2 as function_pb
Expand Down Expand Up @@ -121,7 +122,7 @@ def do_POST(self):
url = self.requestline # TODO: need full URL

try:
with asyncio.Runner() as runner:
with Runner() as runner:
content = runner.run(
function_service_run(
url,
Expand Down
4 changes: 2 additions & 2 deletions tests/dispatch/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import unittest
from typing import Any, Callable, List, Optional, Set, Type

from dispatch.asyncio import Runner
from dispatch.coroutine import AnyException, any, call, gather, race
from dispatch.experimental.durable import durable
from dispatch.proto import Arguments, Call, CallResult, Error, Input, Output, TailCall
Expand Down Expand Up @@ -55,7 +55,7 @@ async def raises_error():

class TestOneShotScheduler(unittest.TestCase):
def setUp(self):
self.runner = asyncio.Runner()
self.runner = Runner()

def tearDown(self):
self.runner.close()
Expand Down

0 comments on commit 09bd3e8

Please sign in to comment.