Skip to content

Commit

Permalink
Support on_header_begin (#103)
Browse files Browse the repository at this point in the history
  • Loading branch information
Kludex authored Feb 10, 2024
1 parent 6181731 commit 54c4d18
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 8 deletions.
23 changes: 16 additions & 7 deletions multipart/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from enum import IntEnum
from io import BytesIO
from numbers import Number
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from .decoders import Base64Decoder, QuotedPrintableDecoder
from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError
Expand All @@ -33,7 +33,7 @@ class MultipartCallbacks(TypedDict, total=False):
on_part_begin: Callable[[], None]
on_part_data: Callable[[bytes, int, int], None]
on_part_end: Callable[[], None]
on_headers_begin: Callable[[], None]
on_header_begin: Callable[[], None]
on_header_field: Callable[[bytes, int, int], None]
on_header_value: Callable[[bytes, int, int], None]
on_header_end: Callable[[], None]
Expand Down Expand Up @@ -592,10 +592,12 @@ class BaseParser:
performance.
"""

callbacks: dict[str, Callable[..., Any]]

def __init__(self):
self.logger = logging.getLogger(__name__)

def callback(self, name: str, data=None, start=None, end=None):
def callback(self, name: str, data: bytes | None = None, start: int | None = None, end: int | None = None):
"""This function calls a provided callback with some data. If the
callback is not set, will do nothing.
Expand Down Expand Up @@ -1047,7 +1049,7 @@ def __init__(self, boundary: bytes | str, callbacks: MultipartCallbacks = {}, ma
self._current_size = 0

# Setup marks. These are used to track the state of data received.
self.marks = {}
self.marks: dict[str, int] = {}

# TODO: Actually use this rather than the dumb version we currently use
# # Precompute the skip table for the Boyer-Moore-Horspool algorithm.
Expand Down Expand Up @@ -1118,19 +1120,19 @@ def _internal_write(self, data: bytes, length: int) -> int:
i = 0

# Set a mark.
def set_mark(name):
def set_mark(name: str):
self.marks[name] = i

# Remove a mark.
def delete_mark(name, reset=False):
def delete_mark(name: str, reset: bool = False):
self.marks.pop(name, None)

# Helper function that makes calling a callback with data easier. The
# 'remaining' parameter will callback from the marked value until the
# end of the buffer, and reset the mark, instead of deleting it. This
# is used at the end of the function to call our callbacks with any
# remaining data in this chunk.
def data_callback(name, remaining=False):
def data_callback(name: str, remaining: bool = False):
marked_index = self.marks.get(name)
if marked_index is None:
return
Expand Down Expand Up @@ -1217,6 +1219,13 @@ def data_callback(name, remaining=False):
# Set a mark of our header field.
set_mark("header_field")

# Notify that we're starting a header if the next character is
# not a CR; a CR at the beginning of the header will cause us
# to stop parsing headers in the MultipartState.HEADER_FIELD state,
# below.
if c != CR:
self.callback("header_begin")

# Move to parsing header fields.
state = MultipartState.HEADER_FIELD
i -= 1
Expand Down
31 changes: 30 additions & 1 deletion tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,7 +1222,36 @@ def on_file(f):

def test_invalid_max_size_multipart(self):
with self.assertRaises(ValueError):
q = MultipartParser(b"bound", max_size="foo")
MultipartParser(b"bound", max_size="foo")

def test_header_begin_callback(self):
"""
This test verifies we call the `on_header_begin` callback.
See GitHub issue #23
"""
# Load test data.
test_file = "single_field_single_file.http"
with open(os.path.join(http_tests_dir, test_file), "rb") as f:
test_data = f.read()

calls = 0

def on_header_begin() -> None:
nonlocal calls
calls += 1

parser = MultipartParser("boundary", callbacks={"on_header_begin": on_header_begin}, max_size=1000)

# Create multipart parser and feed it
i = parser.write(test_data)
parser.finalize()

# Assert we processed everything.
self.assertEqual(i, len(test_data))

# Assert that we called our 'header_begin' callback three times; once
# for each header in the multipart message.
self.assertEqual(calls, 3)


class TestHelperFunctions(unittest.TestCase):
Expand Down

0 comments on commit 54c4d18

Please sign in to comment.