Skip to content

Commit

Permalink
Improve type hints on FormParser (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Feb 10, 2024
1 parent 54c4d18 commit 59543a4
Showing 1 changed file with 84 additions and 56 deletions.
140 changes: 84 additions & 56 deletions multipart/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError

if TYPE_CHECKING: # pragma: no cover
from typing import Callable, TypedDict
from typing import Callable, Protocol, TypedDict

class QuerystringCallbacks(TypedDict, total=False):
on_field_start: Callable[[], None]
Expand Down Expand Up @@ -55,6 +55,30 @@ class FileConfig(TypedDict, total=False):
UPLOAD_KEEP_EXTENSIONS: bool
MAX_MEMORY_FILE_SIZE: int

class _FormProtocol(Protocol):
def write(self, data: bytes) -> int:
...

def finalize(self) -> None:
...

def close(self) -> None:
...

class FieldProtocol(_FormProtocol, Protocol):
def __init__(self, name: bytes) -> None:
...

def set_none(self) -> None:
...

class FileProtocol(_FormProtocol, Protocol):
def __init__(self, file_name: bytes | None, field_name: bytes | None, config: FileConfig) -> None:
...

OnFieldCallback = Callable[[FieldProtocol], None]
OnFileCallback = Callable[[FieldProtocol], None]


# Unique missing object.
_missing = object()
Expand Down Expand Up @@ -190,15 +214,15 @@ class Field:
:param name: the name of the form field
"""

def __init__(self, name: str):
def __init__(self, name: bytes):
self._name = name
self._value: list[bytes] = []

# We cache the joined version of _value for speed.
self._cache = _missing

@classmethod
def from_value(cls, name: str, value: bytes | None) -> Field:
def from_value(cls, name: bytes, value: bytes | None) -> Field:
"""Create an instance of a :class:`Field`, and set the corresponding
value - either None or an actual value. This method will also
finalize the Field itself.
Expand Down Expand Up @@ -260,7 +284,7 @@ def set_none(self) -> None:
self._cache = None

@property
def field_name(self) -> str:
def field_name(self) -> bytes:
"""This property returns the name of the field."""
return self._name

Expand Down Expand Up @@ -1562,6 +1586,7 @@ class FormParser:
field_instance.write(data)
field_instance.finalize()
field_instance.close()
field_instance.set_none()
:param config: Configuration to use for this FormParser. The default
values are taken from the DEFAULT_CONFIG value, and then
Expand All @@ -1584,14 +1609,14 @@ class FormParser:

def __init__(
self,
content_type,
on_field,
on_file,
on_end=None,
boundary=None,
file_name=None,
FileClass=File,
FieldClass=Field,
content_type: str,
on_field: OnFieldCallback,
on_file: OnFileCallback,
on_end: Callable[[], None] | None = None,
boundary: bytes | str | None = None,
file_name: bytes | None = None,
FileClass: type[FileProtocol] = File,
FieldClass: type[FieldProtocol] = Field,
config: FormParserConfig = {},
):
self.logger = logging.getLogger(__name__)
Expand All @@ -1617,38 +1642,37 @@ def __init__(

# Depending on the Content-Type, we instantiate the correct parser.
if content_type == "application/octet-stream":
# Work around the lack of 'nonlocal' in Py2
class vars:
f = None
f: FileProtocol | None = None

def on_start() -> None:
vars.f = FileClass(file_name, None, config=self.config)
nonlocal f
f = FileClass(file_name, None, config=self.config)

def on_data(data: bytes, start: int, end: int) -> None:
vars.f.write(data[start:end])
nonlocal f
f.write(data[start:end])

def on_end() -> None:
def _on_end() -> None:
# Finalize the file itself.
vars.f.finalize()
f.finalize()

# Call our callback.
on_file(vars.f)
on_file(f)

# Call the on-end callback.
if self.on_end is not None:
self.on_end()

# Instantiate an octet-stream parser
parser = OctetStreamParser(
callbacks={"on_start": on_start, "on_data": on_data, "on_end": on_end},
callbacks={"on_start": on_start, "on_data": on_data, "on_end": _on_end},
max_size=self.config["MAX_BODY_SIZE"],
)

elif content_type == "application/x-www-form-urlencoded" or content_type == "application/x-url-encoded":
name_buffer: list[bytes] = []

class vars:
f = None
f: FieldProtocol | None = None

def on_field_start() -> None:
pass
Expand All @@ -1657,25 +1681,27 @@ def on_field_name(data: bytes, start: int, end: int) -> None:
name_buffer.append(data[start:end])

def on_field_data(data: bytes, start: int, end: int) -> None:
if vars.f is None:
vars.f = FieldClass(b"".join(name_buffer))
nonlocal f
if f is None:
f = FieldClass(b"".join(name_buffer))
del name_buffer[:]
vars.f.write(data[start:end])
f.write(data[start:end])

def on_field_end() -> None:
nonlocal f
# Finalize and call callback.
if vars.f is None:
if f is None:
# If we get here, it's because there was no field data.
# We create a field, set it to None, and then continue.
vars.f = FieldClass(b"".join(name_buffer))
f = FieldClass(b"".join(name_buffer))
del name_buffer[:]
vars.f.set_none()
f.set_none()

vars.f.finalize()
on_field(vars.f)
vars.f = None
f.finalize()
on_field(f)
f = None

def on_end() -> None:
def _on_end() -> None:
if self.on_end is not None:
self.on_end()

Expand All @@ -1686,7 +1712,7 @@ def on_end() -> None:
"on_field_name": on_field_name,
"on_field_data": on_field_data,
"on_field_end": on_field_end,
"on_end": on_end,
"on_end": _on_end,
},
max_size=self.config["MAX_BODY_SIZE"],
)
Expand All @@ -1700,26 +1726,26 @@ def on_end() -> None:
header_value: list[bytes] = []
headers = {}

