Skip to content

Commit

Permalink
Add support for lifespan state
Browse files Browse the repository at this point in the history
This allows ASGI apps to store state during startup that is then
passed in every scope.
  • Loading branch information
adriangb authored and pgjones committed May 27, 2024
1 parent ba3d813 commit d1c1a23
Show file tree
Hide file tree
Showing 34 changed files with 278 additions and 56 deletions.
12 changes: 10 additions & 2 deletions src/hypercorn/asyncio/lifespan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, Callable

from ..config import Config
from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope
from ..typing import AppWrapper, ASGIReceiveEvent, ASGISendEvent, LifespanScope, LifespanState
from ..utils import LifespanFailureError, LifespanTimeoutError


Expand All @@ -14,14 +14,21 @@ class UnexpectedMessageError(Exception):


class Lifespan:
def __init__(self, app: AppWrapper, config: Config, loop: asyncio.AbstractEventLoop) -> None:
def __init__(
self,
app: AppWrapper,
config: Config,
loop: asyncio.AbstractEventLoop,
lifespan_state: LifespanState,
) -> None:
self.app = app
self.config = config
self.startup = asyncio.Event()
self.shutdown = asyncio.Event()
self.app_queue: asyncio.Queue = asyncio.Queue(config.max_app_queue_size)
self.supported = True
self.loop = loop
self.state = lifespan_state

# This mimics the Trio nursery.start task_status and is
# required to ensure the support has been checked before
Expand All @@ -33,6 +40,7 @@ async def handle_lifespan(self) -> None:
scope: LifespanScope = {
"type": "lifespan",
"asgi": {"spec_version": "2.0", "version": "3.0"},
"state": self.state,
}

