diff --git a/multipart/multipart.py b/multipart/multipart.py index ac2648e..221bb71 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -40,6 +40,21 @@ class MultipartCallbacks(TypedDict, total=False): on_headers_finished: Callable[[], None] on_end: Callable[[], None] + class FormParserConfig(TypedDict, total=False): + UPLOAD_DIR: str | None + UPLOAD_KEEP_FILENAME: bool + UPLOAD_KEEP_EXTENSIONS: bool + UPLOAD_ERROR_ON_BAD_CTE: bool + MAX_MEMORY_FILE_SIZE: int + MAX_BODY_SIZE: float + + class FileConfig(TypedDict, total=False): + UPLOAD_DIR: str | None + UPLOAD_DELETE_TMP: bool + UPLOAD_KEEP_FILENAME: bool + UPLOAD_KEEP_EXTENSIONS: bool + MAX_MEMORY_FILE_SIZE: int + # Unique missing object. _missing = object() @@ -334,7 +349,7 @@ class File: configuration keys and their corresponding values. """ - def __init__(self, file_name, field_name=None, config={}): + def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}): # Save configuration, set other variables default. self.logger = logging.getLogger(__name__) self._config = config @@ -357,14 +372,14 @@ def __init__(self, file_name, field_name=None, config={}): self._ext = ext @property - def field_name(self): + def field_name(self) -> bytes | None: """The form field associated with this file. May be None if there isn't one, for example when we have an application/octet-stream upload. """ return self._field_name @property - def file_name(self): + def file_name(self) -> bytes | None: """The file name given in the upload request.""" return self._file_name @@ -391,13 +406,13 @@ def size(self): return self._bytes_written @property - def in_memory(self): + def in_memory(self) -> bool: """A boolean representing whether or not this file object is currently stored in-memory or on-disk. """ return self._in_memory - def flush_to_disk(self): + def flush_to_disk(self) -> None: """If the file is already on-disk, do nothing. Otherwise, copy from the in-memory buffer to a disk file, and then reassign our internal file object to this new disk file. @@ -495,14 +510,14 @@ def _get_disk_file(self): self._actual_file_name = fname return tmp_file - def write(self, data): + def write(self, data: bytes): """Write some data to the File. :param data: a bytestring """ return self.on_data(data) - def on_data(self, data): + def on_data(self, data: bytes): """This method is a callback that will be called whenever data is written to the File. @@ -534,25 +549,25 @@ def on_data(self, data): # Return the number of bytes written. return bwritten - def on_end(self): + def on_end(self) -> None: """This method is called whenever the Field is finalized.""" # Flush the underlying file object self._fileobj.flush() - def finalize(self): + def finalize(self) -> None: """Finalize the form file. This will not close the underlying file, but simply signal that we are finished writing to the File. """ self.on_end() - def close(self): + def close(self) -> None: """Close the File object. This will actually close the underlying file object (whether it's a :class:`io.BytesIO` or an actual file object). """ self._fileobj.close() - def __repr__(self): + def __repr__(self) -> str: return "{}(file_name={!r}, field_name={!r})".format(self.__class__.__name__, self.file_name, self.field_name) @@ -703,13 +718,13 @@ def write(self, data: bytes): self.callback("data", data, 0, data_len) return data_len - def finalize(self): + def finalize(self) -> None: """Finalize this parser, which signals to that we are finished parsing, and sends the on_end callback. """ self.callback("end") - def __repr__(self): + def __repr__(self) -> str: return "%s()" % self.__class__.__name__ @@ -761,7 +776,7 @@ class QuerystringParser(BaseParser): state: QuerystringState - def __init__(self, callbacks: QuerystringCallbacks = {}, strict_parsing=False, max_size=float("inf")): + def __init__(self, callbacks: QuerystringCallbacks = {}, strict_parsing: bool = False, max_size=float("inf")): super().__init__() self.state = QuerystringState.BEFORE_FIELD self._found_sep = False @@ -777,7 +792,7 @@ def __init__(self, callbacks: QuerystringCallbacks = {}, strict_parsing=False, m # Should parsing be strict? self.strict_parsing = strict_parsing - def write(self, data: bytes): + def write(self, data: bytes) -> int: """Write some data to the parser, which will perform size verification, parse into either a field name or value, and then pass the corresponding data to the underlying callback. If an error is @@ -809,7 +824,7 @@ def write(self, data: bytes): return l - def _internal_write(self, data: bytes, length: int): + def _internal_write(self, data: bytes, length: int) -> int: state = self.state strict_parsing = self.strict_parsing found_sep = self._found_sep @@ -947,7 +962,7 @@ def _internal_write(self, data: bytes, length: int): self._found_sep = found_sep return len(data) - def finalize(self): + def finalize(self) -> None: """Finalize this parser, which signals to that we are finished parsing, if we're still in the middle of a field, an on_field_end callback, and then the on_end callback. @@ -957,7 +972,7 @@ def finalize(self): self.callback("field_end") self.callback("end") - def __repr__(self): + def __repr__(self) -> str: return "{}(strict_parsing={!r}, max_size={!r})".format( self.__class__.__name__, self.strict_parsing, self.max_size ) @@ -1018,7 +1033,7 @@ class MultipartParser(BaseParser): i.e. unbounded. """ - def __init__(self, boundary, callbacks: MultipartCallbacks = {}, max_size=float("inf")): + def __init__(self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, max_size=float("inf")): # Initialize parser state. super().__init__() self.state = MultipartState.START @@ -1057,7 +1072,7 @@ def __init__(self, boundary, callbacks: MultipartCallbacks = {}, max_size=float( # '--\r\n' is 8 bytes. self.lookbehind = [NULL for x in range(len(boundary) + 8)] - def write(self, data): + def write(self, data: bytes) -> int: """Write some data to the parser, which will perform size verification, and then parse the data into the appropriate location (e.g. header, data, etc.), and pass this on to the underlying callback. If an error @@ -1089,7 +1104,7 @@ def write(self, data): return l - def _internal_write(self, data, length): + def _internal_write(self, data: bytes, length: int) -> int: # Get values from locals. boundary = self.boundary @@ -1478,7 +1493,7 @@ def data_callback(name, remaining=False): # all of it. return length - def finalize(self): + def finalize(self) -> None: """Finalize this parser, which signals to that we are finished parsing. Note: It does not currently, but in the future, it will verify that we @@ -1548,7 +1563,7 @@ class FormParser: #: This is the default configuration for our form parser. #: Note: all file sizes should be in bytes. - DEFAULT_CONFIG = { + DEFAULT_CONFIG: FormParserConfig = { "MAX_BODY_SIZE": float("inf"), "MAX_MEMORY_FILE_SIZE": 1 * 1024 * 1024, "UPLOAD_DIR": None, @@ -1568,7 +1583,7 @@ def __init__( file_name=None, FileClass=File, FieldClass=Field, - config={}, + config: FormParserConfig = {}, ): self.logger = logging.getLogger(__name__) @@ -1597,13 +1612,13 @@ def __init__( class vars: f = None - def on_start(): + def on_start() -> None: vars.f = FileClass(file_name, None, config=self.config) - def on_data(data, start, end): + def on_data(data: bytes, start: int, end: int) -> None: vars.f.write(data[start:end]) - def on_end(): + def on_end() -> None: # Finalize the file itself. vars.f.finalize() @@ -1614,30 +1629,31 @@ def on_end(): if self.on_end is not None: self.on_end() - callbacks = {"on_start": on_start, "on_data": on_data, "on_end": on_end} - # Instantiate an octet-stream parser - parser = OctetStreamParser(callbacks, max_size=self.config["MAX_BODY_SIZE"]) + parser = OctetStreamParser( + callbacks={"on_start": on_start, "on_data": on_data, "on_end": on_end}, + max_size=self.config["MAX_BODY_SIZE"], + ) elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded": - name_buffer = [] + name_buffer: list[bytes] = [] class vars: f = None - def on_field_start(): + def on_field_start() -> None: pass - def on_field_name(data, start, end): + def on_field_name(data: bytes, start: int, end: int) -> None: name_buffer.append(data[start:end]) - def on_field_data(data, start, end): + def on_field_data(data: bytes, start: int, end: int) -> None: if vars.f is None: vars.f = FieldClass(b"".join(name_buffer)) del name_buffer[:] vars.f.write(data[start:end]) - def on_field_end(): + def on_field_end() -> None: # Finalize and call callback. if vars.f is None: # If we get here, it's because there was no field data. @@ -1650,29 +1666,29 @@ def on_field_end(): on_field(vars.f) vars.f = None - def on_end(): + def on_end() -> None: if self.on_end is not None: self.on_end() - # Setup callbacks. - callbacks = { - "on_field_start": on_field_start, - "on_field_name": on_field_name, - "on_field_data": on_field_data, - "on_field_end": on_field_end, - "on_end": on_end, - } - # Instantiate parser. - parser = QuerystringParser(callbacks=callbacks, max_size=self.config["MAX_BODY_SIZE"]) + parser = QuerystringParser( + callbacks={ + "on_field_start": on_field_start, + "on_field_name": on_field_name, + "on_field_data": on_field_data, + "on_field_end": on_field_end, + "on_end": on_end, + }, + max_size=self.config["MAX_BODY_SIZE"], + ) elif content_type == "multipart/form-data": if boundary is None: self.logger.error("No boundary given") raise FormParserError("No boundary given") - header_name = [] - header_value = [] + header_name: list[bytes] = [] + header_value: list[bytes] = [] headers = {} # No 'nonlocal' on Python 2 :-( @@ -1684,22 +1700,22 @@ class vars: def on_part_begin(): pass - def on_part_data(data, start, end): + def on_part_data(data: bytes, start: int, end: int): bytes_processed = vars.writer.write(data[start:end]) # TODO: check for error here. return bytes_processed - def on_part_end(): + def on_part_end() -> None: vars.f.finalize() if vars.is_file: on_file(vars.f) else: on_field(vars.f) - def on_header_field(data, start, end): + def on_header_field(data: bytes, start: int, end: int): header_name.append(data[start:end]) - def on_header_value(data, start, end): + def on_header_value(data: bytes, start: int, end: int): header_value.append(data[start:end]) def on_header_end(): @@ -1707,7 +1723,7 @@ def on_header_end(): del header_name[:] del header_value[:] - def on_headers_finished(): + def on_headers_finished() -> None: # Reset the 'is file' flag. vars.is_file = False @@ -1751,25 +1767,26 @@ def on_headers_finished(): # unencoded Content-Transfer-Encoding. vars.writer = vars.f - def on_end(): + def on_end() -> None: vars.writer.finalize() if self.on_end is not None: self.on_end() - # These are our callbacks for the parser. - callbacks = { - "on_part_begin": on_part_begin, - "on_part_data": on_part_data, - "on_part_end": on_part_end, - "on_header_field": on_header_field, - "on_header_value": on_header_value, - "on_header_end": on_header_end, - "on_headers_finished": on_headers_finished, - "on_end": on_end, - } - # Instantiate a multipart parser. - parser = MultipartParser(boundary, callbacks, max_size=self.config["MAX_BODY_SIZE"]) + parser = MultipartParser( + boundary, + callbacks={ + "on_part_begin": on_part_begin, + "on_part_data": on_part_data, + "on_part_end": on_part_end, + "on_header_field": on_header_field, + "on_header_value": on_header_value, + "on_header_end": on_header_end, + "on_headers_finished": on_headers_finished, + "on_end": on_end, + }, + max_size=self.config["MAX_BODY_SIZE"], + ) else: self.logger.warning("Unknown Content-Type: %r", content_type) @@ -1777,7 +1794,7 @@ def on_end(): self.parser = parser - def write(self, data): + def write(self, data: bytes): """Write some data. The parser will forward this to the appropriate underlying parser. @@ -1787,17 +1804,17 @@ def write(self, data): # TODO: check the parser's return value for errors? return self.parser.write(data) - def finalize(self): + def finalize(self) -> None: """Finalize the parser.""" if self.parser is not None and hasattr(self.parser, "finalize"): self.parser.finalize() - def close(self): + def close(self) -> None: """Close the parser.""" if self.parser is not None and hasattr(self.parser, "close"): self.parser.close() - def __repr__(self): + def __repr__(self) -> str: return "{}(content_type={!r}, parser={!r})".format(self.__class__.__name__, self.content_type, self.parser)