From f4479c6b3ae841c6f48931dd406f18d6f35f64c1 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 12 Feb 2024 23:43:17 +0100 Subject: [PATCH] Improve type hints on `File` (#111) --- multipart/decoders.py | 19 ++++++----- multipart/multipart.py | 74 ++++++++++++++++++----------------------- tests/test_multipart.py | 6 +++- 3 files changed, 48 insertions(+), 51 deletions(-) diff --git a/multipart/decoders.py b/multipart/decoders.py index e401fa0..218abe4 100644 --- a/multipart/decoders.py +++ b/multipart/decoders.py @@ -1,5 +1,6 @@ import base64 import binascii +from io import BufferedWriter from .exceptions import DecodeError @@ -33,11 +34,11 @@ class Base64Decoder: :param underlying: the underlying object to pass writes to """ - def __init__(self, underlying): + def __init__(self, underlying: BufferedWriter): self.cache = bytearray() self.underlying = underlying - def write(self, data): + def write(self, data: bytes) -> int: """Takes any input data provided, decodes it as base64, and passes it on to the underlying object. If the data provided is invalid base64 data, then this method will raise @@ -80,7 +81,7 @@ def close(self) -> None: if hasattr(self.underlying, "close"): self.underlying.close() - def finalize(self): + def finalize(self) -> None: """Finalize this object. This should be called when no more data should be written to the stream. This function can raise a :class:`multipart.exceptions.DecodeError` if there is some remaining @@ -97,7 +98,7 @@ def finalize(self): if hasattr(self.underlying, "finalize"): self.underlying.finalize() - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}(underlying={self.underlying!r})" @@ -111,11 +112,11 @@ class QuotedPrintableDecoder: :param underlying: the underlying object to pass writes to """ - def __init__(self, underlying): + def __init__(self, underlying: BufferedWriter) -> None: self.cache = b"" self.underlying = underlying - def write(self, data): + def write(self, data: bytes) -> int: """Takes any input data provided, decodes it as quoted-printable, and passes it on to the underlying object. @@ -142,14 +143,14 @@ def write(self, data): self.cache = rest return len(data) - def close(self): + def close(self) -> None: """Close this decoder. If the underlying object has a `close()` method, this function will call it. """ if hasattr(self.underlying, "close"): self.underlying.close() - def finalize(self): + def finalize(self) -> None: """Finalize this object. This should be called when no more data should be written to the stream. This function will not raise any exceptions, but it may write more data to the underlying object if @@ -167,5 +168,5 @@ def finalize(self): if hasattr(self.underlying, "finalize"): self.underlying.finalize() - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}(underlying={self.underlying!r})" diff --git a/multipart/multipart.py b/multipart/multipart.py index 0c7c447..9f61eb8 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -41,7 +41,7 @@ class MultipartCallbacks(TypedDict, total=False): on_headers_finished: Callable[[], None] on_end: Callable[[], None] - class FormParserConfig(TypedDict, total=False): + class FormParserConfig(TypedDict): UPLOAD_DIR: str | None UPLOAD_KEEP_FILENAME: bool UPLOAD_KEEP_EXTENSIONS: bool @@ -50,7 +50,7 @@ class FormParserConfig(TypedDict, total=False): MAX_BODY_SIZE: float class FileConfig(TypedDict, total=False): - UPLOAD_DIR: str | None + UPLOAD_DIR: str | bytes | None UPLOAD_DELETE_TMP: bool UPLOAD_KEEP_FILENAME: bool UPLOAD_KEEP_EXTENSIONS: bool @@ -374,7 +374,7 @@ class File: configuration keys and their corresponding values. """ - def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}): + def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}) -> None: # Save configuration, set other variables default. self.logger = logging.getLogger(__name__) self._config = config @@ -471,7 +471,7 @@ def flush_to_disk(self) -> None: # Close the old file object. old_fileobj.close() - def _get_disk_file(self): + def _get_disk_file(self) -> io.BufferedRandom | tempfile._TemporaryFileWrapper[bytes]: # type: ignore[reportPrivateUsage] """This function is responsible for getting a file object on-disk for us.""" self.logger.info("Opening a file on disk") @@ -486,9 +486,7 @@ def _get_disk_file(self): # Build our filename. # TODO: what happens if we don't have a filename? - fname = self._file_base - if keep_extensions: - fname = fname + self._ext + fname = self._file_base + self._ext if keep_extensions else self._file_base path = os.path.join(file_dir, fname) try: @@ -503,25 +501,21 @@ def _get_disk_file(self): # Build options array. # Note that on Python 3, tempfile doesn't support byte names. We # encode our paths using the default filesystem encoding. - options = {} - if keep_extensions: - ext = self._ext - if isinstance(ext, bytes): - ext = ext.decode(sys.getfilesystemencoding()) - - options["suffix"] = ext - if file_dir is not None: - d = file_dir - if isinstance(d, bytes): - d = d.decode(sys.getfilesystemencoding()) + suffix = self._ext.decode(sys.getfilesystemencoding()) if keep_extensions else None - options["dir"] = d - options["delete"] = delete_tmp + if file_dir is None: + dir = None + elif isinstance(file_dir, bytes): + dir = file_dir.decode(sys.getfilesystemencoding()) + else: + dir = file_dir # Create a temporary (named) file with the appropriate settings. - self.logger.info("Creating a temporary file with options: %r", options) + self.logger.info( + "Creating a temporary file with options: %r", {"suffix": suffix, "delete": delete_tmp, "dir": dir} + ) try: - tmp_file = tempfile.NamedTemporaryFile(**options) + tmp_file = tempfile.NamedTemporaryFile(suffix=suffix, delete=delete_tmp, dir=dir) except OSError: self.logger.exception("Error creating named temporary file") raise FileError("Error creating named temporary file") @@ -563,11 +557,8 @@ def on_data(self, data: bytes) -> int: self._bytes_written += bwritten # If we're in-memory and are over our limit, we create a file. - if ( - self._in_memory - and self._config.get("MAX_MEMORY_FILE_SIZE") is not None - and (self._bytes_written > self._config.get("MAX_MEMORY_FILE_SIZE")) - ): + max_memory_file_size = self._config.get("MAX_MEMORY_FILE_SIZE") + if self._in_memory and max_memory_file_size is not None and (self._bytes_written > max_memory_file_size): self.logger.info("Flushing to disk") self.flush_to_disk() @@ -617,9 +608,7 @@ class BaseParser: performance. """ - callbacks: dict[str, Callable[..., Any]] - - def __init__(self): + def __init__(self) -> None: self.logger = logging.getLogger(__name__) def callback(self, name: str, data: bytes | None = None, start: int | None = None, end: int | None = None): @@ -706,7 +695,7 @@ class OctetStreamParser(BaseParser): i.e. unbounded. """ - def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size=float("inf")): + def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size: float = float("inf")): super().__init__() self.callbacks = callbacks self._started = False @@ -716,7 +705,7 @@ def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size=float("inf")): self.max_size = max_size self._current_size = 0 - def write(self, data: bytes): + def write(self, data: bytes) -> int: """Write some data to the parser, which will perform size verification, and then pass the data to the underlying callback. @@ -803,7 +792,9 @@ class QuerystringParser(BaseParser): state: QuerystringState - def __init__(self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size=float("inf")): + def __init__( + self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size: float = float("inf") + ) -> None: super().__init__() self.state = QuerystringState.BEFORE_FIELD self._found_sep = False @@ -1060,7 +1051,9 @@ class MultipartParser(BaseParser): i.e. unbounded. """ - def __init__(self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size=float("inf")): + def __init__( + self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size: float = float("inf") + ) -> None: # Initialize parser state. super().__init__() self.state = MultipartState.START @@ -1618,8 +1611,8 @@ def __init__( file_name: bytes | None = None, FileClass: type[FileProtocol] = File, FieldClass: type[FieldProtocol] = Field, - config: FormParserConfig = {}, - ): + config: dict[Any, Any] = {}, + ) -> None: self.logger = logging.getLogger(__name__) # Save variables. @@ -1787,7 +1780,7 @@ def on_headers_finished() -> None: # TODO: check that we properly handle 8bit / 7bit encoding. transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit") - if transfer_encoding == b"binary" or transfer_encoding == b"8bit" or transfer_encoding == b"7bit": + if transfer_encoding in (b"binary", b"8bit", b"7bit"): writer = f elif transfer_encoding == b"base64": @@ -1862,8 +1855,8 @@ def create_form_parser( on_field: OnFieldCallback, on_file: OnFileCallback, trust_x_headers: bool = False, - config={}, -): + config: dict[Any, Any] = {}, +) -> FormParser: """This function is a helper function to aid in creating a FormParser instances. Given a dictionary-like headers object, it will determine the correct information needed, instantiate a FormParser with the @@ -1912,8 +1905,7 @@ def parse_form( on_field: OnFieldCallback, on_file: OnFileCallback, chunk_size: int = 1048576, - **kwargs, -): +) -> None: """This function is useful if you just want to parse a request body, without too much work. Pass it a dictionary-like object of the request's headers, and a file-like object for the input stream, along with two diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 79968e0..93fd38d 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -6,6 +6,7 @@ import tempfile import unittest from io import BytesIO +from typing import TYPE_CHECKING from unittest.mock import Mock import yaml @@ -28,6 +29,9 @@ from .compat import parametrize, parametrize_class, slow_test +if TYPE_CHECKING: + from multipart.multipart import FileConfig + # Get the current directory for our later test cases. curr_dir = os.path.abspath(os.path.dirname(__file__)) @@ -95,7 +99,7 @@ def test_set_none(self): class TestFile(unittest.TestCase): def setUp(self): - self.c = {} + self.c: FileConfig = {} self.d = force_bytes(tempfile.mkdtemp()) self.f = File(b"foo.txt", config=self.c)