def _call_soon(func: Callable, *args: Any) -> Any:
Expand Down
9 changes: 5 additions & 4 deletions src/hypercorn/asyncio/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .udp_server import UDPServer
from .worker_context import WorkerContext
from ..config import Config, Sockets
from ..typing import AppWrapper
from ..typing import AppWrapper, LifespanState
from ..utils import (
check_multiprocess_shutdown_event,
load_application,
Expand Down Expand Up @@ -77,7 +77,8 @@ def _signal_handler(*_: Any) -> None: # noqa: N803

shutdown_trigger = signal_event.wait

lifespan = Lifespan(app, config, loop)
lifespan_state: LifespanState = {}
lifespan = Lifespan(app, config, loop, lifespan_state)

lifespan_task = loop.create_task(lifespan.handle_lifespan())
await lifespan.wait_for_startup()
Expand Down Expand Up @@ -106,7 +107,7 @@ async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamW
task = asyncio.current_task(loop)
server_tasks.add(task)
task.add_done_callback(server_tasks.discard)
await TCPServer(app, loop, config, context, reader, writer)
await TCPServer(app, loop, config, context, lifespan_state, reader, writer)

servers = []
for sock in sockets.secure_sockets:
Expand Down Expand Up @@ -140,7 +141,7 @@ async def _server_callback(reader: asyncio.StreamReader, writer: asyncio.StreamW
sock = _share_socket(sock)

_, protocol = await loop.create_datagram_endpoint(
lambda: UDPServer(app, loop, config, context), sock=sock
lambda: UDPServer(app, loop, config, context, lifespan_state), sock=sock
)
task = loop.create_task(protocol.run())
server_tasks.add(task)
Expand Down
5 changes: 4 additions & 1 deletion src/hypercorn/asyncio/tcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from ..config import Config
from ..events import Closed, Event, RawData, Updated
from ..protocol import ProtocolWrapper
from ..typing import AppWrapper
from ..typing import AppWrapper, ConnectionState, LifespanState
from ..utils import parse_socket_addr

MAX_RECV = 2**16
Expand All @@ -22,6 +22,7 @@ def __init__(
loop: asyncio.AbstractEventLoop,
config: Config,
context: WorkerContext,
state: LifespanState,
reader: asyncio.StreamReader,
writer: asyncio.StreamWriter,
) -> None:
Expand All @@ -33,6 +34,7 @@ def __init__(
self.reader = reader
self.writer = writer
self.send_lock = asyncio.Lock()
self.state = state
self.idle_task = AsyncioSingleTask()

def __await__(self) -> Generator[Any, None, None]:
Expand All @@ -58,6 +60,7 @@ async def run(self) -> None:
self.config,
self.context,
task_group,
ConnectionState(self.state.copy()),
ssl,
client,
server,
Expand Down
12 changes: 10 additions & 2 deletions src/hypercorn/asyncio/udp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .worker_context import WorkerContext
from ..config import Config
from ..events import Event, RawData
from ..typing import AppWrapper
from ..typing import AppWrapper, ConnectionState, LifespanState
from ..utils import parse_socket_addr

if TYPE_CHECKING:
Expand All @@ -22,6 +22,7 @@ def __init__(
loop: asyncio.AbstractEventLoop,
config: Config,
context: WorkerContext,
state: LifespanState,
) -> None:
self.app = app
self.config = config
Expand All @@ -30,6 +31,7 @@ def __init__(
self.protocol: "QuicProtocol"
self.protocol_queue: asyncio.Queue = asyncio.Queue(10)
self.transport: Optional[asyncio.DatagramTransport] = None
self.state = state

def connection_made(self, transport: asyncio.DatagramTransport) -> None: # type: ignore
self.transport = transport
Expand All @@ -48,7 +50,13 @@ async def run(self) -> None:
server = parse_socket_addr(socket.family, socket.getsockname())
async with TaskGroup(self.loop) as task_group:
self.protocol = QuicProtocol(
self.app, self.config, self.context, task_group, server, self.protocol_send
self.app,
self.config,
self.context,
task_group,
ConnectionState(self.state.copy()),
server,
self.protocol_send,
)

while not self.context.terminated.is_set() or not self.protocol.idle:
Expand Down
8 changes: 7 additions & 1 deletion src/hypercorn/protocol/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .h11 import H2CProtocolRequiredError, H2ProtocolAssumedError, H11Protocol
from ..config import Config
from ..events import Event, RawData
from ..typing import AppWrapper, TaskGroup, WorkerContext
from ..typing import AppWrapper, ConnectionState, TaskGroup, WorkerContext


class ProtocolWrapper:
Expand All @@ -16,6 +16,7 @@ def __init__(
config: Config,
context: WorkerContext,
task_group: TaskGroup,
state: ConnectionState,
ssl: bool,
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
Expand All @@ -30,13 +31,15 @@ def __init__(
self.client = client
self.server = server
self.send = send
self.state = state
self.protocol: Union[H11Protocol, H2Protocol]
if alpn_protocol == "h2":
self.protocol = H2Protocol(
self.app,
self.config,
self.context,
self.task_group,
self.state,
self.ssl,
self.client,
self.server,
Expand All @@ -48,6 +51,7 @@ def __init__(
self.config,
self.context,
self.task_group,
self.state,
self.ssl,
self.client,
self.server,
Expand All @@ -66,6 +70,7 @@ async def handle(self, event: Event) -> None:
self.config,
self.context,
self.task_group,
self.state,
self.ssl,
self.client,
self.server,
Expand All @@ -80,6 +85,7 @@ async def handle(self, event: Event) -> None:
self.config,
self.context,
self.task_group,
self.state,
self.ssl,
self.client,
self.server,
Expand Down
3 changes: 3 additions & 0 deletions src/hypercorn/protocol/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from dataclasses import dataclass
from typing import List, Tuple

from hypercorn.typing import ConnectionState


@dataclass(frozen=True)
class Event:
Expand All @@ -15,6 +17,7 @@ class Request(Event):
http_version: str
method: str
raw_path: bytes
state: ConnectionState


@dataclass(frozen=True)
Expand Down
5 changes: 4 additions & 1 deletion src/hypercorn/protocol/h11.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .ws_stream import WSStream
from ..config import Config
from ..events import Closed, Event, RawData, Updated
from ..typing import AppWrapper, H11SendableEvent, TaskGroup, WorkerContext
from ..typing import AppWrapper, ConnectionState, H11SendableEvent, TaskGroup, WorkerContext

STREAM_ID = 1

Expand Down Expand Up @@ -84,6 +84,7 @@ def __init__(
config: Config,
context: WorkerContext,
task_group: TaskGroup,
connection_state: ConnectionState,
ssl: bool,
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
Expand All @@ -103,6 +104,7 @@ def __init__(
self.ssl = ssl
self.stream: Optional[Union[HTTPStream, WSStream]] = None
self.task_group = task_group
self.connection_state = connection_state

async def initiate(self) -> None:
pass
Expand Down Expand Up @@ -236,6 +238,7 @@ async def _create_stream(self, request: h11.Request) -> None:
http_version=request.http_version.decode(),
method=request.method.decode("ascii").upper(),
raw_path=request.target,
state=self.connection_state,
)
)
self.keep_alive_requests += 1
Expand Down
5 changes: 4 additions & 1 deletion src/hypercorn/protocol/h2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .ws_stream import WSStream
from ..config import Config
from ..events import Closed, Event, RawData, Updated
from ..typing import AppWrapper, Event as IOEvent, TaskGroup, WorkerContext
from ..typing import AppWrapper, ConnectionState, Event as IOEvent, TaskGroup, WorkerContext
from ..utils import filter_pseudo_headers

BUFFER_HIGH_WATER = 2 * 2**14 # Twice the default max frame size (two frames worth)
Expand Down Expand Up @@ -85,6 +85,7 @@ def __init__(
config: Config,
context: WorkerContext,
task_group: TaskGroup,
connection_state: ConnectionState,
ssl: bool,
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
Expand All @@ -96,6 +97,7 @@ def __init__(
self.config = config
self.context = context
self.task_group = task_group
self.connection_state = connection_state

self.connection = h2.connection.H2Connection(
config=h2.config.H2Configuration(client_side=False, header_encoding=None)
Expand Down Expand Up @@ -360,6 +362,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None:
http_version="2",
method=method,
raw_path=raw_path,
state=self.connection_state,
)
)
self.keep_alive_requests += 1
Expand Down
5 changes: 4 additions & 1 deletion src/hypercorn/protocol/h3.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .http_stream import HTTPStream
from .ws_stream import WSStream
from ..config import Config
from ..typing import AppWrapper, TaskGroup, WorkerContext
from ..typing import AppWrapper, ConnectionState, TaskGroup, WorkerContext
from ..utils import filter_pseudo_headers


Expand All @@ -34,6 +34,7 @@ def __init__(
config: Config,
context: WorkerContext,
task_group: TaskGroup,
state: ConnectionState,
client: Optional[Tuple[str, int]],
server: Optional[Tuple[str, int]],
quic: QuicConnection,
Expand All @@ -48,6 +49,7 @@ def __init__(
self.server = server
self.streams: Dict[int, Union[HTTPStream, WSStream]] = {}
self.task_group = task_group
self.state = state

async def handle(self, quic_event: QuicEvent) -> None:
for event in self.connection.handle_event(quic_event):
Expand Down Expand Up @@ -127,6 +129,7 @@ async def _create_stream(self, request: HeadersReceived) -> None:
http_version="3",
method=method,
raw_path=raw_path,
state=self.state,
)
)
await self.context.mark_request()
Expand Down
2 changes: 2 additions & 0 deletions src/hypercorn/protocol/http_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ async def handle(self, event: Event) -> None:
"headers": event.headers,
"client": self.client,
"server": self.server,
"state": event.state,
"extensions": {},
}

Expand Down Expand Up @@ -158,6 +159,7 @@ async def app_send(self, message: Optional[ASGISendEvent]) -> None:
http_version=self.scope["http_version"],
method="GET",
raw_path=message["path"].encode(),
state=self.scope["state"],
)
)
elif (
Expand Down
5 changes: 4 additions & 1 deletion src/hypercorn/protocol/quic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from .h3 import H3Protocol
from ..config import Config
from ..events import Closed, Event, RawData
from ..typing import AppWrapper, SingleTask, TaskGroup, WorkerContext
from ..typing import AppWrapper, ConnectionState, SingleTask, TaskGroup, WorkerContext


@dataclass
Expand All @@ -41,6 +41,7 @@ def __init__(
config: Config,
context: WorkerContext,
task_group: TaskGroup,
state: ConnectionState,
server: Optional[Tuple[str, int]],
send: Callable[[Event], Awaitable[None]],
) -> None:
Expand All @@ -51,6 +52,7 @@ def __init__(
self.send = send
self.server = server
self.task_group = task_group
self.state = state

self.quic_config = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=False)
self.quic_config.load_cert_chain(certfile=config.certfile, keyfile=config.keyfile)
Expand Down Expand Up @@ -128,6 +130,7 @@ async def _handle_events(
self.config,
self.context,
self.task_group,
self.state,
client,
self.server,
connection,
Expand Down
1 change: 1 addition & 0 deletions src/hypercorn/protocol/ws_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ async def handle(self, event: Event) -> None:
"headers": event.headers,
"client": self.client,
"server": self.server,
"state": event.state,
"subprotocols": self.handshake.subprotocols or [],
"extensions": {"websocket.http.response": {}},
}
Expand Down
Loading

0 comments on commit d1c1a23

Please sign in to comment.