From 8dad24a8ea1185760fd3ca1ebcaa08ec2d06ffe4 Mon Sep 17 00:00:00 2001 From: Mauve Date: Sun, 1 Oct 2023 22:11:36 -0400 Subject: [PATCH 1/2] Remove python websocket server framework --- .../web/api/framework/multiplexing.py | 107 ------- src/modlunky2/web/api/framework/pubsub.py | 196 ------------ src/modlunky2/web/api/framework/serde_tag.py | 47 --- src/modlunky2/web/api/framework/session.py | 87 ------ src/modlunky2/web/demo.py | 59 +--- .../web/api/framework/multiplexing_test.py | 120 -------- src/tests/web/api/framework/pubsub_test.py | 278 ------------------ src/tests/web/api/framework/serde_tag_test.py | 81 ----- src/tests/web/api/framework/session_test.py | 161 ---------- 9 files changed, 2 insertions(+), 1134 deletions(-) delete mode 100644 src/modlunky2/web/api/framework/multiplexing.py delete mode 100644 src/modlunky2/web/api/framework/pubsub.py delete mode 100644 src/modlunky2/web/api/framework/serde_tag.py delete mode 100644 src/modlunky2/web/api/framework/session.py delete mode 100644 src/tests/web/api/framework/multiplexing_test.py delete mode 100644 src/tests/web/api/framework/pubsub_test.py delete mode 100644 src/tests/web/api/framework/serde_tag_test.py delete mode 100644 src/tests/web/api/framework/session_test.py diff --git a/src/modlunky2/web/api/framework/multiplexing.py b/src/modlunky2/web/api/framework/multiplexing.py deleted file mode 100644 index 749994412..000000000 --- a/src/modlunky2/web/api/framework/multiplexing.py +++ /dev/null @@ -1,107 +0,0 @@ -from anyio import create_task_group -from anyio.abc import TaskGroup -from dataclasses import dataclass -import logging -from starlette.routing import WebSocketRoute -from starlette.websockets import WebSocket, WebSocketDisconnect -from typing import ( - Any, - Awaitable, - Callable, - Dict, - Generic, - Iterable, - Type, - TypeVar, -) - -from modlunky2.web.api.framework.serde_tag import ( - TagException, - TagDeserializer, - TaggedMessage, - to_tagged_dict, -) -from modlunky2.web.api.framework.session import ( - SessionException, - SessionId, - SessionManager, -) - -logger = logging.getLogger(__name__) - -ParamType = TypeVar("ParamType") - - -class SendConnection: - def __init__( - self, websocket: WebSocket, session_id: SessionId, task_group: TaskGroup - ) -> None: - self._websocket = websocket - self._session_id = session_id - self._task_group = task_group - - @property - def session_id(self) -> SessionId: - return self._session_id - - @property - def task_group(self) -> TaskGroup: - return self._task_group - - async def send(self, obj: Any) -> None: - data = to_tagged_dict(obj) - await self._websocket.send_json(data) - - -MultiplexerEndpoint = Callable[[SendConnection, ParamType], Awaitable[None]] - - -@dataclass(frozen=True) -class WSMultiplexerRoute(Generic[ParamType]): - param_type: Type[ParamType] - endpoint: MultiplexerEndpoint[ParamType] - - -class WSMultiplexer: - def __init__(self, routes: Iterable[WSMultiplexerRoute[Any]]) -> None: - self._param_to_endpoint: Dict[Type[Any], MultiplexerEndpoint[Any]] = {} - for r in routes: - self._param_to_endpoint[r.param_type] = r.endpoint - - self._param_deserializer = TagDeserializer(self._param_to_endpoint.keys()) - self._sid_manager = SessionManager("ws_session_id") - self._route = WebSocketRoute("/{ws_session_id}", self._endpoint) - - @property - def starlette_route(self) -> WebSocketRoute: - return self._route - - async def _endpoint(self, websocket: WebSocket) -> None: - await websocket.accept() - try: - with self._sid_manager.session_for(websocket) as session_id: - async with create_task_group() as tg: - while True: - await self._dispatch_one(websocket, session_id, tg) - except SessionException as ex: - logger.warning("Rejected websocket session", exc_info=ex) - await websocket.close(reason=str(ex)) - except WebSocketDisconnect: - # We swallow disconnect since it's OK for clients to close the connection. - # We don't swallow close since we don't (currently) expect handlers to close the connection. - pass - - async def _dispatch_one( - self, websocket: WebSocket, session_id: SessionId, task_group: TaskGroup - ) -> None: - data: TaggedMessage = await websocket.receive_json() - try: - obj = self._param_deserializer.from_tagged_dict(data) - except (TagException, TypeError) as ex: - logger.warning("Deserializing request failed %s", ex) - return - - endpoint = self._param_to_endpoint[type(obj)] - sender = SendConnection(websocket, session_id, task_group) - - await endpoint(sender, obj) diff --git a/src/modlunky2/web/api/framework/pubsub.py b/src/modlunky2/web/api/framework/pubsub.py deleted file mode 100644 index 2ed107eac..000000000 --- a/src/modlunky2/web/api/framework/pubsub.py +++ /dev/null @@ -1,196 +0,0 @@ -from __future__ import annotations -import dataclasses -from enum import Enum, auto -import logging -import math -from anyio import ( - TASK_STATUS_IGNORED, - WouldBlock, - create_memory_object_stream, - create_task_group, -) -from anyio.abc import TaskStatus -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from dataclasses import InitVar, dataclass -from serde import serde -from typing import ( - Any, - Dict, - Iterable, - Set, - Tuple, - Type, -) - - -from modlunky2.web.api.framework.multiplexing import SendConnection, WSMultiplexerRoute -from modlunky2.web.api.framework.serde_tag import to_tagged_dict -from modlunky2.web.api.framework.session import SessionId - -logger = logging.getLogger(__name__) - - -@serde -@dataclass(frozen=True) -class Subscribe: - topics: Set[str] - - -@serde -@dataclass(frozen=True) -class Unsubscribe: - topics: Set[str] - - -@serde -@dataclass(frozen=True) -class Published: - # Should be TaggedMessage, but I'm unsure how to use pyserde with it - message: Dict[str, Any] - - -class ServiceLevel(Enum): - MUST_DELIVER = auto() - MAY_DROP = auto() - - -@dataclass(frozen=True) -class PubSubTopic: - typ: Type[Any] - service_level: ServiceLevel - - -@dataclass -class _TopicInfo: - service_level: ServiceLevel - subscribers: Set[SessionId] = dataclasses.field(default_factory=set) - - -@dataclass -class _StreamPair: - max_buffer_size: InitVar[float] - - send: MemoryObjectSendStream[Published] = dataclasses.field(init=False) - recv: MemoryObjectReceiveStream[Published] = dataclasses.field(init=False) - - def __post_init__(self, max_buffer_size: float): - self.send, self.recv = create_memory_object_stream(max_buffer_size, Published) - - -@dataclass -class _SessionCopier: - connection: SendConnection - - _may_drop: _StreamPair = dataclasses.field(init=False) - _must_deliver: _StreamPair = dataclasses.field(init=False) - - def __post_init__(self): - self._may_drop = _StreamPair(0) - self._must_deliver = _StreamPair(math.inf) - - def send(self, level: ServiceLevel, pub: Published): - if level is ServiceLevel.MAY_DROP: - stream = self._may_drop.send - elif level is ServiceLevel.MUST_DELIVER: - stream = self._must_deliver.send - else: - raise ValueError(f"Unknown service level {level}") # pragma: no cover - - try: - stream.send_nowait(pub) - except WouldBlock: - pass - - async def run( - self, - *, - task_status: TaskStatus = TASK_STATUS_IGNORED, - ) -> None: - """Send messages until the client disconnects""" - async with self._may_drop.recv, self._must_deliver.recv, create_task_group() as tg: - await tg.start(self._run_one, self._may_drop.recv) - await tg.start(self._run_one, self._must_deliver.recv) - task_status.started() - - async def _run_one( - self, - recv: MemoryObjectReceiveStream[Published], - *, - task_status: TaskStatus = TASK_STATUS_IGNORED, - ) -> None: - task_status.started() - async for pub in recv: - await self.connection.send(pub) - - -@dataclass -class PubSubManager: - topics: InitVar[Iterable[PubSubTopic]] - - _topic_info: Dict[str, _TopicInfo] = dataclasses.field( - init=False, default_factory=dict - ) - _sessions: Dict[SessionId, _SessionCopier] = dataclasses.field( - init=False, default_factory=dict - ) - - def __post_init__(self, topics: Iterable[PubSubTopic]): - for t in topics: - name = t.typ.__name__ - if name in self._topic_info: - raise ValueError(f"Topic {name} appears more than once") - self._topic_info[name] = _TopicInfo(t.service_level) - - @property - def multiplexer_routes(self) -> Tuple[WSMultiplexerRoute[Any], ...]: - return ( - WSMultiplexerRoute(Subscribe, self._subscribe), - WSMultiplexerRoute(Unsubscribe, self._unsubscribe), - ) - - def publish(self, msg: Any): - topic_name = type(msg).__name__ - if topic_name not in self._topic_info: - raise ValueError(f"Topic {topic_name} is unknown") - info = self._topic_info[topic_name] - pub = Published(to_tagged_dict(msg)) - for sid in info.subscribers: - self._sessions[sid].send(info.service_level, pub) - - async def _subscribe(self, connection: SendConnection, req: Subscribe) -> None: - self._check_topics(req.topics) - self._maybe_add_session(connection) - for topic in req.topics: - self._topic_info[topic].subscribers.add(connection.session_id) - - async def _unsubscribe(self, connection: SendConnection, req: Unsubscribe) -> None: - self._check_topics(req.topics) - for topic in req.topics: - self._topic_info[topic].subscribers.discard(connection.session_id) - - def _check_topics(self, raw_topics: Set[str]) -> None: - unknown_topics: Set[str] = set() - for t in raw_topics: - if t not in self._topic_info: - unknown_topics.add(t) - if unknown_topics: - raise ValueError(f"Request contains unknown topics {unknown_topics!r}") - - def _maybe_add_session(self, connection: SendConnection): - if connection.session_id in self._sessions: - return - self._sessions[connection.session_id] = _SessionCopier(connection) - connection.task_group.start_soon( - self._run_session, - connection.session_id, - name=f"pubsub copier for sid {connection.session_id}", - ) - - async def _run_session(self, session_id: SessionId): - try: - await self._sessions[session_id].run() - finally: - # Cleanup the session - for ti in self._topic_info.values(): - ti.subscribers.discard(session_id) - del self._sessions[session_id] diff --git a/src/modlunky2/web/api/framework/serde_tag.py b/src/modlunky2/web/api/framework/serde_tag.py deleted file mode 100644 index ea694a1d4..000000000 --- a/src/modlunky2/web/api/framework/serde_tag.py +++ /dev/null @@ -1,47 +0,0 @@ -from __future__ import annotations - -from dataclasses import InitVar, dataclass, field -from types import MappingProxyType -from typing import Any, Dict, Iterable, NewType, Type - -from serde.de import from_dict, is_deserializable -from serde.se import to_dict - -TaggedMessage = NewType("TaggedMessage", Dict[str, Any]) - - -class TagException(Exception): - pass - - -def to_tagged_dict(obj: Any) -> TaggedMessage: - data = to_dict(obj, convert_sets=True) - tag = type(obj).__name__ - return TaggedMessage({tag: data}) - - -@dataclass(frozen=True) -class TagDeserializer: - types: InitVar[Iterable[Type[Any]]] - - _type_map: Dict[str, Type[Any]] = field(default_factory=dict, init=False) - - def __post_init__(self, types: Iterable[Type[Any]]): - type_map: Dict[str, Type[Any]] = {} - for typ in types: - if not is_deserializable(typ): - raise TagException(f"{typ} isn't deserializable with pyserde") - - type_map[typ.__name__] = typ - object.__setattr__(self, "_type_map", MappingProxyType(type_map)) - - def from_tagged_dict(self, data: TaggedMessage) -> Any: - if len(data) != 1: - raise TagException(f"Data has {len(data)} keys, expected 1") - tag = next(iter(data.keys())) - - if tag not in self._type_map: - raise TagException(f"Data has tag {tag}, which isn't a known type") - typ = self._type_map[tag] - - return from_dict(typ, data[tag]) diff --git a/src/modlunky2/web/api/framework/session.py b/src/modlunky2/web/api/framework/session.py deleted file mode 100644 index 95a4d1cbf..000000000 --- a/src/modlunky2/web/api/framework/session.py +++ /dev/null @@ -1,87 +0,0 @@ -from contextlib import contextmanager -from dataclasses import dataclass, field -import logging -from starlette.datastructures import Address -from starlette.websockets import WebSocket -from typing import Dict, Generator, NewType -from urllib.parse import urlsplit - -logger = logging.getLogger(__name__) - -SessionId = NewType("SessionId", str) - - -class SessionException(Exception): - pass - - -@dataclass -class SessionManager: - route_sid_param_name: str - sid_to_addr: Dict[SessionId, Address] = field(default_factory=dict) - - @contextmanager - def session_for(self, websocket: WebSocket) -> Generator[SessionId, None, None]: - # If there's no origin, this isn't a browser request. - # We assume it's from a local process and allow it. - if "origin" in websocket.headers: - _validate_origin(websocket.headers["origin"]) - - sid = SessionId(websocket.path_params[self.route_sid_param_name]) - addr = websocket.client - if addr is None: - raise SessionException(f"Session {sid} has no client address") - - if sid in self.sid_to_addr: - logger.warning( - "Session %s already in use by %r, attempted reuse by %r", - sid, - self.sid_to_addr[sid], - addr, - ) - raise SessionException( - f"Session ID {sid} is currently being used by another connection" - ) - self.sid_to_addr[sid] = addr - - try: - yield sid - finally: - del self.sid_to_addr[sid] - - -# Allow the most common localhost 'spellings.' This doesn't bother with -# supporting alternative IPv4 loopbacks (e.g. 127.0.0.2). Notably IPv6 gets -# canonicalized elsewhere; we'll never see (e.g.) "[0:0:0:0:0:0:000:1]" -_ALLOWED_ORIGINS = frozenset(["localhost", "127.0.0.1", "[::1]"]) - - -def _validate_origin(origin: str) -> None: - # It might be nice to allow file: scheme, but we'd also have to allow the others. - # Unfortunately, it's possible to construct (e.g.) a data: URL from any origin - if origin == "null": - raise SessionException( - "Connection from privacy-sensitive origin (e.g. file:, data:)" - ) - - try: - parsed = urlsplit(origin) - except ValueError as err: - raise SessionException("Failed to parse origin") from err - - hostname = _extract_hostname(parsed.netloc) - if hostname not in _ALLOWED_ORIGINS: - raise SessionException(f"Unauthorized hostname ({hostname}) in origin {origin}") - - -def _extract_hostname(netloc: str) -> str: - # Try to ignore the port. For IPv6 addresses, there may be ":" in the address - colon_index = netloc.rfind(":") - close_index = netloc.rfind("]") - if colon_index < 0: - return netloc - if close_index >= 0 and close_index > colon_index: - return netloc - if colon_index < 0: - return netloc - return netloc[0:colon_index] diff --git a/src/modlunky2/web/demo.py b/src/modlunky2/web/demo.py index cdfaa1c72..bf21ab8aa 100644 --- a/src/modlunky2/web/demo.py +++ b/src/modlunky2/web/demo.py @@ -1,19 +1,6 @@ -from contextlib import asynccontextmanager -from functools import partial import logging -import time -from typing import NoReturn -from anyio import create_task_group, current_time, sleep -from dataclasses import dataclass -from serde import serde from starlette.applications import Starlette -from starlette.routing import Mount, Route -from modlunky2.web.api.framework.multiplexing import ( - SendConnection, - WSMultiplexer, - WSMultiplexerRoute, -) -from modlunky2.web.api.framework.pubsub import PubSubManager, PubSubTopic, ServiceLevel +from starlette.routing import Route from starlette.requests import Request from starlette.responses import PlainTextResponse @@ -24,50 +11,8 @@ async def hello_world(_request: Request): return PlainTextResponse("Hello world!") -@serde -@dataclass -class EchoMessage: - message: str - - -async def echo(connection: SendConnection, msg: EchoMessage): - await connection.send(msg) - - -@serde -@dataclass -class TimeTick: - now: float - - -async def ticker(pub_manager: PubSubManager) -> NoReturn: - period = 5 # in seconds - while True: - delay = period - current_time() % period - civil_time = time.time() - - pub_manager.publish(TimeTick(civil_time)) - await sleep(delay) - - -@asynccontextmanager -async def app_lifespan(pub_manager: PubSubManager, _app: Starlette): - async with create_task_group() as tg: - tg.start_soon(ticker, pub_manager) - yield - tg.cancel_scope.cancel() - - def make_asgi_app() -> Starlette: - topics = [PubSubTopic(TimeTick, ServiceLevel.MAY_DROP)] - pub_manager = PubSubManager(topics) - multiplexer = WSMultiplexer( - (WSMultiplexerRoute(EchoMessage, echo),) + pub_manager.multiplexer_routes - ) routes = [ Route("/", hello_world), - Mount("/echo", routes=[multiplexer.starlette_route]), ] - return Starlette( - debug=True, routes=routes, lifespan=partial(app_lifespan, pub_manager) - ) + return Starlette(debug=True, routes=routes) diff --git a/src/tests/web/api/framework/multiplexing_test.py b/src/tests/web/api/framework/multiplexing_test.py deleted file mode 100644 index 3d6329b58..000000000 --- a/src/tests/web/api/framework/multiplexing_test.py +++ /dev/null @@ -1,120 +0,0 @@ -from dataclasses import dataclass -import logging -from typing import Generator -import pytest -from serde import serde -from starlette.applications import Starlette -from starlette.testclient import TestClient, WebSocketTestSession -from starlette.websockets import WebSocketDisconnect - -from modlunky2.web.api.framework.multiplexing import ( - SendConnection, - WSMultiplexer, - WSMultiplexerRoute, -) - -logger = logging.getLogger(__name__) - - -@serde -@dataclass -class Person: - name: str - - -@serde -@dataclass -class Greeting: - phrase: str - - -@serde -@dataclass -class Oops: - content: str - - -class OopsException(Exception): - """Testing exception""" - - -async def hello(connection: SendConnection, person: Person): - await connection.send(Greeting(f"hi {person.name} from {connection.session_id}")) - - -async def whoops(_connection: SendConnection, oops: Oops): - raise OopsException(f"Got {oops.content}") - - -@pytest.fixture(name="client") -def make_client() -> Generator[TestClient, None, None]: - multiplexer = WSMultiplexer( - [ - WSMultiplexerRoute[Person](Person, hello), - WSMultiplexerRoute[Oops](Oops, whoops), - ] - ) - app = Starlette(routes=[multiplexer.starlette_route]) - test_client = TestClient(app=app) - with test_client: - yield test_client - - -def test_requests(client: TestClient): - connection: WebSocketTestSession = client.websocket_connect("/123") - with connection: - connection.send_json({"Person": {"name": "Ana"}}) - assert connection.receive_json() == {"Greeting": {"phrase": "hi Ana from 123"}} - - connection.send_json({"Person": {"name": "Terra"}}) - assert connection.receive_json() == { - "Greeting": {"phrase": "hi Terra from 123"} - } - - -def test_duplicate_session(client: TestClient): - connection1: WebSocketTestSession = client.websocket_connect("/456") - connection2: WebSocketTestSession = client.websocket_connect("/456") - with connection1, connection2: - with pytest.raises(WebSocketDisconnect, match=r"ID 456"): - connection2.receive_json() - - -def test_two_connections(client: TestClient): - connection1: WebSocketTestSession = client.websocket_connect("/123") - connection2: WebSocketTestSession = client.websocket_connect("/456") - with connection1, connection2: - connection1.send_json({"Person": {"name": "Roffy"}}) - assert connection1.receive_json() == { - "Greeting": {"phrase": "hi Roffy from 123"} - } - - connection2.send_json({"Person": {"name": "Colin"}}) - assert connection2.receive_json() == { - "Greeting": {"phrase": "hi Colin from 456"} - } - - -def test_skip_unrecognized(client: TestClient): - connection: WebSocketTestSession = client.websocket_connect("/789") - with connection: - connection.send_json({"Mysterious": {}}) - connection.send_json({"Person": {"name": "Tina"}}) - assert connection.receive_json() == {"Greeting": {"phrase": "hi Tina from 789"}} - - -def test_exception_handling(): - multiplexer = WSMultiplexer( - [ - WSMultiplexerRoute[Person](Person, hello), - WSMultiplexerRoute[Oops](Oops, whoops), - ] - ) - app = Starlette(routes=[multiplexer.starlette_route]) - test_client = TestClient(app=app, backend_options={"debug": True}) - with test_client: - connection: WebSocketTestSession = test_client.websocket_connect("/257") - with pytest.raises(OopsException), connection: - connection.send_json({"Oops": {"content": "uh"}}) - # This ensures we wait for the exception - connection.receive_json() diff --git a/src/tests/web/api/framework/pubsub_test.py b/src/tests/web/api/framework/pubsub_test.py deleted file mode 100644 index 1d114df37..000000000 --- a/src/tests/web/api/framework/pubsub_test.py +++ /dev/null @@ -1,278 +0,0 @@ -from contextlib import asynccontextmanager -from anyio import create_memory_object_stream, create_task_group -from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from dataclasses import dataclass -from starlette.applications import Starlette -from starlette.testclient import TestClient, WebSocketTestSession -from typing import Any, Generator, Type, TypeVar, cast -import pytest -from serde import serde - -from modlunky2.web.api.framework.multiplexing import ( - SendConnection, - WSMultiplexer, - WSMultiplexerRoute, -) -from modlunky2.web.api.framework.pubsub import ( - _SessionCopier, - PubSubManager, - PubSubTopic, - Published, - ServiceLevel, - Subscribe, - Unsubscribe, -) -from modlunky2.web.api.framework.serde_tag import ( - TagDeserializer, - TaggedMessage, - to_tagged_dict, -) - - -@serde -@dataclass -class Echo: - msg: str - - -@serde -@dataclass -class Broadcast: - announcement: str - - -@serde -@dataclass -class Notice: - message: str - - -@dataclass -class Broadcaster: - manager: PubSubManager - - async def handler(self, _connection: SendConnection, req: Broadcast) -> None: - self.manager.publish(Notice(req.announcement)) - - -async def echo_handler(connection: SendConnection, req: Echo) -> None: - await connection.send(req) - - -T = TypeVar("T") - - -def from_published(tag_de: TagDeserializer, data: Any, typ: Type[T]) -> T: - pub = tag_de.from_tagged_dict(data) - assert isinstance(pub, Published) - - msg = tag_de.from_tagged_dict(TaggedMessage(pub.message)) - assert isinstance(msg, typ) - return msg - - -@pytest.fixture(name="tag_de") -def make_tag_de() -> TagDeserializer: - return TagDeserializer([Echo, Notice, Published]) - - -@pytest.fixture(name="client") -def make_client() -> Generator[TestClient, None, None]: - manager = PubSubManager([PubSubTopic(Notice, ServiceLevel.MUST_DELIVER)]) - - broadcaster = Broadcaster(manager) - multiplexer = WSMultiplexer( - ( - WSMultiplexerRoute[Broadcast](Broadcast, broadcaster.handler), - WSMultiplexerRoute[Echo](Echo, echo_handler), - ) - + manager.multiplexer_routes - ) - - app = Starlette(routes=[multiplexer.starlette_route]) - test_client = TestClient(app=app) - with test_client: - yield test_client - - -def test_not_subscribed(client: TestClient): - connection: WebSocketTestSession = client.websocket_connect("/123") - with connection: - connection.send_json(to_tagged_dict(Broadcast("not listening"))) - # When the connection context exits, it asserts there are no messages in its queue - - -def test_one_subscribed(client: TestClient, tag_de: TagDeserializer): - connection: WebSocketTestSession = client.websocket_connect("/123") - with connection: - connection.send_json(to_tagged_dict(Subscribe({Notice.__name__}))) - - to_send = "listen up" - connection.send_json(to_tagged_dict(Broadcast(to_send))) - - data = connection.receive_json() - got_msg = from_published(tag_de, data, Notice) - assert got_msg == Notice(to_send) - - -def test_one_subscribed_duplicate(client: TestClient, tag_de: TagDeserializer): - connection: WebSocketTestSession = client.websocket_connect("/123") - with connection: - connection.send_json(to_tagged_dict(Subscribe({Notice.__name__}))) - # This duplication should be OK - connection.send_json(to_tagged_dict(Subscribe({Notice.__name__}))) - - to_send = "still ok" - connection.send_json(to_tagged_dict(Broadcast(to_send))) - - data = connection.receive_json() - got_msg = from_published(tag_de, data, Notice) - assert got_msg == Notice(to_send) - - -def test_unsubscribe_wo_sub(client: TestClient): - connection: WebSocketTestSession = client.websocket_connect("/123") - with connection: - connection.send_json(to_tagged_dict(Unsubscribe({Notice.__name__}))) - connection.send_json(to_tagged_dict(Broadcast("can't hear this"))) - - -def test_unsubscribe_after_sub(client: TestClient): - connection: WebSocketTestSession = client.websocket_connect("/123") - with connection: - connection.send_json(to_tagged_dict(Subscribe({Notice.__name__}))) - connection.send_json(to_tagged_dict(Unsubscribe({Notice.__name__}))) - connection.send_json(to_tagged_dict(Broadcast("not heard"))) - - -def test_two_subscribed(client: TestClient, tag_de: TagDeserializer): - c1: WebSocketTestSession = client.websocket_connect("/123") - c2: WebSocketTestSession = client.websocket_connect("/456") - with c1: - with c2: - c1.send_json(to_tagged_dict(Subscribe({Notice.__name__}))) - c2.send_json(to_tagged_dict(Subscribe({Notice.__name__}))) - - # WebSocketTestSession.send_json() doesn't wait for the message to be processed before returning. - # So, we use echos to ensure both Subscribe requests have been processed. - echo1 = to_tagged_dict(Echo("yo")) - c1.send_json(echo1) - assert c1.receive_json() == echo1 - - echo2 = to_tagged_dict(Echo("hi")) - c2.send_json(echo2) - assert c2.receive_json() == echo2 - - # Make the publish event happen - to_send = "you two" - c2.send_json(to_tagged_dict(Broadcast(to_send))) - - data1 = c1.receive_json() - got1 = from_published(tag_de, data1, Notice) - assert got1 == Notice(to_send) - - data2 = c2.receive_json() - got2 = from_published(tag_de, data2, Notice) - assert got2 == Notice(to_send) - - # Try publishing after a client disconnected - to_send = "just us" - c1.send_json(to_tagged_dict(Broadcast(to_send))) - - data3 = c1.receive_json() - got3 = from_published(tag_de, data3, Notice) - assert got3 == Notice(to_send) - - -@pytest.mark.parametrize("to_send", [Subscribe({"bogus"}), Unsubscribe({"bad"})]) -def test_subscribe_unknown(client: TestClient, to_send: Any): - with client: - connection: WebSocketTestSession = client.websocket_connect("/123") - with pytest.raises(ValueError, match=r"unknown topics"), connection: - connection.send_json(to_tagged_dict(to_send)) - # This ensures we wait for the exception - connection.receive_json() - - -def test_duplicate_topic_name(): - topics = [ - PubSubTopic(Notice, ServiceLevel.MAY_DROP), - PubSubTopic(Notice, ServiceLevel.MUST_DELIVER), - ] - with pytest.raises(ValueError, match=r"more than once"): - PubSubManager(topics) - - -def test_publish_unknown(): - pub_manager = PubSubManager([PubSubTopic(Notice, ServiceLevel.MUST_DELIVER)]) - with pytest.raises(ValueError, match=r"unknown"): - pub_manager.publish(Broadcast("shouldn't work")) - - -@dataclass -class FakeSendConnection: - _send: MemoryObjectSendStream[Published] - - async def send(self, pub: Published): - await self._send.send(pub) - - -@dataclass -class CopierFixture: - copier: _SessionCopier - receive: MemoryObjectReceiveStream[Published] - - -@asynccontextmanager -async def make_copier_fixture(): - # Note: We use a context manager to safely manage the task group - - send, receive = create_memory_object_stream(1, Published) - connection = cast(SendConnection, FakeSendConnection(send)) - copier = _SessionCopier(connection) - - async with send, receive, create_task_group() as tg: - await tg.start(copier.run) - - yield CopierFixture(copier, receive) - - # Stop the copier - tg.cancel_scope.cancel() - - # Make sure there's nothing left in the stream - assert receive.statistics().current_buffer_used == 0 - - -@pytest.mark.anyio -async def test_copier_must_deliver(): - async with make_copier_fixture() as fix: - fix.copier.send(ServiceLevel.MUST_DELIVER, Published({"started": True})) - assert await fix.receive.receive() == Published({"started": True}) - - -@pytest.mark.anyio -async def test_copier_may_drop(): - async with make_copier_fixture() as fix: - fix.copier.send(ServiceLevel.MAY_DROP, Published({"kept": True})) - assert await fix.receive.receive() == Published({"kept": True}) - - -@pytest.mark.anyio -async def test_copier_priority_case1(): - async with make_copier_fixture() as fix: - fix.copier.send(ServiceLevel.MUST_DELIVER, Published({"clogging": True})) - fix.copier.send(ServiceLevel.MAY_DROP, Published({"buffered": True})) - fix.copier.send(ServiceLevel.MUST_DELIVER, Published({"must_keep": True})) - for k in ["clogging", "buffered", "must_keep"]: - assert await fix.receive.receive() == Published({k: True}) - - -@pytest.mark.anyio -async def test_copier_priority_case2(): - async with make_copier_fixture() as fix: - fix.copier.send(ServiceLevel.MAY_DROP, Published({"clogging": True})) - fix.copier.send(ServiceLevel.MAY_DROP, Published({"dropped": True})) - fix.copier.send(ServiceLevel.MUST_DELIVER, Published({"must_keep_1": True})) - fix.copier.send(ServiceLevel.MUST_DELIVER, Published({"must_keep_2": True})) - for k in ["clogging", "must_keep_1", "must_keep_2"]: - assert await fix.receive.receive() == Published({k: True}) diff --git a/src/tests/web/api/framework/serde_tag_test.py b/src/tests/web/api/framework/serde_tag_test.py deleted file mode 100644 index 4bcbc04fc..000000000 --- a/src/tests/web/api/framework/serde_tag_test.py +++ /dev/null @@ -1,81 +0,0 @@ -from __future__ import annotations # PEP 563 -from contextlib import AbstractContextManager -from dataclasses import dataclass -import pytest -from serde import serde -from typing import Any - -from modlunky2.web.api.framework.serde_tag import ( - TaggedMessage, - TagException, - TagDeserializer, - to_tagged_dict, -) - - -@serde -@dataclass -class A: - a_num: int - - -@serde -@dataclass -class B: - a_str: str - - -@serde -@dataclass -class Rogue: - a_float: float - - -class NotSerde: - pass - - -@pytest.mark.parametrize( - "obj", - [A(3), B("woah")], -) -def test_round_trip(obj: Any): - de = TagDeserializer([A, B]) - assert de.from_tagged_dict(to_tagged_dict(obj)) == obj - - -@pytest.mark.parametrize( - "data,obj", - [ - ({"A": {"a_num": 23}}, A(23)), - ({"B": {"a_str": "what"}}, B("what")), - ], -) -def test_ok_dict(data: TaggedMessage, obj: Any): - de = TagDeserializer([A, B]) - assert de.from_tagged_dict(data) == obj - assert to_tagged_dict(obj) == data - - -@pytest.mark.parametrize( - "data,expectation", - [ - ( - {"C": {"some": "thing"}}, - pytest.raises(TagException, match=r"isn't a known type"), - ), - ( - {"A": {"a_num": 99}, "more": 3}, - pytest.raises(TagException, match=r"expected 1"), - ), - ], -) -def test_bad_dict(data: TaggedMessage, expectation: AbstractContextManager[None]): - de = TagDeserializer([A, B]) - with expectation: - de.from_tagged_dict(data) - - -def test_bad_type(): - with pytest.raises(TagException, match=r"isn't deserializable"): - TagDeserializer([NotSerde]) diff --git a/src/tests/web/api/framework/session_test.py b/src/tests/web/api/framework/session_test.py deleted file mode 100644 index 8ef5b9851..000000000 --- a/src/tests/web/api/framework/session_test.py +++ /dev/null @@ -1,161 +0,0 @@ -from dataclasses import dataclass, field -import pytest -from starlette.datastructures import Address -from starlette.websockets import WebSocket, WebSocketDisconnect -from typing import Dict, cast - -from modlunky2.web.api.framework.session import ( - _extract_hostname, - _validate_origin, - SessionException, - SessionId, - SessionManager, -) - - -@dataclass -class FakeWebSocket: - path_params: Dict[str, str] - client: Address - headers: Dict[str, str] = field(default_factory=dict) - - -PARAM_NAME = "sid" - - -@pytest.mark.parametrize("sid", [SessionId("123"), SessionId("abc")]) -def test_one_session(sid: SessionId): - websocket = cast(WebSocket, FakeWebSocket({PARAM_NAME: sid}, Address("local", 456))) - - manager = SessionManager(PARAM_NAME) - with manager.session_for(websocket) as got_sid: - assert got_sid == sid - - -def test_two_sessions_ok(): - sid1 = SessionId("123") - sid2 = SessionId("456") - websocket1 = cast( - WebSocket, FakeWebSocket({PARAM_NAME: sid1}, Address("local", 1011)) - ) - websocket2 = cast( - WebSocket, FakeWebSocket({PARAM_NAME: sid2}, Address("local", 2022)) - ) - - manager = SessionManager(PARAM_NAME) - with manager.session_for(websocket1) as got_sid1: - assert got_sid1 == sid1 - with manager.session_for(websocket2) as got_sid2: - assert got_sid2 == sid2 - - -def test_reuse_after_disconnect(): - sid = SessionId("654") - websocket = cast(WebSocket, FakeWebSocket({PARAM_NAME: sid}, Address("local", 987))) - - manager = SessionManager(PARAM_NAME) - # We use try/except because pyright doesn't know pytest.raises() swallows exceptions - try: - with manager.session_for(websocket) as got_sid: - assert got_sid == sid - raise WebSocketDisconnect() - except WebSocketDisconnect: - pass - else: - # This shouldn't be reached - assert False - - with manager.session_for(websocket) as got_sid: - assert got_sid == sid - - -@pytest.mark.parametrize( - "sid,client1,client2", - [ - # Duplicate from different clients - ( - SessionId("123"), - Address("local", 1011), - Address("local", 2022), - ), - # Duplicate from the same client - ( - SessionId("abc"), - Address("local", 1011), - Address("local", 1011), - ), - ], -) -def test_two_sessions_conflict(sid: SessionId, client1: Address, client2: Address): - websocket1 = cast(WebSocket, FakeWebSocket({PARAM_NAME: sid}, client1)) - websocket2 = cast(WebSocket, FakeWebSocket({PARAM_NAME: sid}, client2)) - - manager = SessionManager(PARAM_NAME) - with pytest.raises(SessionException): - with manager.session_for(websocket1) as got_sid1: - assert got_sid1 == sid - with manager.session_for(websocket2): - pass - - -def test_session_with_origin_accepted(): - websocket = cast( - WebSocket, - FakeWebSocket( - {PARAM_NAME: SessionId("2001")}, - Address("local", 456), - {"origin": "http://[::1]:9526"}, - ), - ) - - manager = SessionManager(PARAM_NAME) - with manager.session_for(websocket) as got_sid: - assert got_sid == "2001" - - -def test_session_with_origin_rejected(): - websocket = cast( - WebSocket, - FakeWebSocket( - {PARAM_NAME: SessionId("2002")}, - Address("local", 456), - {"origin": "null"}, - ), - ) - - manager = SessionManager(PARAM_NAME) - with pytest.raises(SessionException): - with manager.session_for(websocket): - pass - - -@pytest.mark.parametrize( - "origin", ["http://localhost", "https://localhost:3000", "http://127.0.0.1:3000"] -) -def test_validate_origin_accepted(origin): - # Just checking that no exception is thrown - _validate_origin(origin) - - -@pytest.mark.parametrize("origin", ["null", "foo", "http://example.com"]) -def test_validate_origin_rejected(origin): - with pytest.raises(SessionException): - _validate_origin(origin) - - -@pytest.mark.parametrize( - "netloc,expected", - [ - ("example1.com", "example1.com"), - ("example2.com:443", "example2.com"), - ("localhost", "localhost"), - ("localhost:3000", "localhost"), - ("1.2.3.4", "1.2.3.4"), - ("5.6.7.8:99", "5.6.7.8"), - ("[::1]", "[::1]"), - ("[::1]:80", "[::1]"), - ], -) -def test_extract_hostname(netloc, expected): - actual = _extract_hostname(netloc) - assert actual == expected From fc9ce02d36f5b80c6376aeb81787d0a906b59740 Mon Sep 17 00:00:00 2001 From: Mauve Date: Sun, 1 Oct 2023 22:12:51 -0400 Subject: [PATCH 2/2] Upgrade anyio --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 8cec4f741..8f2fced89 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ Pillow==10.0.1 -anyio==3.7.1 +anyio==4.0.0 colorhash==2.0.0 fnvhash==0.1.0 httpx==0.25.0