From ca826db3e7c1bd6b9ea2048182d421c354f68e07 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 01:42:37 +0530 Subject: [PATCH 01/59] Implement websockets_sans_impl.py --- tests/conftest.py | 1 + .../websockets/websockets_sansio_impl.py | 384 ++++++++++++++++++ 2 files changed, 385 insertions(+) create mode 100644 uvicorn/protocols/websockets/websockets_sansio_impl.py diff --git a/tests/conftest.py b/tests/conftest.py index c1c136022..71a5390e3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -252,6 +252,7 @@ def unused_tcp_port() -> int: ), ), "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", + "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol" ] ) def ws_protocol_cls(request: pytest.FixtureRequest): diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py new file mode 100644 index 000000000..954fbab28 --- /dev/null +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -0,0 +1,384 @@ +import asyncio +import logging +import sys +import time +import typing +from asyncio.transports import BaseTransport, Transport +from urllib.parse import unquote + +import websockets +from websockets.http11 import Request +from websockets.frames import Frame, Close +from websockets.server import ServerConnection +from websockets.connection import State + +from uvicorn.config import Config +from uvicorn.logging import TRACE_LOG_LEVEL +from uvicorn.protocols.utils import ( + get_local_addr, + get_path_with_query_string, + get_remote_addr, + is_ssl, +) +from uvicorn.server import ServerState +from http import HTTPStatus + +if sys.version_info < (3, 8): + from typing_extensions import Literal +else: + from typing import Literal + +if typing.TYPE_CHECKING: + from asgiref.typing import ( + ASGISendEvent, + WebSocketAcceptEvent, + WebSocketCloseEvent, + WebSocketConnectEvent, + WebSocketDisconnectEvent, + WebSocketReceiveEvent, + WebSocketScope, + WebSocketSendEvent, + ) + + WebSocketEvent = typing.Union[ + "WebSocketReceiveEvent", + "WebSocketDisconnectEvent", + "WebSocketConnectEvent", + ] + + +class WebSocketsSansIOProtocol(asyncio.Protocol): + def __init__( + self, + config: Config, + server_state: ServerState, + app_state: typing.Dict[str, typing.Any], + _loop: typing.Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + if not config.loaded: + config.load() + + self.config = config + self.app = config.loaded_app + self.loop = _loop or asyncio.get_event_loop() + self.logger = logging.getLogger("uvicorn.error") + self.root_path = config.root_path + self.app_state = app_state + + # Shared server state + self.connections = server_state.connections + self.tasks = server_state.tasks + self.default_headers = server_state.default_headers + + # Connection state + self.transport: asyncio.Transport = None # type: ignore[assignment] + self.server: typing.Optional[typing.Tuple[str, int]] = None + self.client: typing.Optional[typing.Tuple[str, int]] = None + self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment] + + # WebSocket state + self.queue: asyncio.Queue["WebSocketEvent"] = asyncio.Queue() + self.handshake_initiated = False + self.handshake_complete = False + self.close_sent = False + + # extensions = [] + # if self.config.ws_per_message_deflate: + # extensions.append(ServerPerMessageDeflateFactory()) + self.conn = ServerConnection() + self.request = None + self.response = None + + self.read_paused = False + self.writable = asyncio.Event() + self.writable.set() + + # Buffers + self.bytes = b"" + self.text = "" + print(len(self.tasks)) + + def connection_made(self, transport: BaseTransport) -> None: + """Called when a connection is made.""" + transport = typing.cast(Transport, transport) + self.connections.add(self) + self.transport = transport + self.server = get_local_addr(transport) + self.client = get_remote_addr(transport) + self.scheme = "wss" if is_ssl(transport) else "ws" + + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) + + def connection_lost(self, exc: typing.Optional[Exception]) -> None: + self.connections.remove(self) + print('came in connection lost : ', exc) + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) + if self.handshake_initiated and not self.close_sent: + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) + + + def data_received(self, data: bytes) -> None: + try: + self.conn.receive_data(data) + except Exception as exc: + self.logger.exception("Exception in ASGI server") + self.transport.close() + self.handle_events() + + def shutdown(self) -> None: + if not self.transport.is_closing(): + if self.handshake_complete: + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) + self.close_send = True + self.conn.send_close(1012) + output = self.conn.data_to_send() + self.transport.writelines(output) + elif self.handshake_initiated: + self.send_500_response() + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) + self.transport.close() + + def handle_events(self) -> None: + for event in self.conn.events_received(): + if isinstance(event, Request): + self.handle_connect(event) + if isinstance(event, Frame): + if event.opcode == websockets.frames.Opcode.CONT: + self.handle_cont(event) + elif event.opcode == websockets.frames.Opcode.TEXT: + self.handle_text(event) + elif event.opcode == websockets.frames.Opcode.BINARY: + self.handle_bytes(event) + elif event.opcode == websockets.frames.Opcode.PING: + self.handle_ping(event) + elif event.opcode == websockets.frames.Opcode.PONG: + self.handle_pong(event) + elif event.opcode == websockets.frames.Opcode.CLOSE: + self.handle_close(event) + + # Event handlers + + def handle_connect(self, event: Request) -> None: + self.request = event + self.response = self.conn.accept(event) + self.handshake_initiated = True + # if status_code is not 101 return response + if self.response.status_code != 101: + self.handshake_complete = True + self.close_sent = True + self.conn.send_response(self.response) + output = self.conn.data_to_send() + self.transport.writelines(output) + self.transport.close() + return + + headers = [ + (key.encode('ascii'), value.encode('ascii', errors='surrogateescape')) + for key, value in event.headers.raw_items() + ] + raw_path, _, query_string = event.path.partition("?") + self.scope: "WebSocketScope" = { # type: ignore[typeddict-item] + "type": "websocket", + "asgi": {"version": self.config.asgi_version, "spec_version": "2.3"}, + "http_version": "1.1", + "scheme": self.scheme, + "server": self.server, + "client": self.client, + "root_path": self.root_path, + "path": unquote(raw_path), + "raw_path": raw_path.encode("ascii"), + "query_string": query_string.encode("ascii"), + "headers": headers, + "subprotocols": event.headers.get_all("Sec-WebSocket-Protocol"), + "extensions": None, + "state": self.app_state.copy(), + } + self.queue.put_nowait({"type": "websocket.connect"}) + task = self.loop.create_task(self.run_asgi()) + task.add_done_callback(self.on_task_complete) + self.tasks.add(task) + + def handle_cont(self, event: Frame) -> None: + self.bytes += event.data + if event.fin: + self.send_receive_event_to_app() + + + def handle_text(self, event: Frame) -> None: + self.bytes = event.data + self.curr_msg_data_type = "text" + if event.fin: + self.send_receive_event_to_app() + + def handle_bytes(self, event: Frame) -> None: + self.bytes = event.data + self.curr_msg_data_type = "bytes" + if event.fin: + self.send_receive_event_to_app() + + def send_receive_event_to_app(self): + if self.curr_msg_data_type == "text": + data = self.bytes.decode() + else: + data = self.bytes + + msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item] + "type": "websocket.receive", + self.curr_msg_data_type: data + } + self.queue.put_nowait(msg) + self.bytes = b"" + self.curr_msg_data_type = None + if not self.read_paused: + self.read_paused = True + self.transport.pause_reading() + + def handle_ping(self, event: Frame) -> None: + output = self.conn.data_to_send() + self.transport.writelines(output) + + def handle_pong(self, event: Frame) -> None: + pass + + def handle_close(self, event: Frame) -> None: + if not self.close_sent and not self.transport.is_closing(): + self.queue.put_nowait({"type": "websocket.disconnect", "code": self.conn.close_rcvd.code}) + output = self.conn.data_to_send() + self.transport.writelines(output) + self.close_sent = True + self.transport.close() + + def on_task_complete(self, task: asyncio.Task) -> None: + self.tasks.discard(task) + + async def run_asgi(self) -> None: + try: + result = await self.app(self.scope, self.receive, self.send) + except BaseException: + self.logger.exception("Exception in ASGI application\n") + if not self.handshake_complete: + self.send_500_response() + self.transport.close() + else: + if not self.handshake_complete: + msg = "ASGI callable returned without completing handshake." + self.logger.error(msg) + self.send_500_response() + self.transport.close() + elif result is not None: + msg = "ASGI callable should return None, but returned '%s'." + self.logger.error(msg, result) + self.transport.close() + + def send_500_response(self) -> None: + msg = b"Internal Server Error" + content = [ + b"HTTP/1.1 500 Internal Server Error\r\n" + b"content-type: text/plain; charset=utf-8\r\n", + b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n", + b"connection: close\r\n", + b"\r\n", + msg, + ] + self.transport.write(b"".join(content)) + + async def send(self, message: "ASGISendEvent") -> None: + await self.writable.wait() + + message_type = message["type"] + + if not self.handshake_complete: + if message_type == "websocket.accept" and not self.transport.is_closing(): + self.logger.info( + '%s - "WebSocket %s" [accepted]', + self.scope["client"], + get_path_with_query_string(self.scope), + ) + headers = [ + (key.decode("ascii"), value.decode("ascii", errors="surrogateescape")) + for key, value in self.default_headers + + list(message.get("headers", [])) + ] + + self.accepted_subprotocol : str = message.get("subprotocol") + if self.accepted_subprotocol: + headers.append(('Sec-WebSocket-Protocol', self.accepted_subprotocol)) + + self.handshake_complete = True + self.response.headers.update(headers) + self.conn.send_response(self.response) + output = self.conn.data_to_send() + self.transport.writelines(output) + + elif message_type == "websocket.close" and not self.transport.is_closing(): + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) + self.logger.info( + '%s - "WebSocket %s" 403', + self.scope["client"], + get_path_with_query_string(self.scope), + ) + extra_headers = [ + (key.decode(), value.decode()) + for key, value in self.default_headers + + list(message.get("headers", [])) + ] + response = self.conn.reject(HTTPStatus.FORBIDDEN, message.get('reason', '')) + response.headers.update(extra_headers) + self.conn.send_response(response) + output = self.conn.data_to_send() + self.close_sent = True + self.hankshake_complete = True + self.transport.writelines(output) + self.transport.close() + + else: + msg = ( + "Expected ASGI message 'websocket.accept' or 'websocket.close', " + "but got '%s'." + ) + raise RuntimeError(msg % message_type) + + elif not self.close_sent: + if message_type == "websocket.send" and not self.transport.is_closing(): + message = typing.cast("WebSocketSendEvent", message) + bytes_data : bytes = message.get("bytes") + text_data : str = message.get("text") + if text_data: + # need to add the logic of sending fragmented data here + self.conn.send_text(text_data.encode()) + elif bytes_data: + self.conn.send_binary(bytes_data) + output = self.conn.data_to_send() + self.transport.writelines(output) + + elif message_type == "websocket.close" and not self.transport.is_closing(): + message = typing.cast("WebSocketCloseEvent", message) + code = message.get("code", 1000) + reason = message.get("reason", "") or "" + self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) + self.conn.send_close(code, reason) + output = self.conn.data_to_send() + self.transport.writelines(output) + self.close_sent = True + self.transport.close() + else: + msg = ( + "Expected ASGI message 'websocket.send' or 'websocket.close'," + " but got '%s'." + ) + raise RuntimeError(msg % message_type) + + else: + msg = "Unexpected ASGI message '%s', after sending 'websocket.close'." + raise RuntimeError(msg % message_type) + + async def receive(self) -> "WebSocketEvent": + message = await self.queue.get() + if self.read_paused and self.queue.empty(): + self.read_paused = False + self.transport.resume_reading() + return message \ No newline at end of file From 559617bd75633358c0631375314199ae9286b01f Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 17:41:57 +0530 Subject: [PATCH 02/59] add surrogate errors in decode --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 954fbab28..abbc80e85 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -322,7 +322,7 @@ async def send(self, message: "ASGISendEvent") -> None: get_path_with_query_string(self.scope), ) extra_headers = [ - (key.decode(), value.decode()) + (key.decode("ascii"), value.decode("ascii", errors="surrogateescape")) for key, value in self.default_headers + list(message.get("headers", [])) ] From 56d2152fc8f077871ce3a617a46b37a2f608ec12 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 18:19:36 +0530 Subject: [PATCH 03/59] fix lint issues --- tests/conftest.py | 2 +- .../websockets/websockets_sansio_impl.py | 65 ++++++++++--------- 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 71a5390e3..cc6a35b9a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -252,7 +252,7 @@ def unused_tcp_port() -> int: ), ), "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", - "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol" + "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol", ] ) def ws_protocol_cls(request: pytest.FixtureRequest): diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index abbc80e85..a593832e4 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -1,16 +1,15 @@ import asyncio import logging import sys -import time import typing from asyncio.transports import BaseTransport, Transport +from http import HTTPStatus from urllib.parse import unquote import websockets +from websockets.frames import Frame from websockets.http11 import Request -from websockets.frames import Frame, Close from websockets.server import ServerConnection -from websockets.connection import State from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL @@ -21,7 +20,6 @@ is_ssl, ) from uvicorn.server import ServerState -from http import HTTPStatus if sys.version_info < (3, 8): from typing_extensions import Literal @@ -31,7 +29,6 @@ if typing.TYPE_CHECKING: from asgiref.typing import ( ASGISendEvent, - WebSocketAcceptEvent, WebSocketCloseEvent, WebSocketConnectEvent, WebSocketDisconnectEvent, @@ -87,7 +84,7 @@ def __init__( # extensions.append(ServerPerMessageDeflateFactory()) self.conn = ServerConnection() self.request = None - self.response = None + self.response = None self.read_paused = False self.writable = asyncio.Event() @@ -97,7 +94,7 @@ def __init__( self.bytes = b"" self.text = "" print(len(self.tasks)) - + def connection_made(self, transport: BaseTransport) -> None: """Called when a connection is made.""" transport = typing.cast(Transport, transport) @@ -113,24 +110,23 @@ def connection_made(self, transport: BaseTransport) -> None: def connection_lost(self, exc: typing.Optional[Exception]) -> None: self.connections.remove(self) - print('came in connection lost : ', exc) + print("came in connection lost : ", exc) if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) if self.handshake_initiated and not self.close_sent: self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) - def data_received(self, data: bytes) -> None: try: self.conn.receive_data(data) - except Exception as exc: + except Exception: self.logger.exception("Exception in ASGI server") self.transport.close() self.handle_events() def shutdown(self) -> None: - if not self.transport.is_closing(): + if not self.transport.is_closing(): if self.handshake_complete: self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) self.close_send = True @@ -141,7 +137,7 @@ def shutdown(self) -> None: self.send_500_response() self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) self.transport.close() - + def handle_events(self) -> None: for event in self.conn.events_received(): if isinstance(event, Request): @@ -165,7 +161,7 @@ def handle_events(self) -> None: def handle_connect(self, event: Request) -> None: self.request = event self.response = self.conn.accept(event) - self.handshake_initiated = True + self.handshake_initiated = True # if status_code is not 101 return response if self.response.status_code != 101: self.handshake_complete = True @@ -177,7 +173,7 @@ def handle_connect(self, event: Request) -> None: return headers = [ - (key.encode('ascii'), value.encode('ascii', errors='surrogateescape')) + (key.encode("ascii"), value.encode("ascii", errors="surrogateescape")) for key, value in event.headers.raw_items() ] raw_path, _, query_string = event.path.partition("?") @@ -206,7 +202,6 @@ def handle_cont(self, event: Frame) -> None: self.bytes += event.data if event.fin: self.send_receive_event_to_app() - def handle_text(self, event: Frame) -> None: self.bytes = event.data @@ -227,9 +222,9 @@ def send_receive_event_to_app(self): data = self.bytes msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item] - "type": "websocket.receive", - self.curr_msg_data_type: data - } + "type": "websocket.receive", + self.curr_msg_data_type: data, + } self.queue.put_nowait(msg) self.bytes = b"" self.curr_msg_data_type = None @@ -246,7 +241,9 @@ def handle_pong(self, event: Frame) -> None: def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): - self.queue.put_nowait({"type": "websocket.disconnect", "code": self.conn.close_rcvd.code}) + self.queue.put_nowait( + {"type": "websocket.disconnect", "code": self.conn.close_rcvd.code} + ) output = self.conn.data_to_send() self.transport.writelines(output) self.close_sent = True @@ -254,7 +251,7 @@ def handle_close(self, event: Frame) -> None: def on_task_complete(self, task: asyncio.Task) -> None: self.tasks.discard(task) - + async def run_asgi(self) -> None: try: result = await self.app(self.scope, self.receive, self.send) @@ -299,18 +296,23 @@ async def send(self, message: "ASGISendEvent") -> None: get_path_with_query_string(self.scope), ) headers = [ - (key.decode("ascii"), value.decode("ascii", errors="surrogateescape")) + ( + key.decode("ascii"), + value.decode("ascii", errors="surrogateescape"), + ) for key, value in self.default_headers + list(message.get("headers", [])) ] - self.accepted_subprotocol : str = message.get("subprotocol") + self.accepted_subprotocol: str = message.get("subprotocol") if self.accepted_subprotocol: - headers.append(('Sec-WebSocket-Protocol', self.accepted_subprotocol)) + headers.append( + ("Sec-WebSocket-Protocol", self.accepted_subprotocol) + ) self.handshake_complete = True self.response.headers.update(headers) - self.conn.send_response(self.response) + self.conn.send_response(self.response) output = self.conn.data_to_send() self.transport.writelines(output) @@ -322,11 +324,16 @@ async def send(self, message: "ASGISendEvent") -> None: get_path_with_query_string(self.scope), ) extra_headers = [ - (key.decode("ascii"), value.decode("ascii", errors="surrogateescape")) + ( + key.decode("ascii"), + value.decode("ascii", errors="surrogateescape"), + ) for key, value in self.default_headers + list(message.get("headers", [])) ] - response = self.conn.reject(HTTPStatus.FORBIDDEN, message.get('reason', '')) + response = self.conn.reject( + HTTPStatus.FORBIDDEN, message.get("reason", "") + ) response.headers.update(extra_headers) self.conn.send_response(response) output = self.conn.data_to_send() @@ -345,8 +352,8 @@ async def send(self, message: "ASGISendEvent") -> None: elif not self.close_sent: if message_type == "websocket.send" and not self.transport.is_closing(): message = typing.cast("WebSocketSendEvent", message) - bytes_data : bytes = message.get("bytes") - text_data : str = message.get("text") + bytes_data: bytes = message.get("bytes") + text_data: str = message.get("text") if text_data: # need to add the logic of sending fragmented data here self.conn.send_text(text_data.encode()) @@ -381,4 +388,4 @@ async def receive(self) -> "WebSocketEvent": if self.read_paused and self.queue.empty(): self.read_paused = False self.transport.resume_reading() - return message \ No newline at end of file + return message From d12e72a96b74831559e25fad7bfe1a0d23906fa2 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 18:59:02 +0530 Subject: [PATCH 04/59] fix mypy failing issues --- .../protocols/websockets/websockets_sansio_impl.py | 14 +++++++------- uvicorn/server.py | 3 ++- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index a593832e4..d91667ee6 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -8,7 +8,7 @@ import websockets from websockets.frames import Frame -from websockets.http11 import Request +from websockets.http11 import Request, Response from websockets.server import ServerConnection from uvicorn.config import Config @@ -82,9 +82,9 @@ def __init__( # extensions = [] # if self.config.ws_per_message_deflate: # extensions.append(ServerPerMessageDeflateFactory()) - self.conn = ServerConnection() - self.request = None - self.response = None + self.conn: typing.Optional[ServerConnection] = ServerConnection() + self.request: typing.Optional[Request] = None + self.response: typing.Optional[Response] = None self.read_paused = False self.writable = asyncio.Event() @@ -177,7 +177,7 @@ def handle_connect(self, event: Request) -> None: for key, value in event.headers.raw_items() ] raw_path, _, query_string = event.path.partition("?") - self.scope: "WebSocketScope" = { # type: ignore[typeddict-item] + self.scope: "WebSocketScope" = { "type": "websocket", "asgi": {"version": self.config.asgi_version, "spec_version": "2.3"}, "http_version": "1.1", @@ -215,13 +215,13 @@ def handle_bytes(self, event: Frame) -> None: if event.fin: self.send_receive_event_to_app() - def send_receive_event_to_app(self): + def send_receive_event_to_app(self) -> None: if self.curr_msg_data_type == "text": data = self.bytes.decode() else: data = self.bytes - msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item] + msg: "WebSocketReceiveEvent" = { "type": "websocket.receive", self.curr_msg_data_type: data, } diff --git a/uvicorn/server.py b/uvicorn/server.py index 3e0db9d01..4ef6d4348 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -20,8 +20,9 @@ from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol from uvicorn.protocols.websockets.wsproto_impl import WSProtocol + from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol - Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol] + Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol, WebSocketsSansIOProtocol] HANDLED_SIGNALS = ( From f24527b7611d7f5dd6eb6bde17de1c0e039d42d4 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 19:03:03 +0530 Subject: [PATCH 05/59] fix lint issues --- .../protocols/websockets/websockets_sansio_impl.py | 2 +- uvicorn/server.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index d91667ee6..a24be37f0 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -221,7 +221,7 @@ def send_receive_event_to_app(self) -> None: else: data = self.bytes - msg: "WebSocketReceiveEvent" = { + msg: "WebSocketReceiveEvent" = { "type": "websocket.receive", self.curr_msg_data_type: data, } diff --git a/uvicorn/server.py b/uvicorn/server.py index 4ef6d4348..29637058c 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -19,10 +19,18 @@ from uvicorn.protocols.http.h11_impl import H11Protocol from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol + from uvicorn.protocols.websockets.websockets_sansio_impl import ( + WebSocketsSansIOProtocol, + ) from uvicorn.protocols.websockets.wsproto_impl import WSProtocol - from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol - Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol, WebSocketsSansIOProtocol] + Protocols = Union[ + H11Protocol, + HttpToolsProtocol, + WSProtocol, + WebSocketProtocol, + WebSocketsSansIOProtocol, + ] HANDLED_SIGNALS = ( From ba972e065cb0f5be92b687ad70b7c5e8f07d719e Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 20:50:30 +0530 Subject: [PATCH 06/59] fix typing issues --- .../websockets/websockets_sansio_impl.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index a24be37f0..49f8ef8d3 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -82,17 +82,17 @@ def __init__( # extensions = [] # if self.config.ws_per_message_deflate: # extensions.append(ServerPerMessageDeflateFactory()) - self.conn: typing.Optional[ServerConnection] = ServerConnection() - self.request: typing.Optional[Request] = None - self.response: typing.Optional[Response] = None + self.conn = ServerConnection() + self.request: Request + self.response: Response + self.curr_msg_data_type: str self.read_paused = False self.writable = asyncio.Event() self.writable.set() # Buffers - self.bytes = b"" - self.text = "" + self.bytes: "bytes" = b"" print(len(self.tasks)) def connection_made(self, transport: BaseTransport) -> None: @@ -216,6 +216,7 @@ def handle_bytes(self, event: Frame) -> None: self.send_receive_event_to_app() def send_receive_event_to_app(self) -> None: + data: typing.Union[str, bytes] if self.curr_msg_data_type == "text": data = self.bytes.decode() else: @@ -226,8 +227,6 @@ def send_receive_event_to_app(self) -> None: self.curr_msg_data_type: data, } self.queue.put_nowait(msg) - self.bytes = b"" - self.curr_msg_data_type = None if not self.read_paused: self.read_paused = True self.transport.pause_reading() @@ -242,7 +241,10 @@ def handle_pong(self, event: Frame) -> None: def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): self.queue.put_nowait( - {"type": "websocket.disconnect", "code": self.conn.close_rcvd.code} + { + "type": "websocket.disconnect", + "code": self.conn.close_rcvd.code, # type: ignore[union-attr] + } ) output = self.conn.data_to_send() self.transport.writelines(output) From b81f7628fefbc0d06aba7d706ab33f367af0c9b6 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 21:41:55 +0530 Subject: [PATCH 07/59] Fix extension tests failing --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 49f8ef8d3..6499a2073 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -7,6 +7,7 @@ from urllib.parse import unquote import websockets +from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory from websockets.frames import Frame from websockets.http11 import Request, Response from websockets.server import ServerConnection @@ -82,7 +83,7 @@ def __init__( # extensions = [] # if self.config.ws_per_message_deflate: # extensions.append(ServerPerMessageDeflateFactory()) - self.conn = ServerConnection() + self.conn = ServerConnection(extensions=[ServerPerMessageDeflateFactory()]) self.request: Request self.response: Response self.curr_msg_data_type: str From 0f59f77f90bba9302adeb6f8e70b3690f0940fdf Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 22:02:26 +0530 Subject: [PATCH 08/59] Fix extension tests failing --- .../websockets/websockets_sansio_impl.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 6499a2073..055b42a06 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -80,10 +80,10 @@ def __init__( self.handshake_complete = False self.close_sent = False - # extensions = [] - # if self.config.ws_per_message_deflate: - # extensions.append(ServerPerMessageDeflateFactory()) - self.conn = ServerConnection(extensions=[ServerPerMessageDeflateFactory()]) + extensions = [] + if self.config.ws_per_message_deflate: + extensions.append(ServerPerMessageDeflateFactory()) + self.conn = ServerConnection(extensions=extensions) self.request: Request self.response: Response self.curr_msg_data_type: str @@ -307,11 +307,9 @@ async def send(self, message: "ASGISendEvent") -> None: + list(message.get("headers", [])) ] - self.accepted_subprotocol: str = message.get("subprotocol") - if self.accepted_subprotocol: - headers.append( - ("Sec-WebSocket-Protocol", self.accepted_subprotocol) - ) + accepted_subprotocol: str = message.get("subprotocol") + if accepted_subprotocol: + headers.append(("Sec-WebSocket-Protocol", accepted_subprotocol)) self.handshake_complete = True self.response.headers.update(headers) From 38e16294c9d479c9577dbb89e4e72a1af12339ba Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Sun, 30 Jul 2023 00:26:30 +0530 Subject: [PATCH 09/59] correct types import --- .../websockets/websockets_sansio_impl.py | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 055b42a06..0ed8e442a 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -28,23 +28,18 @@ from typing import Literal if typing.TYPE_CHECKING: - from asgiref.typing import ( + from uvicorn._types import ( + ASGIReceiveEvent, ASGISendEvent, - WebSocketCloseEvent, WebSocketConnectEvent, - WebSocketDisconnectEvent, + WebSocketAcceptEvent, WebSocketReceiveEvent, - WebSocketScope, WebSocketSendEvent, + WebSocketCloseEvent, + WebSocketDisconnectEvent, + WebSocketScope, ) - WebSocketEvent = typing.Union[ - "WebSocketReceiveEvent", - "WebSocketDisconnectEvent", - "WebSocketConnectEvent", - ] - - class WebSocketsSansIOProtocol(asyncio.Protocol): def __init__( self, @@ -75,7 +70,7 @@ def __init__( self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment] # WebSocket state - self.queue: asyncio.Queue["WebSocketEvent"] = asyncio.Queue() + self.queue: asyncio.Queue["ASGIReceiveEvent"] = asyncio.Queue() self.handshake_initiated = False self.handshake_complete = False self.close_sent = False @@ -94,7 +89,6 @@ def __init__( # Buffers self.bytes: "bytes" = b"" - print(len(self.tasks)) def connection_made(self, transport: BaseTransport) -> None: """Called when a connection is made.""" @@ -111,7 +105,6 @@ def connection_made(self, transport: BaseTransport) -> None: def connection_lost(self, exc: typing.Optional[Exception]) -> None: self.connections.remove(self) - print("came in connection lost : ", exc) if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) @@ -241,12 +234,11 @@ def handle_pong(self, event: Frame) -> None: def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): - self.queue.put_nowait( - { - "type": "websocket.disconnect", - "code": self.conn.close_rcvd.code, # type: ignore[union-attr] - } - ) + disconnect_event: "WebSocketDisconnectEvent" = { + "type": "websocket.disconnect", + "code": self.conn.close_rcvd.code + } + self.queue.put_nowait(disconnect_event) output = self.conn.data_to_send() self.transport.writelines(output) self.close_sent = True @@ -293,6 +285,7 @@ async def send(self, message: "ASGISendEvent") -> None: if not self.handshake_complete: if message_type == "websocket.accept" and not self.transport.is_closing(): + message = typing.cast("WebSocketAcceptEvent", message) self.logger.info( '%s - "WebSocket %s" [accepted]', self.scope["client"], @@ -318,6 +311,7 @@ async def send(self, message: "ASGISendEvent") -> None: self.transport.writelines(output) elif message_type == "websocket.close" and not self.transport.is_closing(): + message = typing.cast("WebSocketCloseEvent", message) self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) self.logger.info( '%s - "WebSocket %s" 403', @@ -384,7 +378,7 @@ async def send(self, message: "ASGISendEvent") -> None: msg = "Unexpected ASGI message '%s', after sending 'websocket.close'." raise RuntimeError(msg % message_type) - async def receive(self) -> "WebSocketEvent": + async def receive(self) -> "ASGIReceiveEvent": message = await self.queue.get() if self.read_paused and self.queue.empty(): self.read_paused = False From 29d2d094486d65f41540b860c90e3776ad23c405 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Sun, 30 Jul 2023 01:24:01 +0530 Subject: [PATCH 10/59] correct types import and mypy issues --- .../websockets/websockets_sansio_impl.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 0ed8e442a..3057a50df 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -31,15 +31,15 @@ from uvicorn._types import ( ASGIReceiveEvent, ASGISendEvent, - WebSocketConnectEvent, WebSocketAcceptEvent, - WebSocketReceiveEvent, - WebSocketSendEvent, WebSocketCloseEvent, WebSocketDisconnectEvent, + WebSocketReceiveEvent, WebSocketScope, + WebSocketSendEvent, ) + class WebSocketsSansIOProtocol(asyncio.Protocol): def __init__( self, @@ -184,7 +184,6 @@ def handle_connect(self, event: Request) -> None: "query_string": query_string.encode("ascii"), "headers": headers, "subprotocols": event.headers.get_all("Sec-WebSocket-Protocol"), - "extensions": None, "state": self.app_state.copy(), } self.queue.put_nowait({"type": "websocket.connect"}) @@ -218,7 +217,7 @@ def send_receive_event_to_app(self) -> None: msg: "WebSocketReceiveEvent" = { "type": "websocket.receive", - self.curr_msg_data_type: data, + self.curr_msg_data_type: data, # type: ignore[misc] } self.queue.put_nowait(msg) if not self.read_paused: @@ -236,7 +235,7 @@ def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): disconnect_event: "WebSocketDisconnectEvent" = { "type": "websocket.disconnect", - "code": self.conn.close_rcvd.code + "code": self.conn.close_rcvd.code, # type: ignore[union-attr] } self.queue.put_nowait(disconnect_event) output = self.conn.data_to_send() @@ -300,7 +299,7 @@ async def send(self, message: "ASGISendEvent") -> None: + list(message.get("headers", [])) ] - accepted_subprotocol: str = message.get("subprotocol") + accepted_subprotocol = message.get("subprotocol") if accepted_subprotocol: headers.append(("Sec-WebSocket-Protocol", accepted_subprotocol)) @@ -312,7 +311,12 @@ async def send(self, message: "ASGISendEvent") -> None: elif message_type == "websocket.close" and not self.transport.is_closing(): message = typing.cast("WebSocketCloseEvent", message) - self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) + self.queue.put_nowait( + { + "type": "websocket.disconnect", + "code": message.get("code", 1000) or 1000, + } + ) self.logger.info( '%s - "WebSocket %s" 403', self.scope["client"], @@ -324,10 +328,10 @@ async def send(self, message: "ASGISendEvent") -> None: value.decode("ascii", errors="surrogateescape"), ) for key, value in self.default_headers - + list(message.get("headers", [])) ] + response = self.conn.reject( - HTTPStatus.FORBIDDEN, message.get("reason", "") + HTTPStatus.FORBIDDEN, message.get("reason", "") or "" ) response.headers.update(extra_headers) self.conn.send_response(response) @@ -347,10 +351,9 @@ async def send(self, message: "ASGISendEvent") -> None: elif not self.close_sent: if message_type == "websocket.send" and not self.transport.is_closing(): message = typing.cast("WebSocketSendEvent", message) - bytes_data: bytes = message.get("bytes") - text_data: str = message.get("text") + bytes_data = message.get("bytes") + text_data = message.get("text") if text_data: - # need to add the logic of sending fragmented data here self.conn.send_text(text_data.encode()) elif bytes_data: self.conn.send_binary(bytes_data) From 28f2714a909cf9531ce0f134fcc76161b546765e Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Sun, 30 Jul 2023 01:44:56 +0530 Subject: [PATCH 11/59] fix typo --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 3057a50df..b272b56a2 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -123,7 +123,7 @@ def shutdown(self) -> None: if not self.transport.is_closing(): if self.handshake_complete: self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) - self.close_send = True + self.close_sent = True self.conn.send_close(1012) output = self.conn.data_to_send() self.transport.writelines(output) From 3a725044a27b0f68c885506b82c086d5ba2a5f4b Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Tue, 1 Aug 2023 22:11:59 +0530 Subject: [PATCH 12/59] Replace ServerConnection with ServerProtocol due to upgradation of websockets version --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index b272b56a2..269be083e 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -10,7 +10,7 @@ from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory from websockets.frames import Frame from websockets.http11 import Request, Response -from websockets.server import ServerConnection +from websockets.server import ServerProtocol from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL @@ -78,7 +78,7 @@ def __init__( extensions = [] if self.config.ws_per_message_deflate: extensions.append(ServerPerMessageDeflateFactory()) - self.conn = ServerConnection(extensions=extensions) + self.conn = ServerProtocol(extensions=extensions) self.request: Request self.response: Response self.curr_msg_data_type: str From 39e3c33c1a14323fb29cbe675abb858ab2d93bbe Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 27 Aug 2023 21:03:36 +0200 Subject: [PATCH 13/59] Remove conditional on imports --- .../websockets/websockets_sansio_impl.py | 29 +++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 269be083e..2604083e2 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -1,9 +1,9 @@ import asyncio import logging -import sys import typing from asyncio.transports import BaseTransport, Transport from http import HTTPStatus +from typing import Literal from urllib.parse import unquote import websockets @@ -12,6 +12,16 @@ from websockets.http11 import Request, Response from websockets.server import ServerProtocol +from uvicorn._types import ( + ASGIReceiveEvent, + ASGISendEvent, + WebSocketAcceptEvent, + WebSocketCloseEvent, + WebSocketDisconnectEvent, + WebSocketReceiveEvent, + WebSocketScope, + WebSocketSendEvent, +) from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL from uvicorn.protocols.utils import ( @@ -22,23 +32,6 @@ ) from uvicorn.server import ServerState -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - -if typing.TYPE_CHECKING: - from uvicorn._types import ( - ASGIReceiveEvent, - ASGISendEvent, - WebSocketAcceptEvent, - WebSocketCloseEvent, - WebSocketDisconnectEvent, - WebSocketReceiveEvent, - WebSocketScope, - WebSocketSendEvent, - ) - class WebSocketsSansIOProtocol(asyncio.Protocol): def __init__( From 931e78e442a371ed25a63d96ffed594b45f0dab3 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 27 Aug 2023 21:12:54 +0200 Subject: [PATCH 14/59] Fix typos, and small details --- .../websockets/websockets_sansio_impl.py | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 2604083e2..8613a068d 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -6,9 +6,8 @@ from typing import Literal from urllib.parse import unquote -import websockets from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory -from websockets.frames import Frame +from websockets.frames import Frame, Opcode from websockets.http11 import Request, Response from websockets.server import ServerProtocol @@ -81,7 +80,7 @@ def __init__( self.writable.set() # Buffers - self.bytes: "bytes" = b"" + self.bytes = b"" def connection_made(self, transport: BaseTransport) -> None: """Called when a connection is made.""" @@ -130,17 +129,17 @@ def handle_events(self) -> None: if isinstance(event, Request): self.handle_connect(event) if isinstance(event, Frame): - if event.opcode == websockets.frames.Opcode.CONT: + if event.opcode == Opcode.CONT: self.handle_cont(event) - elif event.opcode == websockets.frames.Opcode.TEXT: + elif event.opcode == Opcode.TEXT: self.handle_text(event) - elif event.opcode == websockets.frames.Opcode.BINARY: + elif event.opcode == Opcode.BINARY: self.handle_bytes(event) - elif event.opcode == websockets.frames.Opcode.PING: + elif event.opcode == Opcode.PING: self.handle_ping(event) - elif event.opcode == websockets.frames.Opcode.PONG: + elif event.opcode == Opcode.PONG: self.handle_pong(event) - elif event.opcode == websockets.frames.Opcode.CLOSE: + elif event.opcode == Opcode.CLOSE: self.handle_close(event) # Event handlers @@ -208,7 +207,7 @@ def send_receive_event_to_app(self) -> None: else: data = self.bytes - msg: "WebSocketReceiveEvent" = { + msg: WebSocketReceiveEvent = { "type": "websocket.receive", self.curr_msg_data_type: data, # type: ignore[misc] } @@ -226,7 +225,7 @@ def handle_pong(self, event: Frame) -> None: def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): - disconnect_event: "WebSocketDisconnectEvent" = { + disconnect_event: WebSocketDisconnectEvent = { "type": "websocket.disconnect", "code": self.conn.close_rcvd.code, # type: ignore[union-attr] } @@ -270,7 +269,7 @@ def send_500_response(self) -> None: ] self.transport.write(b"".join(content)) - async def send(self, message: "ASGISendEvent") -> None: + async def send(self, message: ASGISendEvent) -> None: await self.writable.wait() message_type = message["type"] @@ -330,7 +329,7 @@ async def send(self, message: "ASGISendEvent") -> None: self.conn.send_response(response) output = self.conn.data_to_send() self.close_sent = True - self.hankshake_complete = True + self.handshake_complete = True self.transport.writelines(output) self.transport.close() @@ -343,7 +342,7 @@ async def send(self, message: "ASGISendEvent") -> None: elif not self.close_sent: if message_type == "websocket.send" and not self.transport.is_closing(): - message = typing.cast("WebSocketSendEvent", message) + message = typing.cast(WebSocketSendEvent, message) bytes_data = message.get("bytes") text_data = message.get("text") if text_data: @@ -354,7 +353,7 @@ async def send(self, message: "ASGISendEvent") -> None: self.transport.writelines(output) elif message_type == "websocket.close" and not self.transport.is_closing(): - message = typing.cast("WebSocketCloseEvent", message) + message = typing.cast(WebSocketCloseEvent, message) code = message.get("code", 1000) reason = message.get("reason", "") or "" self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) @@ -374,7 +373,7 @@ async def send(self, message: "ASGISendEvent") -> None: msg = "Unexpected ASGI message '%s', after sending 'websocket.close'." raise RuntimeError(msg % message_type) - async def receive(self) -> "ASGIReceiveEvent": + async def receive(self) -> ASGIReceiveEvent: message = await self.queue.get() if self.read_paused and self.queue.empty(): self.read_paused = False From 3d57661c700b66f345f56fd85dc874ec28a33b40 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 28 Aug 2023 07:26:11 +0200 Subject: [PATCH 15/59] Refactor small things --- uvicorn/_types.py | 4 +- .../protocols/websockets/websockets_impl.py | 11 +---- .../websockets/websockets_sansio_impl.py | 49 +++++++++---------- uvicorn/protocols/websockets/wsproto_impl.py | 4 +- uvicorn/server.py | 2 +- 5 files changed, 30 insertions(+), 40 deletions(-) diff --git a/uvicorn/_types.py b/uvicorn/_types.py index ecc3bd5c9..8bd2b7c70 100644 --- a/uvicorn/_types.py +++ b/uvicorn/_types.py @@ -162,8 +162,8 @@ class WebSocketAcceptEvent(TypedDict): class WebSocketReceiveEvent(TypedDict): type: Literal["websocket.receive"] - bytes: Optional[bytes] - text: Optional[str] + bytes: NotRequired[bytes] + text: NotRequired[str] class WebSocketSendEvent(TypedDict): diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 089eeb536..94f40f233 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -364,13 +364,6 @@ async def asgi_receive( return {"type": "websocket.disconnect", "code": 1012} return {"type": "websocket.disconnect", "code": exc.code} - msg: WebSocketReceiveEvent = { # type: ignore[typeddict-item] - "type": "websocket.receive" - } - if isinstance(data, str): - msg["text"] = data - else: - msg["bytes"] = data - - return msg + return {"type": "websocket.receive", "text": data} + return {"type": "websocket.receive", "bytes": data} diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 8613a068d..980f957f2 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -8,7 +8,7 @@ from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory from websockets.frames import Frame, Opcode -from websockets.http11 import Request, Response +from websockets.http11 import Request from websockets.server import ServerProtocol from uvicorn._types import ( @@ -62,18 +62,19 @@ def __init__( self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment] # WebSocket state - self.queue: asyncio.Queue["ASGIReceiveEvent"] = asyncio.Queue() + self.queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue() self.handshake_initiated = False self.handshake_complete = False self.close_sent = False extensions = [] if self.config.ws_per_message_deflate: - extensions.append(ServerPerMessageDeflateFactory()) - self.conn = ServerProtocol(extensions=extensions) - self.request: Request - self.response: Response - self.curr_msg_data_type: str + extensions = [ServerPerMessageDeflateFactory()] + self.conn = ServerProtocol( + extensions=extensions, + max_size=self.config.ws_max_size, + logger=logging.getLogger("uvicorn.error"), + ) self.read_paused = False self.writable = asyncio.Event() @@ -103,14 +104,6 @@ def connection_lost(self, exc: typing.Optional[Exception]) -> None: if self.handshake_initiated and not self.close_sent: self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) - def data_received(self, data: bytes) -> None: - try: - self.conn.receive_data(data) - except Exception: - self.logger.exception("Exception in ASGI server") - self.transport.close() - self.handle_events() - def shutdown(self) -> None: if not self.transport.is_closing(): if self.handshake_complete: @@ -124,6 +117,14 @@ def shutdown(self) -> None: self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) self.transport.close() + def data_received(self, data: bytes) -> None: + try: + self.conn.receive_data(data) + except Exception: + self.logger.exception("Exception in ASGI server") + self.transport.close() + self.handle_events() + def handle_events(self) -> None: for event in self.conn.events_received(): if isinstance(event, Request): @@ -190,7 +191,7 @@ def handle_cont(self, event: Frame) -> None: def handle_text(self, event: Frame) -> None: self.bytes = event.data - self.curr_msg_data_type = "text" + self.curr_msg_data_type: Literal["text", "bytes"] = "text" if event.fin: self.send_receive_event_to_app() @@ -201,16 +202,12 @@ def handle_bytes(self, event: Frame) -> None: self.send_receive_event_to_app() def send_receive_event_to_app(self) -> None: - data: typing.Union[str, bytes] - if self.curr_msg_data_type == "text": - data = self.bytes.decode() + data_type = self.curr_msg_data_type + msg: WebSocketReceiveEvent + if data_type == "text": + msg = {"type": "websocket.receive", data_type: self.bytes.decode()} else: - data = self.bytes - - msg: WebSocketReceiveEvent = { - "type": "websocket.receive", - self.curr_msg_data_type: data, # type: ignore[misc] - } + msg = {"type": "websocket.receive", data_type: self.bytes} self.queue.put_nowait(msg) if not self.read_paused: self.read_paused = True @@ -235,7 +232,7 @@ def handle_close(self, event: Frame) -> None: self.close_sent = True self.transport.close() - def on_task_complete(self, task: asyncio.Task) -> None: + def on_task_complete(self, task: asyncio.Task[None]) -> None: self.tasks.discard(task) async def run_asgi(self) -> None: diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index aa4bec8f2..6cdc91a72 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -181,7 +181,7 @@ def handle_connect(self, event: events.Request) -> None: def handle_text(self, event: events.TextMessage) -> None: self.text += event.data if event.message_finished: - msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item] + msg: WebSocketReceiveEvent = { "type": "websocket.receive", "text": self.text, } @@ -195,7 +195,7 @@ def handle_bytes(self, event: events.BytesMessage) -> None: self.bytes += event.data # todo: we may want to guard the size of self.bytes and self.text if event.message_finished: - msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item] + msg: WebSocketReceiveEvent = { "type": "websocket.receive", "bytes": self.bytes, } diff --git a/uvicorn/server.py b/uvicorn/server.py index 29637058c..c3a88eb2e 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -51,7 +51,7 @@ class ServerState: def __init__(self) -> None: self.total_requests = 0 self.connections: Set["Protocols"] = set() - self.tasks: Set[asyncio.Task] = set() + self.tasks: Set[asyncio.Task[None]] = set() self.default_headers: List[Tuple[bytes, bytes]] = [] From d76cdc66757b247860c1f7c127a747bde6afbbf4 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 28 Aug 2023 07:44:46 +0200 Subject: [PATCH 16/59] Fix linter --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 2 +- uvicorn/server.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 980f957f2..6e93fffe2 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -232,7 +232,7 @@ def handle_close(self, event: Frame) -> None: self.close_sent = True self.transport.close() - def on_task_complete(self, task: asyncio.Task[None]) -> None: + def on_task_complete(self, task: "asyncio.Task[None]") -> None: self.tasks.discard(task) async def run_asgi(self) -> None: diff --git a/uvicorn/server.py b/uvicorn/server.py index c3a88eb2e..7664409c6 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -51,7 +51,7 @@ class ServerState: def __init__(self) -> None: self.total_requests = 0 self.connections: Set["Protocols"] = set() - self.tasks: Set[asyncio.Task[None]] = set() + self.tasks: Set["asyncio.Task[None]"] = set() self.default_headers: List[Tuple[bytes, bytes]] = [] From aed00c8bfe977f9c81e5d5d9c4d6e97d02455573 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Tue, 29 Aug 2023 22:41:26 +0530 Subject: [PATCH 17/59] Add tests for websocket server for receiving multiple frames --- tests/protocols/test_websocket.py | 86 +++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 17f2a92d1..469cb109a 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -451,6 +451,92 @@ async def send_text(url): assert data == b"abc" +@pytest.mark.anyio +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_send_text_data_to_server_in_multiple_frames( + ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + http_protocol_cls, + unused_tcp_port: int, +): + message = ( + "This is a long message that will be sent in " + "multiple frames and number of frames will be 5." + ) + + class App(WebSocketResponse): + async def websocket_connect(self, message): + await self.send({"type": "websocket.accept"}) + + async def websocket_receive(self, message): + _text = message.get("text") + await self.send({"type": "websocket.send", "text": _text}) + + async def send_text(url): + async with websockets.client.connect(url) as websocket: + assembled_frames = [] + # send this message in 5 frames + for i in range(5): + # divide the message in 5 parts + msg = message[i * len(message) // 5 : (i + 1) * len(message) // 5] + assembled_frames.append(msg) + await websocket.send(assembled_frames) + return await websocket.recv() + + config = Config( + app=App, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + data = await send_text(f"ws://127.0.0.1:{unused_tcp_port}") + assert data == message + + +@pytest.mark.anyio +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) +async def test_send_binary_data_to_server_in_multiple_frames( + ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + http_protocol_cls, + unused_tcp_port: int, +): + message = ( + b"This is a long message that will be sent in " + b"multiple frames and number of frames will be 5." + ) + + class App(WebSocketResponse): + async def websocket_connect(self, message): + await self.send({"type": "websocket.accept"}) + + async def websocket_receive(self, message): + _bytes = message.get("bytes") + await self.send({"type": "websocket.send", "bytes": _bytes}) + + async def send_bytes(url): + async with websockets.client.connect(url) as websocket: + assembled_frames = [] + # send this message in 5 frames + for i in range(5): + # divide the message in 5 parts + msg = message[i * len(message) // 5 : (i + 1) * len(message) // 5] + assembled_frames.append(msg) + await websocket.send(assembled_frames) + return await websocket.recv() + + config = Config( + app=App, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + data = await send_bytes(f"ws://127.0.0.1:{unused_tcp_port}") + assert data == message + + @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_after_protocol_close( From 808f9515147d8b40b70fbd7761df097a8540e1a9 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Tue, 29 Aug 2023 22:51:33 +0530 Subject: [PATCH 18/59] Remove checking of PONG event after receiving data As it won't be propagated as an event by the websockets sansIO protocol --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 6e93fffe2..110c96fd1 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -138,8 +138,6 @@ def handle_events(self) -> None: self.handle_bytes(event) elif event.opcode == Opcode.PING: self.handle_ping(event) - elif event.opcode == Opcode.PONG: - self.handle_pong(event) elif event.opcode == Opcode.CLOSE: self.handle_close(event) @@ -217,9 +215,6 @@ def handle_ping(self, event: Frame) -> None: output = self.conn.data_to_send() self.transport.writelines(output) - def handle_pong(self, event: Frame) -> None: - pass - def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): disconnect_event: WebSocketDisconnectEvent = { From 803100c3b19589313127475887393affa40d82cf Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Tue, 29 Aug 2023 23:24:48 +0530 Subject: [PATCH 19/59] Revert "Remove checking of PONG event after receiving data" This reverts commit 808f9515147d8b40b70fbd7761df097a8540e1a9. --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 110c96fd1..6e93fffe2 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -138,6 +138,8 @@ def handle_events(self) -> None: self.handle_bytes(event) elif event.opcode == Opcode.PING: self.handle_ping(event) + elif event.opcode == Opcode.PONG: + self.handle_pong(event) elif event.opcode == Opcode.CLOSE: self.handle_close(event) @@ -215,6 +217,9 @@ def handle_ping(self, event: Frame) -> None: output = self.conn.data_to_send() self.transport.writelines(output) + def handle_pong(self, event: Frame) -> None: + pass + def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): disconnect_event: WebSocketDisconnectEvent = { From 7519e6bf8f1013c24b17c2be0d59753146a65354 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Tue, 29 Aug 2023 23:36:25 +0530 Subject: [PATCH 20/59] "Remove checking of PONG event after receiving data" As websockets sansio protocol is not propagating this event after receiving data --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 6e93fffe2..110c96fd1 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -138,8 +138,6 @@ def handle_events(self) -> None: self.handle_bytes(event) elif event.opcode == Opcode.PING: self.handle_ping(event) - elif event.opcode == Opcode.PONG: - self.handle_pong(event) elif event.opcode == Opcode.CLOSE: self.handle_close(event) @@ -217,9 +215,6 @@ def handle_ping(self, event: Frame) -> None: output = self.conn.data_to_send() self.transport.writelines(output) - def handle_pong(self, event: Frame) -> None: - pass - def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): disconnect_event: WebSocketDisconnectEvent = { From 87ad36a962deb7eb422cc85724f1f11eef076207 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 30 Aug 2023 09:26:11 +0200 Subject: [PATCH 21/59] Create WSType on the test suite --- pyproject.toml | 2 +- tests/middleware/test_logging.py | 9 ++- tests/protocols/test_websocket.py | 70 ++++++++++--------- .../websockets/websockets_sansio_impl.py | 2 +- uvicorn/protocols/websockets/wsproto_impl.py | 2 +- 5 files changed, 49 insertions(+), 36 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 56bd79035..850f7f27f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,7 @@ omit = [ [tool.coverage.report] precision = 2 -fail_under = 98.35 +fail_under = 98.65 show_missing = true skip_covered = true exclude_lines = [ diff --git a/tests/middleware/test_logging.py b/tests/middleware/test_logging.py index 84e7c8985..db1799b12 100644 --- a/tests/middleware/test_logging.py +++ b/tests/middleware/test_logging.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import logging import socket @@ -22,8 +24,13 @@ if typing.TYPE_CHECKING: from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol + from uvicorn.protocols.websockets.websockets_sansio_impl import ( + WebSocketsSansIOProtocol, + ) from uvicorn.protocols.websockets.wsproto_impl import WSProtocol + WSType = typing.Type["WSProtocol | WebSocketProtocol | WebSocketsSansIOProtocol"] + @contextlib.contextmanager def caplog_for_logger(caplog, logger_name): @@ -96,7 +103,7 @@ async def test_trace_logging_on_http_protocol( @pytest.mark.anyio async def test_trace_logging_on_ws_protocol( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, caplog, logging_config, unused_tcp_port: int, diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 469cb109a..cde19ec43 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import typing from copy import deepcopy @@ -14,6 +16,7 @@ from tests.utils import run_server from uvicorn.config import Config from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol +from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol try: from uvicorn.protocols.websockets.wsproto_impl import WSProtocol @@ -22,6 +25,9 @@ except ModuleNotFoundError: skip_if_no_wsproto = pytest.mark.skipif(True, reason="wsproto is not installed.") +if typing.TYPE_CHECKING: + WSType = typing.Type["WSProtocol | WebSocketProtocol | WebSocketsSansIOProtocol"] + class WebSocketResponse: def __init__(self, scope, receive, send): @@ -46,7 +52,7 @@ async def asgi(self): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_invalid_upgrade( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -85,7 +91,7 @@ def app(scope): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_accept_connection( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -112,7 +118,7 @@ async def open_connection(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_supports_permessage_deflate_extension( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -142,7 +148,7 @@ async def open_connection(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_can_disable_permessage_deflate_extension( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -175,7 +181,7 @@ async def open_connection(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_close_connection( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -205,7 +211,7 @@ async def open_connection(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_headers( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -238,7 +244,7 @@ async def open_connection(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_extra_headers( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -267,7 +273,7 @@ async def open_connection(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_path_and_raw_path( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -298,7 +304,7 @@ async def open_connection(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_text_data_to_client( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -326,7 +332,7 @@ async def get_data(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_binary_data_to_client( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -354,7 +360,7 @@ async def get_data(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_and_close_connection( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -390,7 +396,7 @@ async def get_data(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_text_data_to_server( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -422,7 +428,7 @@ async def send_text(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_binary_data_to_server( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -454,7 +460,7 @@ async def send_text(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_text_data_to_server_in_multiple_frames( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -497,7 +503,7 @@ async def send_text(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_binary_data_to_server_in_multiple_frames( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -540,7 +546,7 @@ async def send_bytes(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_after_protocol_close( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -578,7 +584,7 @@ async def get_data(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_missing_handshake( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -604,7 +610,7 @@ async def connect(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_before_handshake( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -630,7 +636,7 @@ async def connect(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_duplicate_handshake( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -658,7 +664,7 @@ async def connect(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_asgi_return_value( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -697,7 +703,7 @@ async def connect(url): ids=["none_as_reason", "normal_reason", "without_reason"], ) async def test_app_close( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, code, @@ -744,7 +750,7 @@ async def websocket_session(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_client_close( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -777,7 +783,7 @@ async def websocket_session(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_client_connection_lost( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -816,7 +822,7 @@ async def app(scope, receive, send): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_connection_lost_before_handshake_complete( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -870,7 +876,7 @@ async def websocket_session(uri): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_send_close_on_server_shutdown( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -921,7 +927,7 @@ async def websocket_session(uri): @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) @pytest.mark.parametrize("subprotocol", ["proto1", "proto2"]) async def test_subprotocols( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, subprotocol, unused_tcp_port: int, @@ -1014,7 +1020,7 @@ async def send_text(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_server_reject_connection( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -1055,7 +1061,7 @@ async def websocket_session(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_server_can_read_messages_in_buffer_after_close( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -1100,7 +1106,7 @@ async def send_text(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_default_server_headers( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -1127,7 +1133,7 @@ async def open_connection(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_no_server_headers( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -1183,7 +1189,7 @@ async def open_connection(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_multiple_server_header( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): @@ -1218,7 +1224,7 @@ async def open_connection(url): @pytest.mark.anyio @pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) async def test_lifespan_state( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, http_protocol_cls, unused_tcp_port: int, ): diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 110c96fd1..0e12208b5 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -41,7 +41,7 @@ def __init__( _loop: typing.Optional[asyncio.AbstractEventLoop] = None, ) -> None: if not config.loaded: - config.load() + config.load() # pragma: no cover self.config = config self.app = config.loaded_app diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 6cdc91a72..f7f3af26b 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -39,7 +39,7 @@ def __init__( _loop: typing.Optional[asyncio.AbstractEventLoop] = None, ) -> None: if not config.loaded: - config.load() + config.load() # pragma: no cover self.config = config self.app = config.loaded_app From 1048c18c010816179316acdebfb86b04dc503673 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 30 Aug 2023 09:32:13 +0200 Subject: [PATCH 22/59] Add WebSocketsSansIOProtocol to the CLI --- docs/deployment.md | 2 +- docs/index.md | 2 +- docs/settings.md | 2 +- uvicorn/config.py | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/deployment.md b/docs/deployment.md index 6826db8c5..f3d72eef9 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -60,7 +60,7 @@ Options: --loop [auto|asyncio|uvloop] Event loop implementation. [default: auto] --http [auto|h11|httptools] HTTP protocol implementation. [default: auto] - --ws [auto|none|websockets|wsproto] + --ws [auto|none|websockets|websockets-sansio|wsproto] WebSocket protocol implementation. [default: auto] --ws-max-size INTEGER WebSocket max size message in bytes diff --git a/docs/index.md b/docs/index.md index 503f92c22..d1915a1ae 100644 --- a/docs/index.md +++ b/docs/index.md @@ -130,7 +130,7 @@ Options: --loop [auto|asyncio|uvloop] Event loop implementation. [default: auto] --http [auto|h11|httptools] HTTP protocol implementation. [default: auto] - --ws [auto|none|websockets|wsproto] + --ws [auto|none|websockets|websockets-sansio|wsproto] WebSocket protocol implementation. [default: auto] --ws-max-size INTEGER WebSocket max size message in bytes diff --git a/docs/settings.md b/docs/settings.md index ed4bef52d..011a0207a 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -67,7 +67,7 @@ Using Uvicorn with watchfiles will enable the following options (which are other * `--loop ` - Set the event loop implementation. The uvloop implementation provides greater performance, but is not compatible with Windows or PyPy. **Options:** *'auto', 'asyncio', 'uvloop'.* **Default:** *'auto'*. * `--http ` - Set the HTTP protocol implementation. The httptools implementation provides greater performance, but it not compatible with PyPy. **Options:** *'auto', 'h11', 'httptools'.* **Default:** *'auto'*. -* `--ws ` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'wsproto'.* **Default:** *'auto'*. +* `--ws ` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'websockets-sansio', 'wsproto'.* **Default:** *'auto'*. * `--ws-max-size ` - Set the WebSockets max message size, in bytes. Please note that this can be used only with the default `websockets` protocol. * `--ws-max-queue ` - Set the maximum length of the WebSocket incoming message queue. Please note that this can be used only with the default `websockets` protocol. * `--ws-ping-interval ` - Set the WebSockets ping interval, in seconds. Please note that this can be used only with the default `websockets` protocol. diff --git a/uvicorn/config.py b/uvicorn/config.py index d1c66a4ff..dcb316c46 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -32,7 +32,7 @@ from uvicorn.middleware.wsgi import WSGIMiddleware HTTPProtocolType = Literal["auto", "h11", "httptools"] -WSProtocolType = Literal["auto", "none", "websockets", "wsproto"] +WSProtocolType = Literal["auto", "none", "websockets", "websockets-sansio", "wsproto"] LifespanType = Literal["auto", "on", "off"] LoopSetupType = Literal["none", "auto", "asyncio", "uvloop"] InterfaceType = Literal["auto", "asgi3", "asgi2", "wsgi"] @@ -54,6 +54,7 @@ "auto": "uvicorn.protocols.websockets.auto:AutoWebSocketsProtocol", "none": None, "websockets": "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", + "websockets-sansio": "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol", # noqa: E501 "wsproto": "uvicorn.protocols.websockets.wsproto_impl:WSProtocol", } LIFESPAN: Dict[LifespanType, str] = { From 37a686f8e1f23f3ac85b0b705dbf6e3a68e1f1a0 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Sat, 2 Sep 2023 01:05:22 +0530 Subject: [PATCH 23/59] Make changes for testing payload max_size limit --- tests/conftest.py | 10 +++++++ tests/protocols/test_websocket.py | 3 ++- .../websockets/websockets_sansio_impl.py | 26 +++++++++++++++---- 3 files changed, 33 insertions(+), 6 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index cc6a35b9a..7990776e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -257,3 +257,13 @@ def unused_tcp_port() -> int: ) def ws_protocol_cls(request: pytest.FixtureRequest): return import_from_string(request.param) + + +@pytest.fixture( + params=[ + "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", + "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol", + ] +) +def websockets_legay_plus_sansio_protocol_cls(request: pytest.FixtureRequest): + return import_from_string(request.param) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index cde19ec43..54a82c1f5 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -978,6 +978,7 @@ async def get_subprotocol(url: str): ], ) async def test_send_binary_data_to_server_bigger_than_default_on_websockets( + websockets_legay_plus_sansio_protocol_cls, http_protocol_cls, client_size_sent: int, server_size_max: int, @@ -1001,7 +1002,7 @@ async def send_text(url): config = Config( app=App, - ws=WebSocketProtocol, + ws=websockets_legay_plus_sansio_protocol_cls, http=http_protocol_cls, lifespan="off", ws_max_size=server_size_max, diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 0e12208b5..4b8e773b3 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -6,9 +6,11 @@ from typing import Literal from urllib.parse import unquote +from websockets.exceptions import PayloadTooBig from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory from websockets.frames import Frame, Opcode from websockets.http11 import Request +from websockets.protocol import State as WebSocketsState from websockets.server import ServerProtocol from uvicorn._types import ( @@ -118,11 +120,13 @@ def shutdown(self) -> None: self.transport.close() def data_received(self, data: bytes) -> None: - try: - self.conn.receive_data(data) - except Exception: - self.logger.exception("Exception in ASGI server") - self.transport.close() + self.conn.receive_data(data) + parser_exc = self.conn.parser_exc + if parser_exc is not None: + self.conn = ServerProtocol(state=WebSocketsState.OPEN) + if isinstance(parser_exc, PayloadTooBig): + self.handle_payloadsize_bigger_than_limit_error() + return self.handle_events() def handle_events(self) -> None: @@ -227,6 +231,18 @@ def handle_close(self, event: Frame) -> None: self.close_sent = True self.transport.close() + def handle_payloadsize_bigger_than_limit_error(self) -> None: + disconnect_event: WebSocketDisconnectEvent = { + "type": "websocket.disconnect", + "code": 1009, + } + self.queue.put_nowait(disconnect_event) + self.conn.send_close(1009) + output = self.conn.data_to_send() + self.transport.writelines(output) + self.close_sent = True + self.transport.close() + def on_task_complete(self, task: "asyncio.Task[None]") -> None: self.tasks.discard(task) From 4f76f62d2b80dbca7dbce32c4fa4e50a1d1b1510 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Mon, 4 Sep 2023 22:53:14 +0530 Subject: [PATCH 24/59] Make changes for testing payload max_size limit --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 4b8e773b3..84c33ef6a 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -123,9 +123,7 @@ def data_received(self, data: bytes) -> None: self.conn.receive_data(data) parser_exc = self.conn.parser_exc if parser_exc is not None: - self.conn = ServerProtocol(state=WebSocketsState.OPEN) - if isinstance(parser_exc, PayloadTooBig): - self.handle_payloadsize_bigger_than_limit_error() + self.handle_parser_exception() return self.handle_events() @@ -231,13 +229,12 @@ def handle_close(self, event: Frame) -> None: self.close_sent = True self.transport.close() - def handle_payloadsize_bigger_than_limit_error(self) -> None: + def handle_parser_exception(self) -> None: disconnect_event: WebSocketDisconnectEvent = { "type": "websocket.disconnect", - "code": 1009, + "code": self.conn.close_sent.code, } self.queue.put_nowait(disconnect_event) - self.conn.send_close(1009) output = self.conn.data_to_send() self.transport.writelines(output) self.close_sent = True From cbe36ba79119b4d7bf7af33b1e60b189a6f58ea1 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Mon, 4 Sep 2023 23:00:23 +0530 Subject: [PATCH 25/59] fix lint issue --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 84c33ef6a..704976c44 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -6,11 +6,9 @@ from typing import Literal from urllib.parse import unquote -from websockets.exceptions import PayloadTooBig from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory from websockets.frames import Frame, Opcode from websockets.http11 import Request -from websockets.protocol import State as WebSocketsState from websockets.server import ServerProtocol from uvicorn._types import ( @@ -232,7 +230,7 @@ def handle_close(self, event: Frame) -> None: def handle_parser_exception(self) -> None: disconnect_event: WebSocketDisconnectEvent = { "type": "websocket.disconnect", - "code": self.conn.close_sent.code, + "code": self.conn.close_sent.code, # type: ignore[union-attr] } self.queue.put_nowait(disconnect_event) output = self.conn.data_to_send() From 348b6acc1b585fa7e03684322075424814174df0 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Mon, 4 Sep 2023 23:08:25 +0530 Subject: [PATCH 26/59] increase msg size from 11 to 32 --- tests/protocols/test_websocket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 54a82c1f5..8614fa392 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -968,7 +968,7 @@ async def get_subprotocol(url: str): (MAX_WS_BYTES, MAX_WS_BYTES, 0), (MAX_WS_BYTES_PLUS1, MAX_WS_BYTES, 1009), (10, 10, 0), - (11, 10, 1009), + (32, 10, 1009), ], ids=[ "max=defaults sent=defaults", From 48b1d5ff1349c9788ebf140fc0d8bf033fd98a21 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Tue, 5 Sep 2023 10:18:01 +0530 Subject: [PATCH 27/59] increase client max_limit --- tests/protocols/test_websocket.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 8614fa392..79f5d691b 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -968,7 +968,7 @@ async def get_subprotocol(url: str): (MAX_WS_BYTES, MAX_WS_BYTES, 0), (MAX_WS_BYTES_PLUS1, MAX_WS_BYTES, 1009), (10, 10, 0), - (32, 10, 1009), + (11, 10, 1009), ], ids=[ "max=defaults sent=defaults", @@ -994,9 +994,7 @@ async def websocket_receive(self, message): await self.send({"type": "websocket.send", "bytes": _bytes}) async def send_text(url): - async with websockets.client.connect( - url, max_size=client_size_sent - ) as websocket: + async with websockets.client.connect(url, max_size=MAX_WS_BYTES) as websocket: await websocket.send(b"\x01" * client_size_sent) return await websocket.recv() From 12bb2d2f9054635b90de0c4f72c30b6d76c3d6f4 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Tue, 5 Sep 2023 10:23:52 +0530 Subject: [PATCH 28/59] Empty-Commit-to-trigger-pipeline From 498eaf5874d7d96d2b87cf73843e774c1b744ddc Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 01:42:37 +0530 Subject: [PATCH 29/59] Implement websockets_sans_impl.py --- tests/conftest.py | 1 + .../websockets/websockets_sansio_impl.py | 384 ++++++++++++++++++ 2 files changed, 385 insertions(+) create mode 100644 uvicorn/protocols/websockets/websockets_sansio_impl.py diff --git a/tests/conftest.py b/tests/conftest.py index a405c3175..0997e8966 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -255,6 +255,7 @@ def unused_tcp_port() -> int: ), ), "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", + "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol" ] ) def ws_protocol_cls(request: pytest.FixtureRequest): diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py new file mode 100644 index 000000000..954fbab28 --- /dev/null +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -0,0 +1,384 @@ +import asyncio +import logging +import sys +import time +import typing +from asyncio.transports import BaseTransport, Transport +from urllib.parse import unquote + +import websockets +from websockets.http11 import Request +from websockets.frames import Frame, Close +from websockets.server import ServerConnection +from websockets.connection import State + +from uvicorn.config import Config +from uvicorn.logging import TRACE_LOG_LEVEL +from uvicorn.protocols.utils import ( + get_local_addr, + get_path_with_query_string, + get_remote_addr, + is_ssl, +) +from uvicorn.server import ServerState +from http import HTTPStatus + +if sys.version_info < (3, 8): + from typing_extensions import Literal +else: + from typing import Literal + +if typing.TYPE_CHECKING: + from asgiref.typing import ( + ASGISendEvent, + WebSocketAcceptEvent, + WebSocketCloseEvent, + WebSocketConnectEvent, + WebSocketDisconnectEvent, + WebSocketReceiveEvent, + WebSocketScope, + WebSocketSendEvent, + ) + + WebSocketEvent = typing.Union[ + "WebSocketReceiveEvent", + "WebSocketDisconnectEvent", + "WebSocketConnectEvent", + ] + + +class WebSocketsSansIOProtocol(asyncio.Protocol): + def __init__( + self, + config: Config, + server_state: ServerState, + app_state: typing.Dict[str, typing.Any], + _loop: typing.Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + if not config.loaded: + config.load() + + self.config = config + self.app = config.loaded_app + self.loop = _loop or asyncio.get_event_loop() + self.logger = logging.getLogger("uvicorn.error") + self.root_path = config.root_path + self.app_state = app_state + + # Shared server state + self.connections = server_state.connections + self.tasks = server_state.tasks + self.default_headers = server_state.default_headers + + # Connection state + self.transport: asyncio.Transport = None # type: ignore[assignment] + self.server: typing.Optional[typing.Tuple[str, int]] = None + self.client: typing.Optional[typing.Tuple[str, int]] = None + self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment] + + # WebSocket state + self.queue: asyncio.Queue["WebSocketEvent"] = asyncio.Queue() + self.handshake_initiated = False + self.handshake_complete = False + self.close_sent = False + + # extensions = [] + # if self.config.ws_per_message_deflate: + # extensions.append(ServerPerMessageDeflateFactory()) + self.conn = ServerConnection() + self.request = None + self.response = None + + self.read_paused = False + self.writable = asyncio.Event() + self.writable.set() + + # Buffers + self.bytes = b"" + self.text = "" + print(len(self.tasks)) + + def connection_made(self, transport: BaseTransport) -> None: + """Called when a connection is made.""" + transport = typing.cast(Transport, transport) + self.connections.add(self) + self.transport = transport + self.server = get_local_addr(transport) + self.client = get_remote_addr(transport) + self.scheme = "wss" if is_ssl(transport) else "ws" + + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) + + def connection_lost(self, exc: typing.Optional[Exception]) -> None: + self.connections.remove(self) + print('came in connection lost : ', exc) + if self.logger.level <= TRACE_LOG_LEVEL: + prefix = "%s:%d - " % self.client if self.client else "" + self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) + if self.handshake_initiated and not self.close_sent: + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) + + + def data_received(self, data: bytes) -> None: + try: + self.conn.receive_data(data) + except Exception as exc: + self.logger.exception("Exception in ASGI server") + self.transport.close() + self.handle_events() + + def shutdown(self) -> None: + if not self.transport.is_closing(): + if self.handshake_complete: + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) + self.close_send = True + self.conn.send_close(1012) + output = self.conn.data_to_send() + self.transport.writelines(output) + elif self.handshake_initiated: + self.send_500_response() + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) + self.transport.close() + + def handle_events(self) -> None: + for event in self.conn.events_received(): + if isinstance(event, Request): + self.handle_connect(event) + if isinstance(event, Frame): + if event.opcode == websockets.frames.Opcode.CONT: + self.handle_cont(event) + elif event.opcode == websockets.frames.Opcode.TEXT: + self.handle_text(event) + elif event.opcode == websockets.frames.Opcode.BINARY: + self.handle_bytes(event) + elif event.opcode == websockets.frames.Opcode.PING: + self.handle_ping(event) + elif event.opcode == websockets.frames.Opcode.PONG: + self.handle_pong(event) + elif event.opcode == websockets.frames.Opcode.CLOSE: + self.handle_close(event) + + # Event handlers + + def handle_connect(self, event: Request) -> None: + self.request = event + self.response = self.conn.accept(event) + self.handshake_initiated = True + # if status_code is not 101 return response + if self.response.status_code != 101: + self.handshake_complete = True + self.close_sent = True + self.conn.send_response(self.response) + output = self.conn.data_to_send() + self.transport.writelines(output) + self.transport.close() + return + + headers = [ + (key.encode('ascii'), value.encode('ascii', errors='surrogateescape')) + for key, value in event.headers.raw_items() + ] + raw_path, _, query_string = event.path.partition("?") + self.scope: "WebSocketScope" = { # type: ignore[typeddict-item] + "type": "websocket", + "asgi": {"version": self.config.asgi_version, "spec_version": "2.3"}, + "http_version": "1.1", + "scheme": self.scheme, + "server": self.server, + "client": self.client, + "root_path": self.root_path, + "path": unquote(raw_path), + "raw_path": raw_path.encode("ascii"), + "query_string": query_string.encode("ascii"), + "headers": headers, + "subprotocols": event.headers.get_all("Sec-WebSocket-Protocol"), + "extensions": None, + "state": self.app_state.copy(), + } + self.queue.put_nowait({"type": "websocket.connect"}) + task = self.loop.create_task(self.run_asgi()) + task.add_done_callback(self.on_task_complete) + self.tasks.add(task) + + def handle_cont(self, event: Frame) -> None: + self.bytes += event.data + if event.fin: + self.send_receive_event_to_app() + + + def handle_text(self, event: Frame) -> None: + self.bytes = event.data + self.curr_msg_data_type = "text" + if event.fin: + self.send_receive_event_to_app() + + def handle_bytes(self, event: Frame) -> None: + self.bytes = event.data + self.curr_msg_data_type = "bytes" + if event.fin: + self.send_receive_event_to_app() + + def send_receive_event_to_app(self): + if self.curr_msg_data_type == "text": + data = self.bytes.decode() + else: + data = self.bytes + + msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item] + "type": "websocket.receive", + self.curr_msg_data_type: data + } + self.queue.put_nowait(msg) + self.bytes = b"" + self.curr_msg_data_type = None + if not self.read_paused: + self.read_paused = True + self.transport.pause_reading() + + def handle_ping(self, event: Frame) -> None: + output = self.conn.data_to_send() + self.transport.writelines(output) + + def handle_pong(self, event: Frame) -> None: + pass + + def handle_close(self, event: Frame) -> None: + if not self.close_sent and not self.transport.is_closing(): + self.queue.put_nowait({"type": "websocket.disconnect", "code": self.conn.close_rcvd.code}) + output = self.conn.data_to_send() + self.transport.writelines(output) + self.close_sent = True + self.transport.close() + + def on_task_complete(self, task: asyncio.Task) -> None: + self.tasks.discard(task) + + async def run_asgi(self) -> None: + try: + result = await self.app(self.scope, self.receive, self.send) + except BaseException: + self.logger.exception("Exception in ASGI application\n") + if not self.handshake_complete: + self.send_500_response() + self.transport.close() + else: + if not self.handshake_complete: + msg = "ASGI callable returned without completing handshake." + self.logger.error(msg) + self.send_500_response() + self.transport.close() + elif result is not None: + msg = "ASGI callable should return None, but returned '%s'." + self.logger.error(msg, result) + self.transport.close() + + def send_500_response(self) -> None: + msg = b"Internal Server Error" + content = [ + b"HTTP/1.1 500 Internal Server Error\r\n" + b"content-type: text/plain; charset=utf-8\r\n", + b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n", + b"connection: close\r\n", + b"\r\n", + msg, + ] + self.transport.write(b"".join(content)) + + async def send(self, message: "ASGISendEvent") -> None: + await self.writable.wait() + + message_type = message["type"] + + if not self.handshake_complete: + if message_type == "websocket.accept" and not self.transport.is_closing(): + self.logger.info( + '%s - "WebSocket %s" [accepted]', + self.scope["client"], + get_path_with_query_string(self.scope), + ) + headers = [ + (key.decode("ascii"), value.decode("ascii", errors="surrogateescape")) + for key, value in self.default_headers + + list(message.get("headers", [])) + ] + + self.accepted_subprotocol : str = message.get("subprotocol") + if self.accepted_subprotocol: + headers.append(('Sec-WebSocket-Protocol', self.accepted_subprotocol)) + + self.handshake_complete = True + self.response.headers.update(headers) + self.conn.send_response(self.response) + output = self.conn.data_to_send() + self.transport.writelines(output) + + elif message_type == "websocket.close" and not self.transport.is_closing(): + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) + self.logger.info( + '%s - "WebSocket %s" 403', + self.scope["client"], + get_path_with_query_string(self.scope), + ) + extra_headers = [ + (key.decode(), value.decode()) + for key, value in self.default_headers + + list(message.get("headers", [])) + ] + response = self.conn.reject(HTTPStatus.FORBIDDEN, message.get('reason', '')) + response.headers.update(extra_headers) + self.conn.send_response(response) + output = self.conn.data_to_send() + self.close_sent = True + self.hankshake_complete = True + self.transport.writelines(output) + self.transport.close() + + else: + msg = ( + "Expected ASGI message 'websocket.accept' or 'websocket.close', " + "but got '%s'." + ) + raise RuntimeError(msg % message_type) + + elif not self.close_sent: + if message_type == "websocket.send" and not self.transport.is_closing(): + message = typing.cast("WebSocketSendEvent", message) + bytes_data : bytes = message.get("bytes") + text_data : str = message.get("text") + if text_data: + # need to add the logic of sending fragmented data here + self.conn.send_text(text_data.encode()) + elif bytes_data: + self.conn.send_binary(bytes_data) + output = self.conn.data_to_send() + self.transport.writelines(output) + + elif message_type == "websocket.close" and not self.transport.is_closing(): + message = typing.cast("WebSocketCloseEvent", message) + code = message.get("code", 1000) + reason = message.get("reason", "") or "" + self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) + self.conn.send_close(code, reason) + output = self.conn.data_to_send() + self.transport.writelines(output) + self.close_sent = True + self.transport.close() + else: + msg = ( + "Expected ASGI message 'websocket.send' or 'websocket.close'," + " but got '%s'." + ) + raise RuntimeError(msg % message_type) + + else: + msg = "Unexpected ASGI message '%s', after sending 'websocket.close'." + raise RuntimeError(msg % message_type) + + async def receive(self) -> "WebSocketEvent": + message = await self.queue.get() + if self.read_paused and self.queue.empty(): + self.read_paused = False + self.transport.resume_reading() + return message \ No newline at end of file From e37877016d6abe5408b9250914f1580c1b01d021 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 17:41:57 +0530 Subject: [PATCH 30/59] add surrogate errors in decode --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 954fbab28..abbc80e85 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -322,7 +322,7 @@ async def send(self, message: "ASGISendEvent") -> None: get_path_with_query_string(self.scope), ) extra_headers = [ - (key.decode(), value.decode()) + (key.decode("ascii"), value.decode("ascii", errors="surrogateescape")) for key, value in self.default_headers + list(message.get("headers", [])) ] From 3ce16114ab489ca7c0fd9f2a088fc78d60528279 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 18:19:36 +0530 Subject: [PATCH 31/59] fix lint issues --- tests/conftest.py | 2 +- .../websockets/websockets_sansio_impl.py | 65 ++++++++++--------- 2 files changed, 37 insertions(+), 30 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 0997e8966..9d9d89499 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -255,7 +255,7 @@ def unused_tcp_port() -> int: ), ), "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", - "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol" + "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol", ] ) def ws_protocol_cls(request: pytest.FixtureRequest): diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index abbc80e85..a593832e4 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -1,16 +1,15 @@ import asyncio import logging import sys -import time import typing from asyncio.transports import BaseTransport, Transport +from http import HTTPStatus from urllib.parse import unquote import websockets +from websockets.frames import Frame from websockets.http11 import Request -from websockets.frames import Frame, Close from websockets.server import ServerConnection -from websockets.connection import State from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL @@ -21,7 +20,6 @@ is_ssl, ) from uvicorn.server import ServerState -from http import HTTPStatus if sys.version_info < (3, 8): from typing_extensions import Literal @@ -31,7 +29,6 @@ if typing.TYPE_CHECKING: from asgiref.typing import ( ASGISendEvent, - WebSocketAcceptEvent, WebSocketCloseEvent, WebSocketConnectEvent, WebSocketDisconnectEvent, @@ -87,7 +84,7 @@ def __init__( # extensions.append(ServerPerMessageDeflateFactory()) self.conn = ServerConnection() self.request = None - self.response = None + self.response = None self.read_paused = False self.writable = asyncio.Event() @@ -97,7 +94,7 @@ def __init__( self.bytes = b"" self.text = "" print(len(self.tasks)) - + def connection_made(self, transport: BaseTransport) -> None: """Called when a connection is made.""" transport = typing.cast(Transport, transport) @@ -113,24 +110,23 @@ def connection_made(self, transport: BaseTransport) -> None: def connection_lost(self, exc: typing.Optional[Exception]) -> None: self.connections.remove(self) - print('came in connection lost : ', exc) + print("came in connection lost : ", exc) if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) if self.handshake_initiated and not self.close_sent: self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) - def data_received(self, data: bytes) -> None: try: self.conn.receive_data(data) - except Exception as exc: + except Exception: self.logger.exception("Exception in ASGI server") self.transport.close() self.handle_events() def shutdown(self) -> None: - if not self.transport.is_closing(): + if not self.transport.is_closing(): if self.handshake_complete: self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) self.close_send = True @@ -141,7 +137,7 @@ def shutdown(self) -> None: self.send_500_response() self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) self.transport.close() - + def handle_events(self) -> None: for event in self.conn.events_received(): if isinstance(event, Request): @@ -165,7 +161,7 @@ def handle_events(self) -> None: def handle_connect(self, event: Request) -> None: self.request = event self.response = self.conn.accept(event) - self.handshake_initiated = True + self.handshake_initiated = True # if status_code is not 101 return response if self.response.status_code != 101: self.handshake_complete = True @@ -177,7 +173,7 @@ def handle_connect(self, event: Request) -> None: return headers = [ - (key.encode('ascii'), value.encode('ascii', errors='surrogateescape')) + (key.encode("ascii"), value.encode("ascii", errors="surrogateescape")) for key, value in event.headers.raw_items() ] raw_path, _, query_string = event.path.partition("?") @@ -206,7 +202,6 @@ def handle_cont(self, event: Frame) -> None: self.bytes += event.data if event.fin: self.send_receive_event_to_app() - def handle_text(self, event: Frame) -> None: self.bytes = event.data @@ -227,9 +222,9 @@ def send_receive_event_to_app(self): data = self.bytes msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item] - "type": "websocket.receive", - self.curr_msg_data_type: data - } + "type": "websocket.receive", + self.curr_msg_data_type: data, + } self.queue.put_nowait(msg) self.bytes = b"" self.curr_msg_data_type = None @@ -246,7 +241,9 @@ def handle_pong(self, event: Frame) -> None: def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): - self.queue.put_nowait({"type": "websocket.disconnect", "code": self.conn.close_rcvd.code}) + self.queue.put_nowait( + {"type": "websocket.disconnect", "code": self.conn.close_rcvd.code} + ) output = self.conn.data_to_send() self.transport.writelines(output) self.close_sent = True @@ -254,7 +251,7 @@ def handle_close(self, event: Frame) -> None: def on_task_complete(self, task: asyncio.Task) -> None: self.tasks.discard(task) - + async def run_asgi(self) -> None: try: result = await self.app(self.scope, self.receive, self.send) @@ -299,18 +296,23 @@ async def send(self, message: "ASGISendEvent") -> None: get_path_with_query_string(self.scope), ) headers = [ - (key.decode("ascii"), value.decode("ascii", errors="surrogateescape")) + ( + key.decode("ascii"), + value.decode("ascii", errors="surrogateescape"), + ) for key, value in self.default_headers + list(message.get("headers", [])) ] - self.accepted_subprotocol : str = message.get("subprotocol") + self.accepted_subprotocol: str = message.get("subprotocol") if self.accepted_subprotocol: - headers.append(('Sec-WebSocket-Protocol', self.accepted_subprotocol)) + headers.append( + ("Sec-WebSocket-Protocol", self.accepted_subprotocol) + ) self.handshake_complete = True self.response.headers.update(headers) - self.conn.send_response(self.response) + self.conn.send_response(self.response) output = self.conn.data_to_send() self.transport.writelines(output) @@ -322,11 +324,16 @@ async def send(self, message: "ASGISendEvent") -> None: get_path_with_query_string(self.scope), ) extra_headers = [ - (key.decode("ascii"), value.decode("ascii", errors="surrogateescape")) + ( + key.decode("ascii"), + value.decode("ascii", errors="surrogateescape"), + ) for key, value in self.default_headers + list(message.get("headers", [])) ] - response = self.conn.reject(HTTPStatus.FORBIDDEN, message.get('reason', '')) + response = self.conn.reject( + HTTPStatus.FORBIDDEN, message.get("reason", "") + ) response.headers.update(extra_headers) self.conn.send_response(response) output = self.conn.data_to_send() @@ -345,8 +352,8 @@ async def send(self, message: "ASGISendEvent") -> None: elif not self.close_sent: if message_type == "websocket.send" and not self.transport.is_closing(): message = typing.cast("WebSocketSendEvent", message) - bytes_data : bytes = message.get("bytes") - text_data : str = message.get("text") + bytes_data: bytes = message.get("bytes") + text_data: str = message.get("text") if text_data: # need to add the logic of sending fragmented data here self.conn.send_text(text_data.encode()) @@ -381,4 +388,4 @@ async def receive(self) -> "WebSocketEvent": if self.read_paused and self.queue.empty(): self.read_paused = False self.transport.resume_reading() - return message \ No newline at end of file + return message From 4f601dc01c5f22bcf0819cdbaad0a537a8560cea Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 18:59:02 +0530 Subject: [PATCH 32/59] fix mypy failing issues --- .../protocols/websockets/websockets_sansio_impl.py | 14 +++++++------- uvicorn/server.py | 3 ++- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index a593832e4..d91667ee6 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -8,7 +8,7 @@ import websockets from websockets.frames import Frame -from websockets.http11 import Request +from websockets.http11 import Request, Response from websockets.server import ServerConnection from uvicorn.config import Config @@ -82,9 +82,9 @@ def __init__( # extensions = [] # if self.config.ws_per_message_deflate: # extensions.append(ServerPerMessageDeflateFactory()) - self.conn = ServerConnection() - self.request = None - self.response = None + self.conn: typing.Optional[ServerConnection] = ServerConnection() + self.request: typing.Optional[Request] = None + self.response: typing.Optional[Response] = None self.read_paused = False self.writable = asyncio.Event() @@ -177,7 +177,7 @@ def handle_connect(self, event: Request) -> None: for key, value in event.headers.raw_items() ] raw_path, _, query_string = event.path.partition("?") - self.scope: "WebSocketScope" = { # type: ignore[typeddict-item] + self.scope: "WebSocketScope" = { "type": "websocket", "asgi": {"version": self.config.asgi_version, "spec_version": "2.3"}, "http_version": "1.1", @@ -215,13 +215,13 @@ def handle_bytes(self, event: Frame) -> None: if event.fin: self.send_receive_event_to_app() - def send_receive_event_to_app(self): + def send_receive_event_to_app(self) -> None: if self.curr_msg_data_type == "text": data = self.bytes.decode() else: data = self.bytes - msg: "WebSocketReceiveEvent" = { # type: ignore[typeddict-item] + msg: "WebSocketReceiveEvent" = { "type": "websocket.receive", self.curr_msg_data_type: data, } diff --git a/uvicorn/server.py b/uvicorn/server.py index 1f0b726f8..325eefd3d 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -22,8 +22,9 @@ from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol from uvicorn.protocols.websockets.wsproto_impl import WSProtocol + from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol - Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol] + Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol, WebSocketsSansIOProtocol] HANDLED_SIGNALS = ( signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. From d48f8c1e504853d678cc0c19863ea673fc821582 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 19:03:03 +0530 Subject: [PATCH 33/59] fix lint issues --- .../protocols/websockets/websockets_sansio_impl.py | 2 +- uvicorn/server.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index d91667ee6..a24be37f0 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -221,7 +221,7 @@ def send_receive_event_to_app(self) -> None: else: data = self.bytes - msg: "WebSocketReceiveEvent" = { + msg: "WebSocketReceiveEvent" = { "type": "websocket.receive", self.curr_msg_data_type: data, } diff --git a/uvicorn/server.py b/uvicorn/server.py index 325eefd3d..e3f13cc51 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -21,10 +21,18 @@ from uvicorn.protocols.http.h11_impl import H11Protocol from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol + from uvicorn.protocols.websockets.websockets_sansio_impl import ( + WebSocketsSansIOProtocol, + ) from uvicorn.protocols.websockets.wsproto_impl import WSProtocol - from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol - Protocols = Union[H11Protocol, HttpToolsProtocol, WSProtocol, WebSocketProtocol, WebSocketsSansIOProtocol] + Protocols = Union[ + H11Protocol, + HttpToolsProtocol, + WSProtocol, + WebSocketProtocol, + WebSocketsSansIOProtocol, + ] HANDLED_SIGNALS = ( signal.SIGINT, # Unix signal 2. Sent by Ctrl+C. From ab12969f82b846649c093f1ea8e7280e010445fb Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 20:50:30 +0530 Subject: [PATCH 34/59] fix typing issues --- .../websockets/websockets_sansio_impl.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index a24be37f0..49f8ef8d3 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -82,17 +82,17 @@ def __init__( # extensions = [] # if self.config.ws_per_message_deflate: # extensions.append(ServerPerMessageDeflateFactory()) - self.conn: typing.Optional[ServerConnection] = ServerConnection() - self.request: typing.Optional[Request] = None - self.response: typing.Optional[Response] = None + self.conn = ServerConnection() + self.request: Request + self.response: Response + self.curr_msg_data_type: str self.read_paused = False self.writable = asyncio.Event() self.writable.set() # Buffers - self.bytes = b"" - self.text = "" + self.bytes: "bytes" = b"" print(len(self.tasks)) def connection_made(self, transport: BaseTransport) -> None: @@ -216,6 +216,7 @@ def handle_bytes(self, event: Frame) -> None: self.send_receive_event_to_app() def send_receive_event_to_app(self) -> None: + data: typing.Union[str, bytes] if self.curr_msg_data_type == "text": data = self.bytes.decode() else: @@ -226,8 +227,6 @@ def send_receive_event_to_app(self) -> None: self.curr_msg_data_type: data, } self.queue.put_nowait(msg) - self.bytes = b"" - self.curr_msg_data_type = None if not self.read_paused: self.read_paused = True self.transport.pause_reading() @@ -242,7 +241,10 @@ def handle_pong(self, event: Frame) -> None: def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): self.queue.put_nowait( - {"type": "websocket.disconnect", "code": self.conn.close_rcvd.code} + { + "type": "websocket.disconnect", + "code": self.conn.close_rcvd.code, # type: ignore[union-attr] + } ) output = self.conn.data_to_send() self.transport.writelines(output) From 65057b8160b7d9326ddf8e0bbe65a548934b3e3e Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 21:41:55 +0530 Subject: [PATCH 35/59] Fix extension tests failing --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 49f8ef8d3..6499a2073 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -7,6 +7,7 @@ from urllib.parse import unquote import websockets +from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory from websockets.frames import Frame from websockets.http11 import Request, Response from websockets.server import ServerConnection @@ -82,7 +83,7 @@ def __init__( # extensions = [] # if self.config.ws_per_message_deflate: # extensions.append(ServerPerMessageDeflateFactory()) - self.conn = ServerConnection() + self.conn = ServerConnection(extensions=[ServerPerMessageDeflateFactory()]) self.request: Request self.response: Response self.curr_msg_data_type: str From b81bd5a8de8203c0b77ee90ed5fd8096dcdcfb81 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Thu, 27 Jul 2023 22:02:26 +0530 Subject: [PATCH 36/59] Fix extension tests failing --- .../websockets/websockets_sansio_impl.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 6499a2073..055b42a06 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -80,10 +80,10 @@ def __init__( self.handshake_complete = False self.close_sent = False - # extensions = [] - # if self.config.ws_per_message_deflate: - # extensions.append(ServerPerMessageDeflateFactory()) - self.conn = ServerConnection(extensions=[ServerPerMessageDeflateFactory()]) + extensions = [] + if self.config.ws_per_message_deflate: + extensions.append(ServerPerMessageDeflateFactory()) + self.conn = ServerConnection(extensions=extensions) self.request: Request self.response: Response self.curr_msg_data_type: str @@ -307,11 +307,9 @@ async def send(self, message: "ASGISendEvent") -> None: + list(message.get("headers", [])) ] - self.accepted_subprotocol: str = message.get("subprotocol") - if self.accepted_subprotocol: - headers.append( - ("Sec-WebSocket-Protocol", self.accepted_subprotocol) - ) + accepted_subprotocol: str = message.get("subprotocol") + if accepted_subprotocol: + headers.append(("Sec-WebSocket-Protocol", accepted_subprotocol)) self.handshake_complete = True self.response.headers.update(headers) From 63e6f6827c2905bb95e7cb50929f22153fc8a350 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Sun, 30 Jul 2023 00:26:30 +0530 Subject: [PATCH 37/59] correct types import --- .../websockets/websockets_sansio_impl.py | 36 ++++++++----------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 055b42a06..0ed8e442a 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -28,23 +28,18 @@ from typing import Literal if typing.TYPE_CHECKING: - from asgiref.typing import ( + from uvicorn._types import ( + ASGIReceiveEvent, ASGISendEvent, - WebSocketCloseEvent, WebSocketConnectEvent, - WebSocketDisconnectEvent, + WebSocketAcceptEvent, WebSocketReceiveEvent, - WebSocketScope, WebSocketSendEvent, + WebSocketCloseEvent, + WebSocketDisconnectEvent, + WebSocketScope, ) - WebSocketEvent = typing.Union[ - "WebSocketReceiveEvent", - "WebSocketDisconnectEvent", - "WebSocketConnectEvent", - ] - - class WebSocketsSansIOProtocol(asyncio.Protocol): def __init__( self, @@ -75,7 +70,7 @@ def __init__( self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment] # WebSocket state - self.queue: asyncio.Queue["WebSocketEvent"] = asyncio.Queue() + self.queue: asyncio.Queue["ASGIReceiveEvent"] = asyncio.Queue() self.handshake_initiated = False self.handshake_complete = False self.close_sent = False @@ -94,7 +89,6 @@ def __init__( # Buffers self.bytes: "bytes" = b"" - print(len(self.tasks)) def connection_made(self, transport: BaseTransport) -> None: """Called when a connection is made.""" @@ -111,7 +105,6 @@ def connection_made(self, transport: BaseTransport) -> None: def connection_lost(self, exc: typing.Optional[Exception]) -> None: self.connections.remove(self) - print("came in connection lost : ", exc) if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) @@ -241,12 +234,11 @@ def handle_pong(self, event: Frame) -> None: def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): - self.queue.put_nowait( - { - "type": "websocket.disconnect", - "code": self.conn.close_rcvd.code, # type: ignore[union-attr] - } - ) + disconnect_event: "WebSocketDisconnectEvent" = { + "type": "websocket.disconnect", + "code": self.conn.close_rcvd.code + } + self.queue.put_nowait(disconnect_event) output = self.conn.data_to_send() self.transport.writelines(output) self.close_sent = True @@ -293,6 +285,7 @@ async def send(self, message: "ASGISendEvent") -> None: if not self.handshake_complete: if message_type == "websocket.accept" and not self.transport.is_closing(): + message = typing.cast("WebSocketAcceptEvent", message) self.logger.info( '%s - "WebSocket %s" [accepted]', self.scope["client"], @@ -318,6 +311,7 @@ async def send(self, message: "ASGISendEvent") -> None: self.transport.writelines(output) elif message_type == "websocket.close" and not self.transport.is_closing(): + message = typing.cast("WebSocketCloseEvent", message) self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) self.logger.info( '%s - "WebSocket %s" 403', @@ -384,7 +378,7 @@ async def send(self, message: "ASGISendEvent") -> None: msg = "Unexpected ASGI message '%s', after sending 'websocket.close'." raise RuntimeError(msg % message_type) - async def receive(self) -> "WebSocketEvent": + async def receive(self) -> "ASGIReceiveEvent": message = await self.queue.get() if self.read_paused and self.queue.empty(): self.read_paused = False From b89d73205ee7a4f9c5d5995c2e37afbd1d201c97 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Sun, 30 Jul 2023 01:24:01 +0530 Subject: [PATCH 38/59] correct types import and mypy issues --- .../websockets/websockets_sansio_impl.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 0ed8e442a..3057a50df 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -31,15 +31,15 @@ from uvicorn._types import ( ASGIReceiveEvent, ASGISendEvent, - WebSocketConnectEvent, WebSocketAcceptEvent, - WebSocketReceiveEvent, - WebSocketSendEvent, WebSocketCloseEvent, WebSocketDisconnectEvent, + WebSocketReceiveEvent, WebSocketScope, + WebSocketSendEvent, ) + class WebSocketsSansIOProtocol(asyncio.Protocol): def __init__( self, @@ -184,7 +184,6 @@ def handle_connect(self, event: Request) -> None: "query_string": query_string.encode("ascii"), "headers": headers, "subprotocols": event.headers.get_all("Sec-WebSocket-Protocol"), - "extensions": None, "state": self.app_state.copy(), } self.queue.put_nowait({"type": "websocket.connect"}) @@ -218,7 +217,7 @@ def send_receive_event_to_app(self) -> None: msg: "WebSocketReceiveEvent" = { "type": "websocket.receive", - self.curr_msg_data_type: data, + self.curr_msg_data_type: data, # type: ignore[misc] } self.queue.put_nowait(msg) if not self.read_paused: @@ -236,7 +235,7 @@ def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): disconnect_event: "WebSocketDisconnectEvent" = { "type": "websocket.disconnect", - "code": self.conn.close_rcvd.code + "code": self.conn.close_rcvd.code, # type: ignore[union-attr] } self.queue.put_nowait(disconnect_event) output = self.conn.data_to_send() @@ -300,7 +299,7 @@ async def send(self, message: "ASGISendEvent") -> None: + list(message.get("headers", [])) ] - accepted_subprotocol: str = message.get("subprotocol") + accepted_subprotocol = message.get("subprotocol") if accepted_subprotocol: headers.append(("Sec-WebSocket-Protocol", accepted_subprotocol)) @@ -312,7 +311,12 @@ async def send(self, message: "ASGISendEvent") -> None: elif message_type == "websocket.close" and not self.transport.is_closing(): message = typing.cast("WebSocketCloseEvent", message) - self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) + self.queue.put_nowait( + { + "type": "websocket.disconnect", + "code": message.get("code", 1000) or 1000, + } + ) self.logger.info( '%s - "WebSocket %s" 403', self.scope["client"], @@ -324,10 +328,10 @@ async def send(self, message: "ASGISendEvent") -> None: value.decode("ascii", errors="surrogateescape"), ) for key, value in self.default_headers - + list(message.get("headers", [])) ] + response = self.conn.reject( - HTTPStatus.FORBIDDEN, message.get("reason", "") + HTTPStatus.FORBIDDEN, message.get("reason", "") or "" ) response.headers.update(extra_headers) self.conn.send_response(response) @@ -347,10 +351,9 @@ async def send(self, message: "ASGISendEvent") -> None: elif not self.close_sent: if message_type == "websocket.send" and not self.transport.is_closing(): message = typing.cast("WebSocketSendEvent", message) - bytes_data: bytes = message.get("bytes") - text_data: str = message.get("text") + bytes_data = message.get("bytes") + text_data = message.get("text") if text_data: - # need to add the logic of sending fragmented data here self.conn.send_text(text_data.encode()) elif bytes_data: self.conn.send_binary(bytes_data) From 526fe561b678f283320222ab904afe68278dd340 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Sun, 30 Jul 2023 01:44:56 +0530 Subject: [PATCH 39/59] fix typo --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 3057a50df..b272b56a2 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -123,7 +123,7 @@ def shutdown(self) -> None: if not self.transport.is_closing(): if self.handshake_complete: self.queue.put_nowait({"type": "websocket.disconnect", "code": 1012}) - self.close_send = True + self.close_sent = True self.conn.send_close(1012) output = self.conn.data_to_send() self.transport.writelines(output) From 0699a7e996ede292bdeea070545614f27efccd41 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Tue, 1 Aug 2023 22:11:59 +0530 Subject: [PATCH 40/59] Replace ServerConnection with ServerProtocol due to upgradation of websockets version --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index b272b56a2..269be083e 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -10,7 +10,7 @@ from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory from websockets.frames import Frame from websockets.http11 import Request, Response -from websockets.server import ServerConnection +from websockets.server import ServerProtocol from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL @@ -78,7 +78,7 @@ def __init__( extensions = [] if self.config.ws_per_message_deflate: extensions.append(ServerPerMessageDeflateFactory()) - self.conn = ServerConnection(extensions=extensions) + self.conn = ServerProtocol(extensions=extensions) self.request: Request self.response: Response self.curr_msg_data_type: str From 82f3f6edfe344d3af2025b6b637ae8a09c10fecf Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 27 Aug 2023 21:03:36 +0200 Subject: [PATCH 41/59] Remove conditional on imports --- .../websockets/websockets_sansio_impl.py | 29 +++++++------------ 1 file changed, 11 insertions(+), 18 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 269be083e..2604083e2 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -1,9 +1,9 @@ import asyncio import logging -import sys import typing from asyncio.transports import BaseTransport, Transport from http import HTTPStatus +from typing import Literal from urllib.parse import unquote import websockets @@ -12,6 +12,16 @@ from websockets.http11 import Request, Response from websockets.server import ServerProtocol +from uvicorn._types import ( + ASGIReceiveEvent, + ASGISendEvent, + WebSocketAcceptEvent, + WebSocketCloseEvent, + WebSocketDisconnectEvent, + WebSocketReceiveEvent, + WebSocketScope, + WebSocketSendEvent, +) from uvicorn.config import Config from uvicorn.logging import TRACE_LOG_LEVEL from uvicorn.protocols.utils import ( @@ -22,23 +32,6 @@ ) from uvicorn.server import ServerState -if sys.version_info < (3, 8): - from typing_extensions import Literal -else: - from typing import Literal - -if typing.TYPE_CHECKING: - from uvicorn._types import ( - ASGIReceiveEvent, - ASGISendEvent, - WebSocketAcceptEvent, - WebSocketCloseEvent, - WebSocketDisconnectEvent, - WebSocketReceiveEvent, - WebSocketScope, - WebSocketSendEvent, - ) - class WebSocketsSansIOProtocol(asyncio.Protocol): def __init__( From 7a90b8b6cc91a58959c18657bae7aa73cd9ebfef Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 27 Aug 2023 21:12:54 +0200 Subject: [PATCH 42/59] Fix typos, and small details --- .../websockets/websockets_sansio_impl.py | 31 +++++++++---------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 2604083e2..8613a068d 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -6,9 +6,8 @@ from typing import Literal from urllib.parse import unquote -import websockets from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory -from websockets.frames import Frame +from websockets.frames import Frame, Opcode from websockets.http11 import Request, Response from websockets.server import ServerProtocol @@ -81,7 +80,7 @@ def __init__( self.writable.set() # Buffers - self.bytes: "bytes" = b"" + self.bytes = b"" def connection_made(self, transport: BaseTransport) -> None: """Called when a connection is made.""" @@ -130,17 +129,17 @@ def handle_events(self) -> None: if isinstance(event, Request): self.handle_connect(event) if isinstance(event, Frame): - if event.opcode == websockets.frames.Opcode.CONT: + if event.opcode == Opcode.CONT: self.handle_cont(event) - elif event.opcode == websockets.frames.Opcode.TEXT: + elif event.opcode == Opcode.TEXT: self.handle_text(event) - elif event.opcode == websockets.frames.Opcode.BINARY: + elif event.opcode == Opcode.BINARY: self.handle_bytes(event) - elif event.opcode == websockets.frames.Opcode.PING: + elif event.opcode == Opcode.PING: self.handle_ping(event) - elif event.opcode == websockets.frames.Opcode.PONG: + elif event.opcode == Opcode.PONG: self.handle_pong(event) - elif event.opcode == websockets.frames.Opcode.CLOSE: + elif event.opcode == Opcode.CLOSE: self.handle_close(event) # Event handlers @@ -208,7 +207,7 @@ def send_receive_event_to_app(self) -> None: else: data = self.bytes - msg: "WebSocketReceiveEvent" = { + msg: WebSocketReceiveEvent = { "type": "websocket.receive", self.curr_msg_data_type: data, # type: ignore[misc] } @@ -226,7 +225,7 @@ def handle_pong(self, event: Frame) -> None: def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): - disconnect_event: "WebSocketDisconnectEvent" = { + disconnect_event: WebSocketDisconnectEvent = { "type": "websocket.disconnect", "code": self.conn.close_rcvd.code, # type: ignore[union-attr] } @@ -270,7 +269,7 @@ def send_500_response(self) -> None: ] self.transport.write(b"".join(content)) - async def send(self, message: "ASGISendEvent") -> None: + async def send(self, message: ASGISendEvent) -> None: await self.writable.wait() message_type = message["type"] @@ -330,7 +329,7 @@ async def send(self, message: "ASGISendEvent") -> None: self.conn.send_response(response) output = self.conn.data_to_send() self.close_sent = True - self.hankshake_complete = True + self.handshake_complete = True self.transport.writelines(output) self.transport.close() @@ -343,7 +342,7 @@ async def send(self, message: "ASGISendEvent") -> None: elif not self.close_sent: if message_type == "websocket.send" and not self.transport.is_closing(): - message = typing.cast("WebSocketSendEvent", message) + message = typing.cast(WebSocketSendEvent, message) bytes_data = message.get("bytes") text_data = message.get("text") if text_data: @@ -354,7 +353,7 @@ async def send(self, message: "ASGISendEvent") -> None: self.transport.writelines(output) elif message_type == "websocket.close" and not self.transport.is_closing(): - message = typing.cast("WebSocketCloseEvent", message) + message = typing.cast(WebSocketCloseEvent, message) code = message.get("code", 1000) reason = message.get("reason", "") or "" self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) @@ -374,7 +373,7 @@ async def send(self, message: "ASGISendEvent") -> None: msg = "Unexpected ASGI message '%s', after sending 'websocket.close'." raise RuntimeError(msg % message_type) - async def receive(self) -> "ASGIReceiveEvent": + async def receive(self) -> ASGIReceiveEvent: message = await self.queue.get() if self.read_paused and self.queue.empty(): self.read_paused = False From e72cd54667e0404bd7c118a890dee189389f9655 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 28 Aug 2023 07:26:11 +0200 Subject: [PATCH 43/59] Refactor small things --- .../websockets/websockets_sansio_impl.py | 49 +++++++++---------- 1 file changed, 23 insertions(+), 26 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 8613a068d..980f957f2 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -8,7 +8,7 @@ from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory from websockets.frames import Frame, Opcode -from websockets.http11 import Request, Response +from websockets.http11 import Request from websockets.server import ServerProtocol from uvicorn._types import ( @@ -62,18 +62,19 @@ def __init__( self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment] # WebSocket state - self.queue: asyncio.Queue["ASGIReceiveEvent"] = asyncio.Queue() + self.queue: asyncio.Queue[ASGIReceiveEvent] = asyncio.Queue() self.handshake_initiated = False self.handshake_complete = False self.close_sent = False extensions = [] if self.config.ws_per_message_deflate: - extensions.append(ServerPerMessageDeflateFactory()) - self.conn = ServerProtocol(extensions=extensions) - self.request: Request - self.response: Response - self.curr_msg_data_type: str + extensions = [ServerPerMessageDeflateFactory()] + self.conn = ServerProtocol( + extensions=extensions, + max_size=self.config.ws_max_size, + logger=logging.getLogger("uvicorn.error"), + ) self.read_paused = False self.writable = asyncio.Event() @@ -103,14 +104,6 @@ def connection_lost(self, exc: typing.Optional[Exception]) -> None: if self.handshake_initiated and not self.close_sent: self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) - def data_received(self, data: bytes) -> None: - try: - self.conn.receive_data(data) - except Exception: - self.logger.exception("Exception in ASGI server") - self.transport.close() - self.handle_events() - def shutdown(self) -> None: if not self.transport.is_closing(): if self.handshake_complete: @@ -124,6 +117,14 @@ def shutdown(self) -> None: self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) self.transport.close() + def data_received(self, data: bytes) -> None: + try: + self.conn.receive_data(data) + except Exception: + self.logger.exception("Exception in ASGI server") + self.transport.close() + self.handle_events() + def handle_events(self) -> None: for event in self.conn.events_received(): if isinstance(event, Request): @@ -190,7 +191,7 @@ def handle_cont(self, event: Frame) -> None: def handle_text(self, event: Frame) -> None: self.bytes = event.data - self.curr_msg_data_type = "text" + self.curr_msg_data_type: Literal["text", "bytes"] = "text" if event.fin: self.send_receive_event_to_app() @@ -201,16 +202,12 @@ def handle_bytes(self, event: Frame) -> None: self.send_receive_event_to_app() def send_receive_event_to_app(self) -> None: - data: typing.Union[str, bytes] - if self.curr_msg_data_type == "text": - data = self.bytes.decode() + data_type = self.curr_msg_data_type + msg: WebSocketReceiveEvent + if data_type == "text": + msg = {"type": "websocket.receive", data_type: self.bytes.decode()} else: - data = self.bytes - - msg: WebSocketReceiveEvent = { - "type": "websocket.receive", - self.curr_msg_data_type: data, # type: ignore[misc] - } + msg = {"type": "websocket.receive", data_type: self.bytes} self.queue.put_nowait(msg) if not self.read_paused: self.read_paused = True @@ -235,7 +232,7 @@ def handle_close(self, event: Frame) -> None: self.close_sent = True self.transport.close() - def on_task_complete(self, task: asyncio.Task) -> None: + def on_task_complete(self, task: asyncio.Task[None]) -> None: self.tasks.discard(task) async def run_asgi(self) -> None: From d9a4ea0a18aa4af5c5e98d0bfdc070440dbec7d3 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 28 Aug 2023 07:44:46 +0200 Subject: [PATCH 44/59] Fix linter --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 980f957f2..6e93fffe2 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -232,7 +232,7 @@ def handle_close(self, event: Frame) -> None: self.close_sent = True self.transport.close() - def on_task_complete(self, task: asyncio.Task[None]) -> None: + def on_task_complete(self, task: "asyncio.Task[None]") -> None: self.tasks.discard(task) async def run_asgi(self) -> None: From 252bdc155b15f4c46e2779a98e2950311100fbb2 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Tue, 29 Aug 2023 22:41:26 +0530 Subject: [PATCH 45/59] Add tests for websocket server for receiving multiple frames --- tests/protocols/test_websocket.py | 84 +++++++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index df2415f2e..3b3d63cb3 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -494,6 +494,90 @@ async def send_text(url: str): assert data == b"abc" +@pytest.mark.anyio +async def test_send_text_data_to_server_in_multiple_frames( + ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + unused_tcp_port: int, +): + message = ( + "This is a long message that will be sent in " + "multiple frames and number of frames will be 5." + ) + + class App(WebSocketResponse): + async def websocket_connect(self, message): + await self.send({"type": "websocket.accept"}) + + async def websocket_receive(self, message): + _text = message.get("text") + await self.send({"type": "websocket.send", "text": _text}) + + async def send_text(url): + async with websockets.client.connect(url) as websocket: + assembled_frames = [] + # send this message in 5 frames + for i in range(5): + # divide the message in 5 parts + msg = message[i * len(message) // 5 : (i + 1) * len(message) // 5] + assembled_frames.append(msg) + await websocket.send(assembled_frames) + return await websocket.recv() + + config = Config( + app=App, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + data = await send_text(f"ws://127.0.0.1:{unused_tcp_port}") + assert data == message + + +@pytest.mark.anyio +async def test_send_binary_data_to_server_in_multiple_frames( + ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + unused_tcp_port: int, +): + message = ( + b"This is a long message that will be sent in " + b"multiple frames and number of frames will be 5." + ) + + class App(WebSocketResponse): + async def websocket_connect(self, message): + await self.send({"type": "websocket.accept"}) + + async def websocket_receive(self, message): + _bytes = message.get("bytes") + await self.send({"type": "websocket.send", "bytes": _bytes}) + + async def send_bytes(url): + async with websockets.client.connect(url) as websocket: + assembled_frames = [] + # send this message in 5 frames + for i in range(5): + # divide the message in 5 parts + msg = message[i * len(message) // 5 : (i + 1) * len(message) // 5] + assembled_frames.append(msg) + await websocket.send(assembled_frames) + return await websocket.recv() + + config = Config( + app=App, + ws=ws_protocol_cls, + http=http_protocol_cls, + lifespan="off", + port=unused_tcp_port, + ) + async with run_server(config): + data = await send_bytes(f"ws://127.0.0.1:{unused_tcp_port}") + assert data == message + + @pytest.mark.anyio async def test_send_after_protocol_close( ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", From 9ff1a2e226de2c58d9938c73f18b9654ecb5e1b2 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Tue, 29 Aug 2023 22:51:33 +0530 Subject: [PATCH 46/59] Remove checking of PONG event after receiving data As it won't be propagated as an event by the websockets sansIO protocol --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 6e93fffe2..110c96fd1 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -138,8 +138,6 @@ def handle_events(self) -> None: self.handle_bytes(event) elif event.opcode == Opcode.PING: self.handle_ping(event) - elif event.opcode == Opcode.PONG: - self.handle_pong(event) elif event.opcode == Opcode.CLOSE: self.handle_close(event) @@ -217,9 +215,6 @@ def handle_ping(self, event: Frame) -> None: output = self.conn.data_to_send() self.transport.writelines(output) - def handle_pong(self, event: Frame) -> None: - pass - def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): disconnect_event: WebSocketDisconnectEvent = { From bc35b4f4e73c4c11b68464f27659a61f30eb85ed Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Tue, 29 Aug 2023 23:24:48 +0530 Subject: [PATCH 47/59] Revert "Remove checking of PONG event after receiving data" This reverts commit 808f9515147d8b40b70fbd7761df097a8540e1a9. --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 110c96fd1..6e93fffe2 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -138,6 +138,8 @@ def handle_events(self) -> None: self.handle_bytes(event) elif event.opcode == Opcode.PING: self.handle_ping(event) + elif event.opcode == Opcode.PONG: + self.handle_pong(event) elif event.opcode == Opcode.CLOSE: self.handle_close(event) @@ -215,6 +217,9 @@ def handle_ping(self, event: Frame) -> None: output = self.conn.data_to_send() self.transport.writelines(output) + def handle_pong(self, event: Frame) -> None: + pass + def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): disconnect_event: WebSocketDisconnectEvent = { From b82a8cee37168df38a01c910e077e0d1133464ee Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Tue, 29 Aug 2023 23:36:25 +0530 Subject: [PATCH 48/59] "Remove checking of PONG event after receiving data" As websockets sansio protocol is not propagating this event after receiving data --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 6e93fffe2..110c96fd1 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -138,8 +138,6 @@ def handle_events(self) -> None: self.handle_bytes(event) elif event.opcode == Opcode.PING: self.handle_ping(event) - elif event.opcode == Opcode.PONG: - self.handle_pong(event) elif event.opcode == Opcode.CLOSE: self.handle_close(event) @@ -217,9 +215,6 @@ def handle_ping(self, event: Frame) -> None: output = self.conn.data_to_send() self.transport.writelines(output) - def handle_pong(self, event: Frame) -> None: - pass - def handle_close(self, event: Frame) -> None: if not self.close_sent and not self.transport.is_closing(): disconnect_event: WebSocketDisconnectEvent = { From 09d30725a2192f39e84122dbba1b6ffa4aa39724 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 30 Aug 2023 09:26:11 +0200 Subject: [PATCH 49/59] Create WSType on the test suite --- pyproject.toml | 2 +- tests/middleware/test_logging.py | 9 ++++++++- tests/protocols/test_websocket.py | 5 +++++ uvicorn/protocols/websockets/websockets_sansio_impl.py | 2 +- uvicorn/protocols/websockets/wsproto_impl.py | 2 +- 5 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fc8f0af50..414469b5b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,7 +100,7 @@ omit = [ [tool.coverage.report] precision = 2 -fail_under = 98.35 +fail_under = 98.65 show_missing = true skip_covered = true exclude_lines = [ diff --git a/tests/middleware/test_logging.py b/tests/middleware/test_logging.py index bc49f3463..51539e1d2 100644 --- a/tests/middleware/test_logging.py +++ b/tests/middleware/test_logging.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import logging import socket @@ -14,8 +16,13 @@ if typing.TYPE_CHECKING: from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol + from uvicorn.protocols.websockets.websockets_sansio_impl import ( + WebSocketsSansIOProtocol, + ) from uvicorn.protocols.websockets.wsproto_impl import WSProtocol + WSType = typing.Type["WSProtocol | WebSocketProtocol | WebSocketsSansIOProtocol"] + @contextlib.contextmanager def caplog_for_logger(caplog, logger_name): @@ -87,7 +94,7 @@ async def test_trace_logging_on_http_protocol( @pytest.mark.anyio async def test_trace_logging_on_ws_protocol( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", + ws_protocol_cls: WSType, caplog, logging_config, unused_tcp_port: int, diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 3b3d63cb3..0c242710e 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import typing from copy import deepcopy @@ -23,6 +25,7 @@ ) from uvicorn.config import Config from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol +from uvicorn.protocols.websockets.websockets_sansio_impl import WebSocketsSansIOProtocol try: from uvicorn.protocols.websockets.wsproto_impl import WSProtocol @@ -35,6 +38,8 @@ from uvicorn.protocols.http.h11_impl import H11Protocol from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol + WSType = typing.Type["WSProtocol | WebSocketProtocol | WebSocketsSansIOProtocol"] + class WebSocketResponse: def __init__( diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 110c96fd1..0e12208b5 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -41,7 +41,7 @@ def __init__( _loop: typing.Optional[asyncio.AbstractEventLoop] = None, ) -> None: if not config.loaded: - config.load() + config.load() # pragma: no cover self.config = config self.app = config.loaded_app diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 7929f3a91..1c17ee068 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -42,7 +42,7 @@ def __init__( _loop: asyncio.AbstractEventLoop | None = None, ) -> None: if not config.loaded: - config.load() + config.load() # pragma: no cover self.config = config self.app = config.loaded_app From 5a22d00bd542e4a73b03414c9510f37f38d57f49 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Wed, 30 Aug 2023 09:32:13 +0200 Subject: [PATCH 50/59] Add WebSocketsSansIOProtocol to the CLI --- docs/deployment.md | 2 +- docs/index.md | 2 +- docs/settings.md | 2 +- uvicorn/config.py | 3 ++- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/docs/deployment.md b/docs/deployment.md index 58927d95e..8ff0c490c 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -60,7 +60,7 @@ Options: --loop [auto|asyncio|uvloop] Event loop implementation. [default: auto] --http [auto|h11|httptools] HTTP protocol implementation. [default: auto] - --ws [auto|none|websockets|wsproto] + --ws [auto|none|websockets|websockets-sansio|wsproto] WebSocket protocol implementation. [default: auto] --ws-max-size INTEGER WebSocket max size message in bytes diff --git a/docs/index.md b/docs/index.md index 6ce92feb8..9b1d9a8f7 100644 --- a/docs/index.md +++ b/docs/index.md @@ -130,7 +130,7 @@ Options: --loop [auto|asyncio|uvloop] Event loop implementation. [default: auto] --http [auto|h11|httptools] HTTP protocol implementation. [default: auto] - --ws [auto|none|websockets|wsproto] + --ws [auto|none|websockets|websockets-sansio|wsproto] WebSocket protocol implementation. [default: auto] --ws-max-size INTEGER WebSocket max size message in bytes diff --git a/docs/settings.md b/docs/settings.md index 9c62460fe..391cceaf9 100644 --- a/docs/settings.md +++ b/docs/settings.md @@ -67,7 +67,7 @@ Using Uvicorn with watchfiles will enable the following options (which are other * `--loop ` - Set the event loop implementation. The uvloop implementation provides greater performance, but is not compatible with Windows or PyPy. **Options:** *'auto', 'asyncio', 'uvloop'.* **Default:** *'auto'*. * `--http ` - Set the HTTP protocol implementation. The httptools implementation provides greater performance, but it not compatible with PyPy. **Options:** *'auto', 'h11', 'httptools'.* **Default:** *'auto'*. -* `--ws ` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'wsproto'.* **Default:** *'auto'*. +* `--ws ` - Set the WebSockets protocol implementation. Either of the `websockets` and `wsproto` packages are supported. Use `'none'` to ignore all websocket requests. **Options:** *'auto', 'none', 'websockets', 'websockets-sansio', 'wsproto'.* **Default:** *'auto'*. * `--ws-max-size ` - Set the WebSockets max message size, in bytes. Please note that this can be used only with the default `websockets` protocol. * `--ws-max-queue ` - Set the maximum length of the WebSocket incoming message queue. Please note that this can be used only with the default `websockets` protocol. * `--ws-ping-interval ` - Set the WebSockets ping interval, in seconds. Please note that this can be used only with the default `websockets` protocol. **Default:** *20.0* diff --git a/uvicorn/config.py b/uvicorn/config.py index b0dff4604..702886415 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -23,7 +23,7 @@ from uvicorn.middleware.wsgi import WSGIMiddleware HTTPProtocolType = Literal["auto", "h11", "httptools"] -WSProtocolType = Literal["auto", "none", "websockets", "wsproto"] +WSProtocolType = Literal["auto", "none", "websockets", "websockets-sansio", "wsproto"] LifespanType = Literal["auto", "on", "off"] LoopSetupType = Literal["none", "auto", "asyncio", "uvloop"] InterfaceType = Literal["auto", "asgi3", "asgi2", "wsgi"] @@ -45,6 +45,7 @@ "auto": "uvicorn.protocols.websockets.auto:AutoWebSocketsProtocol", "none": None, "websockets": "uvicorn.protocols.websockets.websockets_impl:WebSocketProtocol", + "websockets-sansio": "uvicorn.protocols.websockets.websockets_sansio_impl:WebSocketsSansIOProtocol", # noqa: E501 "wsproto": "uvicorn.protocols.websockets.wsproto_impl:WSProtocol", } LIFESPAN: dict[LifespanType, str] = { From 5f5f5f47195761cb0be213e2ea770b71942f1506 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Sat, 2 Sep 2023 01:05:22 +0530 Subject: [PATCH 51/59] Make changes for testing payload max_size limit --- tests/protocols/test_websocket.py | 3 ++- .../websockets/websockets_sansio_impl.py | 26 +++++++++++++++---- 2 files changed, 23 insertions(+), 6 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 0c242710e..5d0e500e5 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1006,6 +1006,7 @@ async def get_subprotocol(url: str): ], ) async def test_send_binary_data_to_server_bigger_than_default_on_websockets( + ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", client_size_sent: int, server_size_max: int, @@ -1027,7 +1028,7 @@ async def send_text(url: str): config = Config( app=App, - ws=WebSocketProtocol, + ws=ws_protocol_cls, http=http_protocol_cls, lifespan="off", ws_max_size=server_size_max, diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 0e12208b5..4b8e773b3 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -6,9 +6,11 @@ from typing import Literal from urllib.parse import unquote +from websockets.exceptions import PayloadTooBig from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory from websockets.frames import Frame, Opcode from websockets.http11 import Request +from websockets.protocol import State as WebSocketsState from websockets.server import ServerProtocol from uvicorn._types import ( @@ -118,11 +120,13 @@ def shutdown(self) -> None: self.transport.close() def data_received(self, data: bytes) -> None: - try: - self.conn.receive_data(data) - except Exception: - self.logger.exception("Exception in ASGI server") - self.transport.close() + self.conn.receive_data(data) + parser_exc = self.conn.parser_exc + if parser_exc is not None: + self.conn = ServerProtocol(state=WebSocketsState.OPEN) + if isinstance(parser_exc, PayloadTooBig): + self.handle_payloadsize_bigger_than_limit_error() + return self.handle_events() def handle_events(self) -> None: @@ -227,6 +231,18 @@ def handle_close(self, event: Frame) -> None: self.close_sent = True self.transport.close() + def handle_payloadsize_bigger_than_limit_error(self) -> None: + disconnect_event: WebSocketDisconnectEvent = { + "type": "websocket.disconnect", + "code": 1009, + } + self.queue.put_nowait(disconnect_event) + self.conn.send_close(1009) + output = self.conn.data_to_send() + self.transport.writelines(output) + self.close_sent = True + self.transport.close() + def on_task_complete(self, task: "asyncio.Task[None]") -> None: self.tasks.discard(task) From 8dcd5057d066362aef85ccd79b98f29238b2f519 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Mon, 4 Sep 2023 22:53:14 +0530 Subject: [PATCH 52/59] Make changes for testing payload max_size limit --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 4b8e773b3..84c33ef6a 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -123,9 +123,7 @@ def data_received(self, data: bytes) -> None: self.conn.receive_data(data) parser_exc = self.conn.parser_exc if parser_exc is not None: - self.conn = ServerProtocol(state=WebSocketsState.OPEN) - if isinstance(parser_exc, PayloadTooBig): - self.handle_payloadsize_bigger_than_limit_error() + self.handle_parser_exception() return self.handle_events() @@ -231,13 +229,12 @@ def handle_close(self, event: Frame) -> None: self.close_sent = True self.transport.close() - def handle_payloadsize_bigger_than_limit_error(self) -> None: + def handle_parser_exception(self) -> None: disconnect_event: WebSocketDisconnectEvent = { "type": "websocket.disconnect", - "code": 1009, + "code": self.conn.close_sent.code, } self.queue.put_nowait(disconnect_event) - self.conn.send_close(1009) output = self.conn.data_to_send() self.transport.writelines(output) self.close_sent = True From 3f4eecb114709dad31c697c55ad566d39476d234 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Mon, 4 Sep 2023 23:00:23 +0530 Subject: [PATCH 53/59] fix lint issue --- uvicorn/protocols/websockets/websockets_sansio_impl.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 84c33ef6a..704976c44 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -6,11 +6,9 @@ from typing import Literal from urllib.parse import unquote -from websockets.exceptions import PayloadTooBig from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory from websockets.frames import Frame, Opcode from websockets.http11 import Request -from websockets.protocol import State as WebSocketsState from websockets.server import ServerProtocol from uvicorn._types import ( @@ -232,7 +230,7 @@ def handle_close(self, event: Frame) -> None: def handle_parser_exception(self) -> None: disconnect_event: WebSocketDisconnectEvent = { "type": "websocket.disconnect", - "code": self.conn.close_sent.code, + "code": self.conn.close_sent.code, # type: ignore[union-attr] } self.queue.put_nowait(disconnect_event) output = self.conn.data_to_send() From 60460f9526ce6f93f4291882cba5bc6c0d123e71 Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Mon, 4 Sep 2023 23:08:25 +0530 Subject: [PATCH 54/59] increase msg size from 11 to 32 --- tests/protocols/test_websocket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 5d0e500e5..3963f1aaf 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -996,7 +996,7 @@ async def get_subprotocol(url: str): (MAX_WS_BYTES, MAX_WS_BYTES, 0), (MAX_WS_BYTES_PLUS1, MAX_WS_BYTES, 1009), (10, 10, 0), - (11, 10, 1009), + (32, 10, 1009), ], ids=[ "max=defaults sent=defaults", From 78c69413b0ef5f943533e324c294b650f04c26be Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Tue, 5 Sep 2023 10:18:01 +0530 Subject: [PATCH 55/59] increase client max_limit --- tests/protocols/test_websocket.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 3963f1aaf..5d0e500e5 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -996,7 +996,7 @@ async def get_subprotocol(url: str): (MAX_WS_BYTES, MAX_WS_BYTES, 0), (MAX_WS_BYTES_PLUS1, MAX_WS_BYTES, 1009), (10, 10, 0), - (32, 10, 1009), + (11, 10, 1009), ], ids=[ "max=defaults sent=defaults", From 0ade3d4266cd8733cb60447a6d3c9fcbe220de9c Mon Sep 17 00:00:00 2001 From: Gourav Kandoria Date: Tue, 5 Sep 2023 10:23:52 +0530 Subject: [PATCH 56/59] Empty-Commit-to-trigger-pipeline From db31c56cc40f9b6c606b963602d18172c2e8d3b4 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 26 Dec 2023 12:32:50 +0100 Subject: [PATCH 57/59] Use WSProtocolType --- tests/protocols/test_websocket.py | 216 +++++++++++++++--------------- 1 file changed, 109 insertions(+), 107 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 5d0e500e5..015c5dbdb 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -21,6 +21,7 @@ Scope, WebSocketCloseEvent, WebSocketDisconnectEvent, + WebSocketReceiveEvent, WebSocketResponseStartEvent, ) from uvicorn.config import Config @@ -38,7 +39,8 @@ from uvicorn.protocols.http.h11_impl import H11Protocol from uvicorn.protocols.http.httptools_impl import HttpToolsProtocol - WSType = typing.Type["WSProtocol | WebSocketProtocol | WebSocketsSansIOProtocol"] + HTTPProtocolType = type["H11Protocol | HttpToolsProtocol"] + WSProtocolType = type["WSProtocol | WebSocketProtocol | WebSocketsSansIOProtocol"] class WebSocketResponse: @@ -63,7 +65,7 @@ async def asgi(self): break -async def wsresponse(url): +async def wsresponse(url: str): """ A simple websocket connection request and response helper """ @@ -80,8 +82,8 @@ async def wsresponse(url): @pytest.mark.anyio async def test_invalid_upgrade( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): def app(scope: Scope): @@ -121,15 +123,15 @@ def app(scope: Scope): @pytest.mark.anyio async def test_accept_connection( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) - async def open_connection(url): + async def open_connection(url: str): async with websockets.client.connect(url) as websocket: return websocket.open @@ -147,8 +149,8 @@ async def open_connection(url): @pytest.mark.anyio async def test_shutdown( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): @@ -170,8 +172,8 @@ async def websocket_connect(self, message): @pytest.mark.anyio async def test_supports_permessage_deflate_extension( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): @@ -199,8 +201,8 @@ async def open_connection(url): @pytest.mark.anyio async def test_can_disable_permessage_deflate_extension( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): @@ -231,8 +233,8 @@ async def open_connection(url: str): @pytest.mark.anyio async def test_close_connection( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): @@ -260,8 +262,8 @@ async def open_connection(url: str): @pytest.mark.anyio async def test_headers( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): @@ -292,8 +294,8 @@ async def open_connection(url: str): @pytest.mark.anyio async def test_extra_headers( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): @@ -320,8 +322,8 @@ async def open_connection(url: str): @pytest.mark.anyio async def test_path_and_raw_path( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): @@ -350,8 +352,8 @@ async def open_connection(url: str): @pytest.mark.anyio async def test_send_text_data_to_client( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): @@ -377,8 +379,8 @@ async def get_data(url: str): @pytest.mark.anyio async def test_send_binary_data_to_client( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): @@ -404,8 +406,8 @@ async def get_data(url: str): @pytest.mark.anyio async def test_send_and_close_connection( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): @@ -439,17 +441,17 @@ async def get_data(url: str): @pytest.mark.anyio async def test_send_text_data_to_server( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) - async def websocket_receive(self, message): + async def websocket_receive(self, message: WebSocketReceiveEvent) -> None: _text = message.get("text") - await self.send({"type": "websocket.send", "text": _text}) + await self.send({"type": "websocket.send", "text": _text}) # type: ignore async def send_text(url: str): async with websockets.client.connect(url) as websocket: @@ -470,17 +472,17 @@ async def send_text(url: str): @pytest.mark.anyio async def test_send_binary_data_to_server( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) - async def websocket_receive(self, message): + async def websocket_receive(self, message: WebSocketReceiveEvent) -> None: _bytes = message.get("bytes") - await self.send({"type": "websocket.send", "bytes": _bytes}) + await self.send({"type": "websocket.send", "bytes": _bytes}) # type: ignore async def send_text(url: str): async with websockets.client.connect(url) as websocket: @@ -501,8 +503,8 @@ async def send_text(url: str): @pytest.mark.anyio async def test_send_text_data_to_server_in_multiple_frames( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): message = ( @@ -514,13 +516,13 @@ class App(WebSocketResponse): async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) - async def websocket_receive(self, message): + async def websocket_receive(self, message: WebSocketReceiveEvent) -> None: _text = message.get("text") - await self.send({"type": "websocket.send", "text": _text}) + await self.send({"type": "websocket.send", "text": _text}) # type: ignore - async def send_text(url): + async def send_text(url: str): async with websockets.client.connect(url) as websocket: - assembled_frames = [] + assembled_frames: list[str] = [] # send this message in 5 frames for i in range(5): # divide the message in 5 parts @@ -543,8 +545,8 @@ async def send_text(url): @pytest.mark.anyio async def test_send_binary_data_to_server_in_multiple_frames( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): message = ( @@ -556,13 +558,13 @@ class App(WebSocketResponse): async def websocket_connect(self, message): await self.send({"type": "websocket.accept"}) - async def websocket_receive(self, message): + async def websocket_receive(self, message: WebSocketReceiveEvent) -> None: _bytes = message.get("bytes") - await self.send({"type": "websocket.send", "bytes": _bytes}) + await self.send({"type": "websocket.send", "bytes": _bytes}) # type: ignore - async def send_bytes(url): + async def send_bytes(url: str): async with websockets.client.connect(url) as websocket: - assembled_frames = [] + assembled_frames: list[bytes] = [] # send this message in 5 frames for i in range(5): # divide the message in 5 parts @@ -585,8 +587,8 @@ async def send_bytes(url): @pytest.mark.anyio async def test_send_after_protocol_close( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): @@ -622,8 +624,8 @@ async def get_data(url: str): @pytest.mark.anyio async def test_missing_handshake( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): @@ -647,8 +649,8 @@ async def connect(url: str): @pytest.mark.anyio async def test_send_before_handshake( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): @@ -672,8 +674,8 @@ async def connect(url: str): @pytest.mark.anyio async def test_duplicate_handshake( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): @@ -699,8 +701,8 @@ async def connect(url: str): @pytest.mark.anyio async def test_asgi_return_value( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): """ @@ -737,11 +739,11 @@ async def connect(url: str): ids=["none_as_reason", "normal_reason", "without_reason"], ) async def test_app_close( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, - code: typing.Optional[int], - reason: typing.Optional[str], + code: int | None, + reason: str | None, ): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): while True: @@ -783,8 +785,8 @@ async def websocket_session(url: str): @pytest.mark.anyio async def test_client_close( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): @@ -815,8 +817,8 @@ async def websocket_session(url: str): @pytest.mark.anyio async def test_client_connection_lost( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): got_disconnect_event = False @@ -853,8 +855,8 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable @pytest.mark.anyio async def test_connection_lost_before_handshake_complete( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): send_accept_task = asyncio.Event() @@ -906,8 +908,8 @@ async def websocket_session(uri: str): @pytest.mark.anyio async def test_send_close_on_server_shutdown( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): disconnect_message: WebSocketDisconnectEvent = {} # type: ignore @@ -956,8 +958,8 @@ async def websocket_session(uri: str): @pytest.mark.anyio @pytest.mark.parametrize("subprotocol", ["proto1", "proto2"]) async def test_subprotocols( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, subprotocol: str, unused_tcp_port: int, ): @@ -1006,8 +1008,8 @@ async def get_subprotocol(url: str): ], ) async def test_send_binary_data_to_server_bigger_than_default_on_websockets( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, client_size_sent: int, server_size_max: int, expected_result: int, @@ -1046,8 +1048,8 @@ async def send_text(url: str): @pytest.mark.anyio async def test_server_reject_connection( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): disconnected_message: ASGIReceiveEvent = {} # type: ignore @@ -1089,8 +1091,8 @@ async def websocket_session(url): @pytest.mark.anyio async def test_server_reject_connection_with_response( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): disconnected_message = {} @@ -1129,8 +1131,8 @@ async def websocket_session(url): @pytest.mark.anyio async def test_server_reject_connection_with_multibody_response( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): disconnected_message: ASGIReceiveEvent = {} # type: ignore @@ -1184,8 +1186,8 @@ async def websocket_session(url: str): @pytest.mark.anyio async def test_server_reject_connection_with_invalid_status( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): # this test checks that even if there is an error in the response, the server @@ -1228,8 +1230,8 @@ async def websocket_session(url): @pytest.mark.anyio async def test_server_reject_connection_with_body_nolength( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): # test that the server can send a response with a body but no content-length @@ -1251,7 +1253,7 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable ) await send({"type": "websocket.http.response.body", "body": b"hardbody"}) - async def websocket_session(url): + async def websocket_session(url: str): response = await wsresponse(url) assert response.status_code == 403 assert response.content == b"hardbody" @@ -1275,19 +1277,19 @@ async def websocket_session(url): @pytest.mark.anyio async def test_server_reject_connection_with_invalid_msg( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): - async def app(scope, receive, send): + async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): assert scope["type"] == "websocket" - assert "websocket.http.response" in scope["extensions"] + assert "websocket.http.response" in scope.get("extensions", {}) # Pull up first recv message. - message = await receive() - assert message["type"] == "websocket.connect" + connect_message = await receive() + assert connect_message["type"] == "websocket.connect" - message = { + message: WebSocketResponseStartEvent = { "type": "websocket.http.response.start", "status": 404, "headers": [(b"Content-Length", b"0"), (b"Content-Type", b"text/plain")], @@ -1296,7 +1298,7 @@ async def app(scope, receive, send): # send invalid message. This will raise an exception here await send(message) - async def websocket_session(url): + async def websocket_session(url: str): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: async with websockets.client.connect(url): pass # pragma: no cover @@ -1315,8 +1317,8 @@ async def websocket_session(url): @pytest.mark.anyio async def test_server_reject_connection_with_missing_body( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): async def app(scope, receive, send): @@ -1335,7 +1337,7 @@ async def app(scope, receive, send): await send(message) # no further message - async def websocket_session(url): + async def websocket_session(url: str): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: async with websockets.client.connect(url): pass # pragma: no cover @@ -1354,8 +1356,8 @@ async def websocket_session(url): @pytest.mark.anyio async def test_server_multiple_websocket_http_response_start_events( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): """ @@ -1409,8 +1411,8 @@ async def websocket_session(url: str): @pytest.mark.anyio async def test_server_can_read_messages_in_buffer_after_close( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): frames = [] @@ -1453,8 +1455,8 @@ async def send_text(url: str): @pytest.mark.anyio async def test_default_server_headers( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): @@ -1479,8 +1481,8 @@ async def open_connection(url: str): @pytest.mark.anyio async def test_no_server_headers( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): @@ -1533,8 +1535,8 @@ async def open_connection(url: str): @pytest.mark.anyio async def test_multiple_server_header( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): class App(WebSocketResponse): @@ -1567,8 +1569,8 @@ async def open_connection(url: str): @pytest.mark.anyio async def test_lifespan_state( - ws_protocol_cls: "typing.Type[WSProtocol | WebSocketProtocol]", - http_protocol_cls: "typing.Type[H11Protocol | HttpToolsProtocol]", + ws_protocol_cls: WSProtocolType, + http_protocol_cls: HTTPProtocolType, unused_tcp_port: int, ): expected_states = [ From bf00adae49cbf9e942f3a82a5580d711f2dc71a8 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 26 Dec 2023 12:41:03 +0100 Subject: [PATCH 58/59] Use future annotations on websocket sansio implementation --- .../websockets/websockets_sansio_impl.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index 704976c44..a75243602 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import asyncio import logging -import typing from asyncio.transports import BaseTransport, Transport from http import HTTPStatus -from typing import Literal +from typing import Any, Literal, cast from urllib.parse import unquote from websockets.extensions.permessage_deflate import ServerPerMessageDeflateFactory @@ -37,8 +38,8 @@ def __init__( self, config: Config, server_state: ServerState, - app_state: typing.Dict[str, typing.Any], - _loop: typing.Optional[asyncio.AbstractEventLoop] = None, + app_state: dict[str, Any], + _loop: asyncio.AbstractEventLoop | None = None, ) -> None: if not config.loaded: config.load() # pragma: no cover @@ -57,8 +58,8 @@ def __init__( # Connection state self.transport: asyncio.Transport = None # type: ignore[assignment] - self.server: typing.Optional[typing.Tuple[str, int]] = None - self.client: typing.Optional[typing.Tuple[str, int]] = None + self.server: tuple[str, int] | None = None + self.client: tuple[str, int] | None = None self.scheme: Literal["wss", "ws"] = None # type: ignore[assignment] # WebSocket state @@ -85,7 +86,7 @@ def __init__( def connection_made(self, transport: BaseTransport) -> None: """Called when a connection is made.""" - transport = typing.cast(Transport, transport) + transport = cast(Transport, transport) self.connections.add(self) self.transport = transport self.server = get_local_addr(transport) @@ -96,7 +97,7 @@ def connection_made(self, transport: BaseTransport) -> None: prefix = "%s:%d - " % self.client if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) - def connection_lost(self, exc: typing.Optional[Exception]) -> None: + def connection_lost(self, exc: Exception | None) -> None: self.connections.remove(self) if self.logger.level <= TRACE_LOG_LEVEL: prefix = "%s:%d - " % self.client if self.client else "" @@ -238,7 +239,7 @@ def handle_parser_exception(self) -> None: self.close_sent = True self.transport.close() - def on_task_complete(self, task: "asyncio.Task[None]") -> None: + def on_task_complete(self, task: asyncio.Task[None]) -> None: self.tasks.discard(task) async def run_asgi(self) -> None: @@ -279,7 +280,7 @@ async def send(self, message: ASGISendEvent) -> None: if not self.handshake_complete: if message_type == "websocket.accept" and not self.transport.is_closing(): - message = typing.cast("WebSocketAcceptEvent", message) + message = cast(WebSocketAcceptEvent, message) self.logger.info( '%s - "WebSocket %s" [accepted]', self.scope["client"], @@ -305,7 +306,7 @@ async def send(self, message: ASGISendEvent) -> None: self.transport.writelines(output) elif message_type == "websocket.close" and not self.transport.is_closing(): - message = typing.cast("WebSocketCloseEvent", message) + message = cast(WebSocketCloseEvent, message) self.queue.put_nowait( { "type": "websocket.disconnect", @@ -345,7 +346,7 @@ async def send(self, message: ASGISendEvent) -> None: elif not self.close_sent: if message_type == "websocket.send" and not self.transport.is_closing(): - message = typing.cast(WebSocketSendEvent, message) + message = cast(WebSocketSendEvent, message) bytes_data = message.get("bytes") text_data = message.get("text") if text_data: @@ -356,7 +357,7 @@ async def send(self, message: ASGISendEvent) -> None: self.transport.writelines(output) elif message_type == "websocket.close" and not self.transport.is_closing(): - message = typing.cast(WebSocketCloseEvent, message) + message = cast(WebSocketCloseEvent, message) code = message.get("code", 1000) reason = message.get("reason", "") or "" self.queue.put_nowait({"type": "websocket.disconnect", "code": code}) From 64d6eb707fac78568aa03f7610fa4401334af4c4 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Tue, 26 Dec 2023 14:00:07 +0100 Subject: [PATCH 59/59] WIP websockets denial response extension --- tests/protocols/test_websocket.py | 8 +- .../protocols/websockets/websockets_impl.py | 9 +- .../websockets/websockets_sansio_impl.py | 104 ++++++++++-------- uvicorn/protocols/websockets/wsproto_impl.py | 2 +- 4 files changed, 70 insertions(+), 53 deletions(-) diff --git a/tests/protocols/test_websocket.py b/tests/protocols/test_websocket.py index 015c5dbdb..a2363e46f 100644 --- a/tests/protocols/test_websocket.py +++ b/tests/protocols/test_websocket.py @@ -1369,8 +1369,7 @@ async def test_server_multiple_websocket_http_response_start_events( async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable): nonlocal exception_message assert scope["type"] == "websocket" - assert "extensions" in scope - assert "websocket.http.response" in scope["extensions"] + assert "websocket.http.response" in scope.get("extensions", {}) # Pull up first recv message. message = await receive() @@ -1385,13 +1384,14 @@ async def app(scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable try: await send(start_event) except Exception as exc: + print(exc) exception_message = str(exc) async def websocket_session(url: str): with pytest.raises(websockets.exceptions.InvalidStatusCode) as exc_info: async with websockets.client.connect(url): pass - assert exc_info.value.status_code == 404 + assert exc_info.value.status_code in (404, 500) config = Config( app=app, @@ -1564,7 +1564,7 @@ async def open_connection(url: str): ) async with run_server(config): headers = await open_connection(f"ws://127.0.0.1:{unused_tcp_port}") - assert headers.get_all("Server") == ["uvicorn", "over-ridden", "another-value"] + assert headers.get_all("server") == ["uvicorn", "over-ridden", "another-value"] @pytest.mark.anyio diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index 3f04c1dd5..46bcabcdc 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -290,8 +290,11 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: self.extra_headers.extend( # ASGI spec requires bytes # But for compatibility we need to convert it to strings - (name.decode("latin-1"), value.decode("latin-1")) - for name, value in message["headers"] + ( + name.decode("latin-1").lower(), + value.decode("latin-1").lower(), + ) + for name, value in list(message.get("headers", [])) ) self.handshake_started_event.set() @@ -317,7 +320,7 @@ async def asgi_send(self, message: "ASGISendEvent") -> None: # websockets requires the status to be an enum. look it up. status = http.HTTPStatus(message["status"]) headers = [ - (name.decode("latin-1"), value.decode("latin-1")) + (name.decode("latin-1").lower(), value.decode("latin-1").lower()) for name, value in message.get("headers", []) ] self.initial_response = (status, headers, b"") diff --git a/uvicorn/protocols/websockets/websockets_sansio_impl.py b/uvicorn/protocols/websockets/websockets_sansio_impl.py index a75243602..2960c7af9 100644 --- a/uvicorn/protocols/websockets/websockets_sansio_impl.py +++ b/uvicorn/protocols/websockets/websockets_sansio_impl.py @@ -19,6 +19,8 @@ WebSocketCloseEvent, WebSocketDisconnectEvent, WebSocketReceiveEvent, + WebSocketResponseBodyEvent, + WebSocketResponseStartEvent, WebSocketScope, WebSocketSendEvent, ) @@ -67,6 +69,7 @@ def __init__( self.handshake_initiated = False self.handshake_complete = False self.close_sent = False + self.initial_response: tuple[int, list[tuple[str, str]], bytes] | None = None extensions = [] if self.config.ws_per_message_deflate: @@ -177,6 +180,7 @@ def handle_connect(self, event: Request) -> None: "headers": headers, "subprotocols": event.headers.get_all("Sec-WebSocket-Protocol"), "state": self.app_state.copy(), + "extensions": {"websocket.http.response": {}}, } self.queue.put_nowait({"type": "websocket.connect"}) task = self.loop.create_task(self.run_asgi()) @@ -262,24 +266,20 @@ async def run_asgi(self) -> None: self.transport.close() def send_500_response(self) -> None: - msg = b"Internal Server Error" - content = [ - b"HTTP/1.1 500 Internal Server Error\r\n" - b"content-type: text/plain; charset=utf-8\r\n", - b"content-length: " + str(len(msg)).encode("ascii") + b"\r\n", - b"connection: close\r\n", - b"\r\n", - msg, - ] - self.transport.write(b"".join(content)) + response = self.conn.reject(500, "Internal Server Error") + self.conn.send_response(response) + output = self.conn.data_to_send() + self.transport.writelines(output) async def send(self, message: ASGISendEvent) -> None: await self.writable.wait() message_type = message["type"] - if not self.handshake_complete: - if message_type == "websocket.accept" and not self.transport.is_closing(): + if not self.handshake_complete or ( + self.handshake_complete and self.initial_response is None + ): + if message_type == "websocket.accept": message = cast(WebSocketAcceptEvent, message) self.logger.info( '%s - "WebSocket %s" [accepted]', @@ -287,64 +287,59 @@ async def send(self, message: ASGISendEvent) -> None: get_path_with_query_string(self.scope), ) headers = [ - ( - key.decode("ascii"), - value.decode("ascii", errors="surrogateescape"), + (name.decode("latin-1").lower(), value.decode("latin-1").lower()) + for name, value in ( + self.default_headers + list(message.get("headers", [])) ) - for key, value in self.default_headers - + list(message.get("headers", [])) ] - accepted_subprotocol = message.get("subprotocol") if accepted_subprotocol: headers.append(("Sec-WebSocket-Protocol", accepted_subprotocol)) - - self.handshake_complete = True self.response.headers.update(headers) - self.conn.send_response(self.response) - output = self.conn.data_to_send() - self.transport.writelines(output) - elif message_type == "websocket.close" and not self.transport.is_closing(): + if not self.transport.is_closing(): + self.handshake_complete = True + self.conn.send_response(self.response) + output = self.conn.data_to_send() + self.transport.writelines(output) + + elif message_type == "websocket.close": message = cast(WebSocketCloseEvent, message) - self.queue.put_nowait( - { - "type": "websocket.disconnect", - "code": message.get("code", 1000) or 1000, - } - ) + self.queue.put_nowait({"type": "websocket.disconnect", "code": 1006}) self.logger.info( '%s - "WebSocket %s" 403', self.scope["client"], get_path_with_query_string(self.scope), ) - extra_headers = [ - ( - key.decode("ascii"), - value.decode("ascii", errors="surrogateescape"), - ) - for key, value in self.default_headers - ] - - response = self.conn.reject( - HTTPStatus.FORBIDDEN, message.get("reason", "") or "" - ) - response.headers.update(extra_headers) + response = self.conn.reject(HTTPStatus.FORBIDDEN, "") self.conn.send_response(response) output = self.conn.data_to_send() self.close_sent = True self.handshake_complete = True self.transport.writelines(output) self.transport.close() - + elif message_type == "websocket.http.response.start": + message = cast(WebSocketResponseStartEvent, message) + self.logger.info( + '%s - "WebSocket %s" %d', + self.scope["client"], + get_path_with_query_string(self.scope), + message["status"], + ) + headers = [ + (name.decode("latin-1"), value.decode("latin-1")) + for name, value in list(message.get("headers", [])) + ] + self.initial_response = (message["status"], headers, b"") else: msg = ( - "Expected ASGI message 'websocket.accept' or 'websocket.close', " + "Expected ASGI message 'websocket.accept', 'websocket.close' " + "or 'websocket.http.response.start' " "but got '%s'." ) raise RuntimeError(msg % message_type) - elif not self.close_sent: + elif not self.close_sent and self.initial_response is None: if message_type == "websocket.send" and not self.transport.is_closing(): message = cast(WebSocketSendEvent, message) bytes_data = message.get("bytes") @@ -372,6 +367,25 @@ async def send(self, message: ASGISendEvent) -> None: " but got '%s'." ) raise RuntimeError(msg % message_type) + elif self.initial_response is not None: + if message_type == "websocket.http.response.body": + message = cast(WebSocketResponseBodyEvent, message) + body = self.initial_response[2] + message["body"] + self.initial_response = self.initial_response[:2] + (body,) + if not message.get("more_body", False): + response = self.conn.reject(self.initial_response[0], body.decode()) + response.headers.update(self.initial_response[1]) + self.conn.send_response(response) + output = self.conn.data_to_send() + self.close_sent = True + self.transport.writelines(output) + self.transport.close() + else: + msg = ( + "Expected ASGI message 'websocket.http.response.body' " + "but got '%s'." + ) + raise RuntimeError(msg % message_type) else: msg = "Unexpected ASGI message '%s', after sending 'websocket.close'." diff --git a/uvicorn/protocols/websockets/wsproto_impl.py b/uvicorn/protocols/websockets/wsproto_impl.py index 1c17ee068..db7967f45 100644 --- a/uvicorn/protocols/websockets/wsproto_impl.py +++ b/uvicorn/protocols/websockets/wsproto_impl.py @@ -155,7 +155,7 @@ def shutdown(self) -> None: self.send_500_response() self.transport.close() - def on_task_complete(self, task: asyncio.Task) -> None: + def on_task_complete(self, task: asyncio.Task[None]) -> None: self.tasks.discard(task) # Event handlers