# No 'nonlocal' on Python 2 :-(
class vars:
f = None
writer = None
is_file = False
f: FileProtocol | FieldProtocol | None = None
writer = None
is_file = False

def on_part_begin():
pass

def on_part_data(data: bytes, start: int, end: int):
bytes_processed = vars.writer.write(data[start:end])
def on_part_data(data: bytes, start: int, end: int) -> None:
nonlocal writer
bytes_processed = writer.write(data[start:end])
# TODO: check for error here.
return bytes_processed

def on_part_end() -> None:
vars.f.finalize()
if vars.is_file:
on_file(vars.f)
nonlocal f, is_file
f.finalize()
if is_file:
on_file(f)
else:
on_field(vars.f)
on_field(f)

def on_header_field(data: bytes, start: int, end: int):
header_name.append(data[start:end])
Expand All @@ -1733,8 +1759,9 @@ def on_header_end():
del header_value[:]

def on_headers_finished() -> None:
nonlocal is_file, f, writer
# Reset the 'is file' flag.
vars.is_file = False
is_file = False

# Parse the content-disposition header.
# TODO: handle mixed case
Expand All @@ -1748,24 +1775,24 @@ def on_headers_finished() -> None:

# Create the proper class.
if file_name is None:
vars.f = FieldClass(field_name)
f = FieldClass(field_name)
else:
vars.f = FileClass(file_name, field_name, config=self.config)
vars.is_file = True
f = FileClass(file_name, field_name, config=self.config)
is_file = True

# Parse the given Content-Transfer-Encoding to determine what
# we need to do with the incoming data.
# TODO: check that we properly handle 8bit / 7bit encoding.
transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit")

if transfer_encoding == b"binary" or transfer_encoding == b"8bit" or transfer_encoding == b"7bit":
vars.writer = vars.f
writer = f

elif transfer_encoding == b"base64":
vars.writer = Base64Decoder(vars.f)
writer = Base64Decoder(f)

elif transfer_encoding == b"quoted-printable":
vars.writer = QuotedPrintableDecoder(vars.f)
writer = QuotedPrintableDecoder(f)

else:
self.logger.warning("Unknown Content-Transfer-Encoding: %r", transfer_encoding)
Expand All @@ -1774,10 +1801,11 @@ def on_headers_finished() -> None:
else:
# If we aren't erroring, then we just treat this as an
# unencoded Content-Transfer-Encoding.
vars.writer = vars.f
writer = f

def on_end() -> None:
vars.writer.finalize()
def _on_end() -> None:
nonlocal writer
writer.finalize()
if self.on_end is not None:
self.on_end()

Expand All @@ -1792,7 +1820,7 @@ def on_end() -> None:
"on_header_value": on_header_value,
"on_header_end": on_header_end,
"on_headers_finished": on_headers_finished,
"on_end": on_end,
"on_end": _on_end,
},
max_size=self.config["MAX_BODY_SIZE"],
)
Expand Down

0 comments on commit 59543a4

Please sign in to comment.