From 81384a7af1841a186825cf702fcd7a8b00c448c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Pierre?= Date: Wed, 2 Oct 2024 10:15:29 +1300 Subject: [PATCH 1/3] [Refactor] server/client: now use body readers and writers --- examples/client-gzip.py | 17 +-- src/py/extra/client.py | 6 +- src/py/extra/http/model.py | 247 +++++++++++++++++++----------------- src/py/extra/server.py | 91 +++++++------ src/py/extra/utils/codec.py | 116 +++++++++++++++++ src/py/extra/utils/io.py | 6 +- 6 files changed, 300 insertions(+), 183 deletions(-) create mode 100644 src/py/extra/utils/codec.py diff --git a/examples/client-gzip.py b/examples/client-gzip.py index 21441fd..bc95f28 100644 --- a/examples/client-gzip.py +++ b/examples/client-gzip.py @@ -1,25 +1,12 @@ import asyncio from extra.client import HTTPClient from extra.http.model import HTTPBodyBlob - -import zlib - - -class GzipDecoder: - def __init__(self): - self.decompressor = zlib.decompressobj(wbits=zlib.MAX_WBITS | 32) - self.buffer = io.BytesIO() - - def feed(self, chunk: bytes) -> bytes | None: - return self.decompressor.decompress(chunk) - - def flush(self) -> bytes | None: - return self.decompressor.flush() +from extra.utils.codec import GZipDecoder # NOTE: Start "examples/sse.py" async def main(path: str, host: str = "127.0.0.1", port: int = 443, ssl: bool = True): - transform = GzipDecoder() + transform = GZipDecoder() with open("/dev/stdout", "wb") as f: async for atom in HTTPClient.Request( diff --git a/src/py/extra/client.py b/src/py/extra/client.py index 3ea8ffa..6126d88 100644 --- a/src/py/extra/client.py +++ b/src/py/extra/client.py @@ -16,7 +16,7 @@ HTTPBodyFile, HTTPHeaders, HTTPBody, - HTTPReaderBody, + HTTPBodyIO, HTTPAtom, HTTPProcessingStatus, ) @@ -479,7 +479,7 @@ async def Request( *, port: int | None = None, headers: dict[str, str] | None = None, - body: HTTPReaderBody | HTTPBodyBlob | None = None, + body: HTTPBodyIO | HTTPBodyBlob | None = None, params: dict[str, str] | str | None = None, ssl: bool = True, verified: bool = True, @@ -591,7 +591,7 @@ async def request( *, port: int | None = None, headers: dict[str, str] | None = None, - body: HTTPReaderBody | HTTPBodyBlob | None = None, + body: HTTPBodyIO | HTTPBodyBlob | None = None, params: dict[str, str] | str | None = None, ssl: bool = True, verified: bool = True, diff --git a/src/py/extra/http/model.py b/src/py/extra/http/model.py index b5e5691..7a51e22 100644 --- a/src/py/extra/http/model.py +++ b/src/py/extra/http/model.py @@ -4,7 +4,6 @@ Iterable, Literal, Generator, - Iterator, TypeAlias, Union, Callable, @@ -13,16 +12,16 @@ ) from abc import ABC, abstractmethod from functools import cached_property +from tempfile import SpooledTemporaryFile from http.cookies import SimpleCookie, Morsel import os.path import inspect -from gzip import GzipFile -from io import BytesIO from pathlib import Path from enum import Enum from ..utils.primitives import TPrimitive from .status import HTTP_STATUS from ..utils.io import DEFAULT_ENCODING, asWritable +from ..utils.codec import BytesTransform from .api import ResponseFactory # NOTE: MyPyC doesn't support async generators. We're trying without. @@ -38,7 +37,7 @@ def headername(name: str, *, headers: dict[str, str] = {}) -> str: - """Normalizes the header name.""" + """Normalizes the header name as `Kebab-Case`.""" if name in headers: return headers[name] key: str = name.lower() @@ -57,22 +56,9 @@ def headername(name: str, *, headers: dict[str, str] = {}) -> str: # ----------------------------------------------------------------------------- -class HTTPRequestError(Exception): - def __init__( - self, - message: str, - status: int | None = None, - contentType: str | None = None, - payload: TPrimitive | None = None, - ): - super().__init__(message) - self.message: str = message - self.status: int | None = status - self.contentType: str | None = contentType - self.payload: TPrimitive | bytes | None = payload - - class HTTPRequestLine(NamedTuple): + """Represents a request status line""" + method: str path: str query: str @@ -80,18 +66,24 @@ class HTTPRequestLine(NamedTuple): class HTTPResponseLine(NamedTuple): + """Represents a response status line""" + protocol: str status: int message: str class HTTPHeaders(NamedTuple): + """Wraps HTTP headers, keeping key information for response/request processing.""" + headers: dict[str, str] contentType: str | None = None contentLength: int | None = None class HTTPProcessingStatus(Enum): + """Internal parser/processor state management""" + Processing = 0 Body = 1 Complete = 2 @@ -100,6 +92,7 @@ class HTTPProcessingStatus(Enum): BadFormat = 12 +# Type alias for the parser would produce HTTPAtom: TypeAlias = Union[ HTTPRequestLine, HTTPResponseLine, @@ -110,40 +103,45 @@ class HTTPProcessingStatus(Enum): "HTTPResponse", ] - # ----------------------------------------------------------------------------- # -# BODY +# ERRORS # # ----------------------------------------------------------------------------- -BODY_READER_TIMEOUT: float = 1.0 +class HTTPRequestError(Exception): + """To be raised by handlers to generate a 500 error.""" + def __init__( + self, + message: str, + status: int | None = None, + contentType: str | None = None, + payload: TPrimitive | None = None, + ): + super().__init__(message) + self.message: str = message + self.status: int | None = status + self.contentType: str | None = contentType + self.payload: TPrimitive | bytes | None = payload -class HTTPBodyReader(ABC): - """A based class for being able to read a request body, typically from a - socket.""" - @abstractmethod - async def read(self, timeout: float = BODY_READER_TIMEOUT) -> bytes | None: ... +# ----------------------------------------------------------------------------- +# +# BODY +# +# ----------------------------------------------------------------------------- - async def load(self, timeout: float = BODY_READER_TIMEOUT) -> bytes: - data = bytearray() - while True: - chunk = await self.read(timeout) - if not chunk: - break - else: - data += chunk - return data + +BODY_READER_TIMEOUT: float = 1.0 -class HTTPReaderBody: - __slots__ = ("reader", "read", "expected", "remaining") - """Represents a body that is loaded from a reader.""" +class HTTPBodyIO: + __slots__ = ["reader", "read", "expected", "remaining"] + """Represents a body that is loaded from a reader IO.""" - def __init__(self, reader: HTTPBodyReader, expected: int | None = None): + def __init__(self, reader: "HTTPBodyReader", expected: int | None = None): self.reader: HTTPBodyReader = reader self.read: int = 0 self.expected: int | None = expected @@ -163,6 +161,8 @@ async def load( class HTTPBodyBlob(NamedTuple): + """Represents a part (or a whole) body as bytes.""" + payload: bytes = b"" length: int = 0 # NOTE: We don't know how many is remaining @@ -179,6 +179,8 @@ async def load( class HTTPBodyFile(NamedTuple): + """Represents an HTTP body from a file, potentially with a file descriptor.""" + path: Path fd: int | None = None @@ -188,13 +190,18 @@ def length(self) -> int: class HTTPBodyStream(NamedTuple): + """An HTTP body that is generated from a stream.""" + stream: Generator[str | bytes | TPrimitive, Any, Any] class HTTPBodyAsyncStream(NamedTuple): + """An HTTP body that is generated from an asynchronous stream.""" + stream: AsyncGenerator[str | bytes | TPrimitive, Any] +# The different types of bodies that are managed THTTPBody: TypeAlias = ( HTTPBodyBlob | HTTPBodyFile | HTTPBodyStream | HTTPBodyAsyncStream ) @@ -211,85 +218,89 @@ def HasRemaining(body: THTTPBody | None) -> bool: return bool(body.remaining) elif isinstance(body, HTTPBodyStream) or isinstance(body, HTTPBodyAsyncStream): return True - elif isinstance(body, HTTPReaderBody): + elif isinstance(body, HTTPBodyIO): return body.remaining is not None else: return False -# We do separate the body, as typically the head of the request is there -# as a whole, and the body can be loaded through different loaders based -# on use case. - # ----------------------------------------------------------------------------- # # BODY TRANSFORMS # # ----------------------------------------------------------------------------- +# We do separate the body, as typically the head of the request is there +# as a whole, and the body can be loaded through different loaders based +# on use case. -class BytesTransform(ABC): - """An abstract bytes transform.""" - - def open(self) -> bool: - return True - - def close(self) -> bool: - return True - - @abstractmethod - def feed( - self, chunk: bytes, more: bool = False - ) -> bytes | None | Literal[False]: ... - - def __enter__(self) -> Iterator[bool]: - yield self.open() +class HTTPBodyReader(ABC): + """A base class for being able to read a request body, typically from a + socket.""" - def __exit__(self, *args: Any, **kwargs: Any) -> None: - self.close() + __slots__ = ["transform"] + def __init__(self, transform: BytesTransform | None = None) -> None: + self.transform: BytesTransform | None = transform -class GZipEncode(BytesTransform): - """An encoder for gzip byte streams.""" + async def read( + self, timeout: float = BODY_READER_TIMEOUT, size: int | None = None + ) -> bytes | None: + chunk = await self._read(timeout) + if chunk is not None and self.transform: + res = self.transform.feed(chunk) + return res if res else None + else: + return chunk - def __init__(self) -> None: - self.out: BytesIO = BytesIO() - self.comp: GzipFile = GzipFile(mode="wb", fileobj=self.out) + @abstractmethod + async def _read( + self, timeout: float = BODY_READER_TIMEOUT, size: int | None = None + ) -> bytes | None: ... - def flush(self) -> bytes | None | Literal[False]: - return None + # NOTE: This is a dangerous operation, as this way bloat the whole memory. + # Instead, loading should spool the file. + async def load(self, timeout: float = BODY_READER_TIMEOUT) -> bytes: + """Loads the entire body into a bytes array.""" + data = bytearray() + while True: + chunk = await self.read(timeout) + if not chunk: + break + else: + data += chunk + return data - def feed( - self, - chunk: bytes, - more: bool = False, - ) -> bytes | None | Literal[False]: - self.comp.write(chunk) - self.comp.flush() - res = self.out.getvalue() - self.comp.seek(0) - self.comp.truncate() - return res + async def spool( + self, timeout: float = BODY_READER_TIMEOUT + ) -> SpooledTemporaryFile[bytes]: + """The safer way to load a body especially if the file exceeds a given size.""" + with SpooledTemporaryFile(prefix="extra", suffix="raw") as f: + while True: + chunk = await self.read(timeout) + if not chunk: + break + else: + f.write(chunk) + return f class HTTPBodyWriter(ABC): - """ "A generic writer for bodies that supports bytes encoding and decoding.""" + """A generic writer for bodies that supports bytes encoding and decoding.""" - def __init__(self) -> None: - self.transform: BytesTransform | None = None + __slots__ = ["transform"] - async def write( - self, - body: HTTPBodyBlob | HTTPBodyFile | HTTPBodyStream | HTTPBodyAsyncStream | None, - ) -> bool: + def __init__(self, transform: BytesTransform | None) -> None: + self.transform: BytesTransform | None = transform + + async def write(self, body: THTTPBody | bytes | None) -> bool: """Writes the given type of body.""" - if isinstance(body, HTTPBodyBlob): + if isinstance(body, bytes): + return await self._writeBytes(body) + elif isinstance(body, HTTPBodyBlob): return await self._write(body.payload) elif isinstance(body, HTTPBodyFile): - with open(body.path, "rb") as f: - while chunk := f.read(64_000): - await self._write(chunk, bool(chunk)) - return True + return await self._writeFile(body.path) elif isinstance(body, HTTPBodyStream): # No keep alive with streaming as these are long # lived requests. @@ -313,35 +324,30 @@ async def write( else: raise ValueError(f"Unsupported body format: {body}") + async def flush(self) -> bool: + if self.transform: + chunk = self.transform.flush() + if chunk: + await self._writeBytes(chunk) + return True + + async def _writeFile(self, path: Path, size: int = 64_000) -> bool: + with open(path, "rb") as f: + while chunk := f.read(size): + await self._write(chunk, bool(chunk)) + return True + async def _write(self, chunk: bytes, more: bool = False) -> bool: - return await self._send( + return await self._writeBytes( self.transform.feed(chunk, more) if self.transform else chunk, more ) @abstractmethod - async def _send( + async def _writeBytes( self, chunk: bytes | None | Literal[False], more: bool = False ) -> bool: ... -# TODO: We need to find an abstraction that works for all writers that supports: -# - HTTPBodyBlob -# - HTTPBodyFile -# - HTTPBodyStream - -# class GZipBodyEncoding: -# -# def accept(self, request: "HTTPRequest") -> bool: -# return any( -# _ -# for _ in request.headers.get("Accept-Encoding", "").split(",") -# if _.strip() == "gzip" -# ) -# -# def accept(self, request: "HTTPRequest") -> bool: -# pass - - # ----------------------------------------------------------------------------- # # REQUESTS @@ -350,6 +356,8 @@ async def _send( class HTTPRequest(ResponseFactory["HTTPResponse"]): + """Represents an HTTP requests, which also acts as a factory for + responses.""" __slots__ = [ "protocol", @@ -368,7 +376,7 @@ def __init__( path: str, query: dict[str, str] | None, headers: HTTPHeaders, - body: HTTPReaderBody | HTTPBodyBlob | None = None, + body: HTTPBodyIO | HTTPBodyBlob | None = None, protocol: str = "HTTP/1.1", ): super().__init__() @@ -377,7 +385,7 @@ def __init__( self.query: dict[str, str] | None = query self.protocol: str = protocol self._headers: HTTPHeaders = headers - self._body: HTTPReaderBody | HTTPBodyBlob | None = body + self._body: HTTPBodyIO | HTTPBodyBlob | None = body self._reader: HTTPBodyReader | None self._onClose: Callable[[HTTPRequest], None] | None = None @@ -423,11 +431,11 @@ def contentType(self) -> str | None: return self._headers.contentType @property - def body(self) -> HTTPReaderBody | HTTPBodyBlob: + def body(self) -> HTTPBodyIO | HTTPBodyBlob: if self._body is None: if not self._reader: raise RuntimeError("Request has no reader, can't read body") - self._body = HTTPReaderBody(self._reader) + self._body = HTTPBodyIO(self._reader) return self._body @property @@ -475,6 +483,9 @@ def __str__(self) -> str: class HTTPResponse: + """An HTTP response.""" + + __slots__ = ["protocol", "status", "message", "headers", "body"] @staticmethod def Create( @@ -551,8 +562,6 @@ def Create( protocol=protocol, ) - __slots__ = ["protocol", "status", "message", "headers", "body"] - def __init__( self, protocol: str, diff --git a/src/py/extra/server.py b/src/py/extra/server.py index b6df65d..d79a1bb 100644 --- a/src/py/extra/server.py +++ b/src/py/extra/server.py @@ -1,18 +1,16 @@ -from typing import Callable, NamedTuple, Any, Coroutine +from typing import Callable, NamedTuple, Any, Coroutine, Literal +from pathlib import Path import socket import asyncio from .utils.logging import exception, info, warning, event -from .utils.io import asWritable +from .utils.codec import BytesTransform from .utils.limits import LimitType, unlimit from .model import Application, Service, mount from .http.model import ( HTTPRequest, HTTPResponse, - HTTPBodyStream, - HTTPBodyAsyncStream, - HTTPBodyBlob, - HTTPBodyFile, HTTPBodyReader, + HTTPBodyWriter, HTTPProcessingStatus, ) from .http.parser import HTTPParser @@ -60,24 +58,56 @@ class ServerOptions(NamedTuple): class AIOSocketBodyReader(HTTPBodyReader): - __slots__ = ["socket", "loop", "buffer"] + __slots__ = ["socket", "loop", "buffer", "size"] def __init__( self, socket: "socket.socket", loop: asyncio.AbstractEventLoop, size: int = 64_000, + *, + transform: BytesTransform | None = None, ) -> None: + super().__init__(transform) self.socket = socket self.loop = loop + self.size: int = size - async def read(self, timeout: float = 1.0, size: int = 64_000) -> bytes | None: + async def _read( + self, timeout: float = 1.0, size: int | None = None + ) -> bytes | None: return await asyncio.wait_for( - self.loop.sock_recv(self.socket, size), + self.loop.sock_recv(self.socket, size or self.size), timeout=timeout, ) +class AIOSocketBodyWriter(HTTPBodyWriter): + + def __init__( + self, + client: "socket.socket", + loop: asyncio.AbstractEventLoop, + *, + transform: BytesTransform | None = None, + ) -> None: + super().__init__(transform) + self.client: socket.socket = client + self.loop: asyncio.AbstractEventLoop = loop + + async def _writeBytes( + self, chunk: bytes | None | Literal[False], more: bool = False + ) -> bool: + if chunk is None or chunk is False: + pass + return False + + async def _writeFile(self, path: Path, size: int = 64_000) -> bool: + with open(path, "rb") as f: + await self.loop.sock_sendfile(self.client, f) + return True + + # NOTE: Based on benchmarks, this gave the best performance. # NOTE: The caveat is that getting SSL directly is a pain, so we may # need to rewrite this a bit. @@ -112,8 +142,11 @@ async def OnRequest( res_count: int = 0 req_count: int = 0 try: + # TODO: Should reuse parser, reader, writer as these will be on the + # hotpath for requests. These should all be recyclable. parser: HTTPParser = HTTPParser() reader: AIOSocketBodyReader = AIOSocketBodyReader(client, loop) + writer: AIOSocketBodyWriter = AIOSocketBodyWriter(client, loop) # NOTE: Here a load balancer will sustain a single connection and # all the requests will come through this loop, until there's @@ -128,7 +161,7 @@ async def OnRequest( # We may have more than one request in each payload when # HTTP Pipelining is on. try: - # NOTE: THe timeout really doesn't do anything here, the + # NOTE: The timeout really doesn't do anything here, the # socket will return no data, instead of being blocking n = await asyncio.wait_for( loop.sock_recv_into(client, buffer), @@ -167,7 +200,7 @@ async def OnRequest( or req.headers.get("Connection") == "close" ): keep_alive = False - if await cls.SendResponse(req, app, client, loop): + if await cls.SendResponse(req, app, writer): res_count += 1 # We clear what we've read from the buffer del buffer[:n] @@ -225,8 +258,7 @@ async def OnRequest( async def SendResponse( request: HTTPRequest, app: Application, - client: socket.socket, - loop: asyncio.AbstractEventLoop, + writer: HTTPBodyWriter, ) -> HTTPResponse | None: req: HTTPRequest = request res: HTTPResponse | None = None @@ -244,39 +276,14 @@ async def SendResponse( Method=req.method, Path=req.path, ) - await loop.sock_sendall(client, SERVER_NOCONTENT) + await writer.write(SERVER_NOCONTENT) sent = True else: try: # We send the request head - await loop.sock_sendall(client, res.head()) + await writer.write(res.head()) sent = True - # And send the request - if isinstance(res.body, HTTPBodyBlob): - await loop.sock_sendall(client, res.body.payload) - elif isinstance(res.body, HTTPBodyFile): - with open(res.body.path, "rb") as f: - await loop.sock_sendfile(client, f) - elif isinstance(res.body, HTTPBodyStream): - # No keep alive with streaming as these are long - # lived requests. - try: - for chunk in res.body.stream: - await loop.sock_sendall(client, asWritable(chunk)) - finally: - res.body.stream.close() - elif isinstance(res.body, HTTPBodyAsyncStream): - # No keep alive with streaming as these are long - # lived requests. - try: - async for chunk in res.body.stream: - await loop.sock_sendall(client, asWritable(chunk)) - finally: - await res.body.stream.aclose() - elif res.body is None: - pass - else: - raise ValueError(f"Unsupported body format: {res.body}") + await writer.write(res.body) except BrokenPipeError: # Client did an early close sent = True @@ -295,7 +302,7 @@ async def SendResponse( Method=req.method, Path=req.path, ) - await loop.sock_sendall(client, SERVER_ERROR) + await writer.write(SERVER_ERROR) except Exception as e: exception(e) diff --git a/src/py/extra/utils/codec.py b/src/py/extra/utils/codec.py new file mode 100644 index 0000000..72d9622 --- /dev/null +++ b/src/py/extra/utils/codec.py @@ -0,0 +1,116 @@ +import zlib +from typing import Literal +from abc import ABC, abstractmethod + + +class BytesTransform(ABC): + """An abstract bytes transform.""" + + @abstractmethod + def feed(self, chunk: bytes, more: bool = False) -> bytes | None | Literal[False]: + """Feeds bytes to the transform, may return a value.""" + + @abstractmethod + def flush(self) -> bytes | None | Literal[False]: + """Ensures that the bytes transform is flushed, for chunked encodings this will produce a new chunk.""" + + +class GZipDecoder(BytesTransform): + __slots__ = ["decompressor"] + + def __init__(self) -> None: + super().__init__() + self.decompressor = zlib.decompressobj(wbits=zlib.MAX_WBITS | 32) + + def feed(self, chunk: bytes, more: bool = False) -> bytes | None | Literal[False]: + return self.decompressor.decompress(chunk) + + def flush( + self, + ) -> bytes | None | Literal[False]: + return self.decompressor.flush() + + +class GZipEncoder(BytesTransform): + __slots__ = ["compressor"] + + def __init__(self, compression_level: int = 6) -> None: + super().__init__() + self.compressor = zlib.compressobj( + level=compression_level, wbits=zlib.MAX_WBITS | 16 + ) + + def feed(self, chunk: bytes, more: bool = False) -> bytes | None | Literal[False]: + return self.compressor.compress(chunk) + + def flush(self) -> bytes | None | Literal[False]: + return self.compressor.flush() + + +# SEE: https://httpwg.org/specs/rfc9112.html#chunked.encoding +class ChunkedEncoder(BytesTransform): + __slots__ = ["buffer"] + + def __init__(self) -> None: + super().__init__() + self.buffer = bytearray() + + def feed(self, chunk: bytes, more: bool = False) -> bytes | None | Literal[False]: + if not chunk: + return None + self.buffer.extend(chunk) + return None + + def flush(self) -> bytes | None | Literal[False]: + if not self.buffer: + return b"0\r\n\r\n" + result = f"{len(self.buffer):X}\r\n".encode() + self.buffer + b"\r\n" + self.buffer.clear() + return bytes(result) + + +class ChunkedDecoder(BytesTransform): + __slots__ = ["buffer", "chunkSize", "readingSize"] + + def __init__(self) -> None: + super().__init__() + self.buffer = bytearray() + self.chunkSize = 0 + self.readingSize = True + + def feed(self, chunk: bytes, more: bool = False) -> bytes | None | Literal[False]: + self.buffer.extend(chunk) + # TODO: We should probably reuse the bytes array to avoid more allocations + res = bytearray() + + while self.buffer: + if self.readingSize: + # TODO: Faster to use find + if b"\r\n" not in self.buffer: + break + size_line, remaining = self.buffer.split(b"\r\n", 1) + try: + self.chunkSize = int(size_line, 16) + except ValueError: + return False # Invalid chunk size + self.buffer = remaining + self.readingSize = False + if self.chunkSize == 0: + return bytes(res) if res else None # End of chunked data + + if len(self.buffer) < self.chunkSize + 2: + break + + res.extend(self.buffer[: self.chunkSize]) + self.buffer = self.buffer[self.chunkSize + 2 :] # +2 for \r\n + self.readingSize = True + + return bytes(res) if res else None + + def flush(self) -> bytes | None | Literal[False]: + if self.buffer: + return False # Incomplete chunk + return None + + +# EOF diff --git a/src/py/extra/utils/io.py b/src/py/extra/utils/io.py index 52bf90d..27bc47e 100644 --- a/src/py/extra/utils/io.py +++ b/src/py/extra/utils/io.py @@ -3,6 +3,8 @@ from .primitives import TPrimitive DEFAULT_ENCODING: str = "utf8" +EOL: bytes = b"\r\n" +END: int = 1 class Control(NamedTuple): @@ -32,10 +34,6 @@ def asWritable(value: str | bytes | TPrimitive) -> bytes: return json(value) -EOL: bytes = b"\r\n" -END: int = 1 - - class LineParser: __slots__ = ["buffer", "buflen", "line", "eol", "eolsize", "offset"] From a5a0873eb16117906ac547e6239a6a56737f776c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Pierre?= Date: Wed, 2 Oct 2024 12:59:31 +1300 Subject: [PATCH 2/3] [Fix] server: writer was not sending data --- src/py/extra/server.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/py/extra/server.py b/src/py/extra/server.py index d79a1bb..1b7aba1 100644 --- a/src/py/extra/server.py +++ b/src/py/extra/server.py @@ -57,6 +57,7 @@ class ServerOptions(NamedTuple): class AIOSocketBodyReader(HTTPBodyReader): + """Specialized body reader to work with AIO sockets.""" __slots__ = ["socket", "loop", "buffer", "size"] @@ -83,6 +84,7 @@ async def _read( class AIOSocketBodyWriter(HTTPBodyWriter): + """Specialized body writer to work with AIO sockets.""" def __init__( self, @@ -100,6 +102,8 @@ async def _writeBytes( ) -> bool: if chunk is None or chunk is False: pass + else: + await self.loop.sock_sendall(self.client, chunk) return False async def _writeFile(self, path: Path, size: int = 64_000) -> bool: @@ -260,6 +264,7 @@ async def SendResponse( app: Application, writer: HTTPBodyWriter, ) -> HTTPResponse | None: + """Processes the request within the application and sends a response using the given writer.""" req: HTTPRequest = request res: HTTPResponse | None = None sent: bool = False @@ -314,6 +319,7 @@ async def Serve( app: Application, options: ServerOptions = ServerOptions(), ) -> None: + """Main server coroutine.""" server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) server.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) From 59e15be2d74f566dea1240a74ae71449fc9a13c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Pierre?= Date: Wed, 2 Oct 2024 13:12:10 +1300 Subject: [PATCH 3/3] [Update] server: added server state with exception logging --- src/py/extra/server.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/src/py/extra/server.py b/src/py/extra/server.py index 1b7aba1..262f6f8 100644 --- a/src/py/extra/server.py +++ b/src/py/extra/server.py @@ -1,5 +1,7 @@ from typing import Callable, NamedTuple, Any, Coroutine, Literal from pathlib import Path +from signal import SIGINT, SIGTERM +from dataclasses import dataclass import socket import asyncio from .utils.logging import exception, info, warning, event @@ -17,6 +19,21 @@ from .config import HOST, PORT +@dataclass(slots=True) +class ServerState: + isRunning: bool = True + + def stop(self) -> None: + self.isRunning = False + + def onException( + self, loop: asyncio.AbstractEventLoop, context: dict[str, Any] + ) -> None: + e = context.get("exception") + if e: + exception(e) + + class ServerOptions(NamedTuple): host: str = "0.0.0.0" # nosec: B104 port: int = 8000 @@ -336,6 +353,13 @@ async def Serve( except RuntimeError: loop = asyncio.new_event_loop() + # Manage server state + state = ServerState() + # Registers handlers for signals and exception (so that we log them) + loop.add_signal_handler(SIGINT, lambda: state.stop()) + loop.add_signal_handler(SIGTERM, lambda: state.stop()) + loop.set_exception_handler(state.onException) + info( "Extra AIO Server listening", icon="🚀", @@ -344,7 +368,7 @@ async def Serve( ) try: - while True: + while state.isRunning: if options.condition and not options.condition(): break try: