From d7f2ad7855cd46d768d4806eedd883112a4ecc23 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Pierre?= Date: Fri, 27 Sep 2024 10:21:43 +1200 Subject: [PATCH] [Fix] client: sending the body properly --- src/py/extra/client.py | 87 ++++++++++++++++++++++++++++++++++---- src/py/extra/http/model.py | 4 ++ 2 files changed, 83 insertions(+), 8 deletions(-) diff --git a/src/py/extra/client.py b/src/py/extra/client.py index 9fd340c..812c153 100644 --- a/src/py/extra/client.py +++ b/src/py/extra/client.py @@ -1,6 +1,7 @@ from typing import NamedTuple, ClassVar, AsyncGenerator, Self, Any, Iterator from urllib.parse import quote_plus, urlparse from contextvars import ContextVar +from .utils.io import asWritable from contextlib import contextmanager from dataclasses import dataclass import asyncio, ssl, time, os @@ -9,12 +10,15 @@ from .http.model import ( HTTPRequest, HTTPResponse, + HTTPResponseStream, + HTTPResponseAsyncStream, + HTTPBodyBlob, + HTTPResponseFile, HTTPHeaders, HTTPRequestBody, HTTPBodyBlob, HTTPAtom, HTTPProcessingStatus, - headername, ) from .http.parser import HTTPParser @@ -259,7 +263,6 @@ async def get( # then we close the connection, or return a new one. while cxn: c = cxn.pop() - print("POOOL CONN", c, c.isValid) if c.isValid: return c else: @@ -335,8 +338,6 @@ async def OnRequest( host: str, cxn: Connection, *, - headers: dict[str, str] | None = None, - body: HTTPRequestBody | HTTPBodyBlob | None = None, timeout: float | None = 2.0, buffer: int = 32_000, streaming: bool | None = None, @@ -345,12 +346,10 @@ async def OnRequest( """Low level function to process HTTP requests with the given connection.""" # We send the line line = f"{request.method} {request.path} HTTP/1.1\r\n".encode() - # We send the headers - head: dict[str, str] = ( - {headername(k): v for k, v in headers.items()} if headers else {} - ) + head: dict[str, str] = request.headers if "Host" not in head: head["Host"] = host + body = request.body if not streaming and "Content-Length" not in head: head["Content-Length"] = ( "0" @@ -358,6 +357,7 @@ async def OnRequest( else ( str(body.length) if isinstance(body, HTTPBodyBlob) + or isinstance(body, HTTPResponseFile) else str(body.expected or "0") ) ) @@ -368,6 +368,39 @@ async def OnRequest( cxn.writer.write(payload) cxn.writer.write(b"\r\n\r\n") await cxn.writer.drain() + # NOTE: This is a common logic shared with the server + # And send the request + if isinstance(body, HTTPBodyBlob): + cxn.writer.write(body.payload) + elif isinstance(body, HTTPResponseFile): + fd: int = -1 + try: + fd = os.open(str(body.path), os.O_RDONLY) + while True: + chunk = os.read(fd, 64_000) + if chunk: + cxn.writer.write(chunk) + else: + break + finally: + if fd > 0: + os.close(fd) + elif isinstance(body, HTTPResponseStream): + # No keep alive with streaming as these are long + # lived requests. + for chunk in body.stream: + cxn.writer.write(asWritable(chunk)) + await cxn.writer.drain() + elif isinstance(body, HTTPResponseAsyncStream): + # No keep alive with streaming as these are long + # lived requests. + async for chunk in body.stream: + cxn.writer.write(asWritable(chunk)) + await cxn.writer.drain() + elif body is None: + pass + else: + raise ValueError(f"Unsupported body format: {body}") iteration: int = 0 # -- @@ -522,6 +555,7 @@ async def Request( path, query=None, headers=HTTPHeaders(headers or {}), + body=body, ), host, cxn, @@ -547,6 +581,43 @@ def pooling(idle: float | int | None = None) -> Iterator[ConnectionPool]: pool.pop().release() +async def request( + method: str, + host: str, + path: str, + *, + port: int | None = None, + headers: dict[str, str] | None = None, + body: HTTPRequestBody | HTTPBodyBlob | None = None, + params: dict[str, str] | str | None = None, + ssl: bool = True, + verified: bool = True, + timeout: float = 10.0, + follow: bool = True, + proxy: tuple[str, int] | bool | None = None, + connection: Connection | None = None, + streaming: bool | None = None, + keepalive: bool = False, +) -> AsyncGenerator[HTTPAtom, None]: + async for atom in HTTPClient.Request( + method, + host, + path, + port=port, + headers=headers, + body=body, + params=params, + ssl=ssl, + verified=verified, + follow=follow, + proxy=proxy, + connection=connection, + streaming=streaming, + keepalive=keepalive, + ): + yield atom + + if __name__ == "__main__": async def main() -> None: diff --git a/src/py/extra/http/model.py b/src/py/extra/http/model.py index 0fd5403..621bc50 100644 --- a/src/py/extra/http/model.py +++ b/src/py/extra/http/model.py @@ -140,6 +140,10 @@ class HTTPResponseFile(NamedTuple): path: Path fd: int | None = None + @property + def length(self) -> int: + return self.path.stat().st_size + class HTTPResponseStream(NamedTuple): stream: Generator[str | bytes | TPrimitive, Any, Any]