From cb75e06c5b6a7d2262778a9ea6090b2029a8e4a7 Mon Sep 17 00:00:00 2001 From: John Stark Date: Fri, 29 Mar 2024 08:42:37 +0100 Subject: [PATCH 1/8] Adding headers --- python_multipart/multipart.py | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/python_multipart/multipart.py b/python_multipart/multipart.py index a996379..a39f70b 100644 --- a/python_multipart/multipart.py +++ b/python_multipart/multipart.py @@ -68,12 +68,14 @@ def finalize(self) -> None: ... def close(self) -> None: ... class FieldProtocol(_FormProtocol, Protocol): - def __init__(self, name: bytes | None) -> None: ... + def __init__(self, name: bytes, headers: dict[str,bytes]) -> None: + ... def set_none(self) -> None: ... class FileProtocol(_FormProtocol, Protocol): - def __init__(self, file_name: bytes | None, field_name: bytes | None, config: FileConfig) -> None: ... + def __init__(self, file_name: bytes | None, field_name: bytes | None, headers: dict[str,bytes], config: FileConfig) -> None: + ... OnFieldCallback = Callable[[FieldProtocol], None] OnFileCallback = Callable[[FileProtocol], None] @@ -223,9 +225,10 @@ class Field: name: The name of the form field. """ - def __init__(self, name: bytes | None) -> None: + def __init__(self, name: bytes, headers: dict[str,bytes]={}) -> None: self._name = name self._value: list[bytes] = [] + self._headers: dict[str,bytes] = headers # We cache the joined version of _value for speed. self._cache = _missing @@ -317,6 +320,11 @@ def value(self) -> bytes | None: assert isinstance(self._cache, bytes) or self._cache is None return self._cache + @property + def headers(self) -> dict[str,bytes]: + """This property returns the headers of the field.""" + return self._headers + def __eq__(self, other: object) -> bool: if isinstance(other, Field): return self.field_name == other.field_name and self.value == other.value @@ -357,7 +365,7 @@ class File: config: The configuration for this File. See above for valid configuration keys and their corresponding values. """ # noqa: E501 - def __init__(self, file_name: bytes | None, field_name: bytes | None = None, config: FileConfig = {}) -> None: + def __init__(self, file_name: bytes | None, field_name: bytes | None = None, headers: dict[str,bytes] = {}, config: FileConfig = {}) -> None: # Save configuration, set other variables default. self.logger = logging.getLogger(__name__) self._config = config @@ -365,9 +373,10 @@ def __init__(self, file_name: bytes | None, field_name: bytes | None = None, con self._bytes_written = 0 self._fileobj: BytesIO | BufferedRandom = BytesIO() - # Save the provided field/file name. + # Save the provided field/file name and content type. self._field_name = field_name self._file_name = file_name + self._headers = headers # Our actual file name is None by default, since, depending on our # config, we may not actually use the provided name. @@ -420,6 +429,12 @@ def in_memory(self) -> bool: """ return self._in_memory + @property + def headers(self) -> dict[str,bytes]: + """The headers for this part. + """ + return self._headers + def flush_to_disk(self) -> None: """If the file is already on-disk, do nothing. Otherwise, copy from the in-memory buffer to a disk file, and then reassign our internal From d46a9c65372cac3eb5cd9615d81d1421a5d04e17 Mon Sep 17 00:00:00 2001 From: John Stark Date: Tue, 30 Apr 2024 22:41:54 +0200 Subject: [PATCH 2/8] Add trailer test --- tests/test_data/http/single_field_with_trailer.http | 7 +++++++ tests/test_data/http/single_field_with_trailer.yaml | 6 ++++++ 2 files changed, 13 insertions(+) create mode 100644 tests/test_data/http/single_field_with_trailer.http create mode 100644 tests/test_data/http/single_field_with_trailer.yaml diff --git a/tests/test_data/http/single_field_with_trailer.http b/tests/test_data/http/single_field_with_trailer.http new file mode 100644 index 0000000..a570340 --- /dev/null +++ b/tests/test_data/http/single_field_with_trailer.http @@ -0,0 +1,7 @@ +------WebKitFormBoundaryTkr3kCBQlBe1nrhc +Content-Disposition: form-data; name="field" + +This is a test. +------WebKitFormBoundaryTkr3kCBQlBe1nrhc-- +this trailer causes a warning +but should be ignored \ No newline at end of file diff --git a/tests/test_data/http/single_field_with_trailer.yaml b/tests/test_data/http/single_field_with_trailer.yaml new file mode 100644 index 0000000..7690f08 --- /dev/null +++ b/tests/test_data/http/single_field_with_trailer.yaml @@ -0,0 +1,6 @@ +boundary: ----WebKitFormBoundaryTkr3kCBQlBe1nrhc +expected: + - name: field + type: field + data: !!binary | + VGhpcyBpcyBhIHRlc3Qu From 12b3f2631c06ba8f3149216dea50429954bd0d9d Mon Sep 17 00:00:00 2001 From: John Stark Date: Tue, 30 Apr 2024 22:45:08 +0200 Subject: [PATCH 3/8] Add content_type property to File Makes all headers lower case, fixing case sensitivity issues. Exposes jheaders property in Files and Fields. --- python_multipart/multipart.py | 55 +++++++++------- .../test_data/http/almost_match_boundary.yaml | 1 + tests/test_data/http/base64_encoding.yaml | 1 + .../http/case_insensitive_headers.http | 21 ++++++ .../http/case_insensitive_headers.yaml | 26 ++++++++ tests/test_data/http/header_with_number.yaml | 1 + tests/test_data/http/multiple_files.yaml | 2 + .../http/quoted_printable_encoding.yaml | 1 + .../http/single_field_single_file.yaml | 2 + tests/test_data/http/single_file.yaml | 1 + tests/test_data/http/utf8_filename.yaml | 1 + tests/test_multipart.py | 66 +++++++++++++++++-- 12 files changed, 151 insertions(+), 27 deletions(-) create mode 100644 tests/test_data/http/case_insensitive_headers.http create mode 100644 tests/test_data/http/case_insensitive_headers.yaml diff --git a/python_multipart/multipart.py b/python_multipart/multipart.py index a39f70b..c08dec5 100644 --- a/python_multipart/multipart.py +++ b/python_multipart/multipart.py @@ -68,14 +68,14 @@ def finalize(self) -> None: ... def close(self) -> None: ... class FieldProtocol(_FormProtocol, Protocol): - def __init__(self, name: bytes, headers: dict[str,bytes]) -> None: - ... + def __init__(self, name: bytes, headers: dict[str, bytes]) -> None: ... def set_none(self) -> None: ... class FileProtocol(_FormProtocol, Protocol): - def __init__(self, file_name: bytes | None, field_name: bytes | None, headers: dict[str,bytes], config: FileConfig) -> None: - ... + def __init__( + self, file_name: bytes | None, field_name: bytes | None, config: FileConfig, headers: dict[str, bytes] + ) -> None: ... OnFieldCallback = Callable[[FieldProtocol], None] OnFileCallback = Callable[[FileProtocol], None] @@ -225,10 +225,10 @@ class Field: name: The name of the form field. """ - def __init__(self, name: bytes, headers: dict[str,bytes]={}) -> None: + def __init__(self, name: bytes, headers: dict[str, bytes] = {}) -> None: self._name = name self._value: list[bytes] = [] - self._headers: dict[str,bytes] = headers + self._headers: dict[str, bytes] = headers # We cache the joined version of _value for speed. self._cache = _missing @@ -321,7 +321,7 @@ def value(self) -> bytes | None: return self._cache @property - def headers(self) -> dict[str,bytes]: + def headers(self) -> dict[str, bytes]: """This property returns the headers of the field.""" return self._headers @@ -365,7 +365,13 @@ class File: config: The configuration for this File. See above for valid configuration keys and their corresponding values. """ # noqa: E501 - def __init__(self, file_name: bytes | None, field_name: bytes | None = None, headers: dict[str,bytes] = {}, config: FileConfig = {}) -> None: + def __init__( + self, + file_name: bytes | None, + field_name: bytes | None = None, + headers: dict[str, bytes] = {}, + config: FileConfig = {}, + ) -> None: # Save configuration, set other variables default. self.logger = logging.getLogger(__name__) self._config = config @@ -430,11 +436,15 @@ def in_memory(self) -> bool: return self._in_memory @property - def headers(self) -> dict[str,bytes]: - """The headers for this part. - """ + def headers(self) -> dict[str, bytes]: + """The headers for this part.""" return self._headers - + + @property + def content_type(self) -> bytes | None: + """The Content-Type value for this part.""" + return self._headers.get("content-type") + def flush_to_disk(self) -> None: """If the file is already on-disk, do nothing. Otherwise, copy from the in-memory buffer to a disk file, and then reassign our internal @@ -1555,7 +1565,7 @@ def __init__( def on_start() -> None: nonlocal file - file = FileClass(file_name, None, config=cast("FileConfig", self.config)) + file = FileClass(file_name, None, headers={}, config=cast("FileConfig", self.config)) def on_data(data: bytes, start: int, end: int) -> None: nonlocal file @@ -1594,7 +1604,7 @@ def on_field_name(data: bytes, start: int, end: int) -> None: def on_field_data(data: bytes, start: int, end: int) -> None: nonlocal f if f is None: - f = FieldClass(b"".join(name_buffer)) + f = FieldClass(b"".join(name_buffer), headers={}) del name_buffer[:] f.write(data[start:end]) @@ -1604,7 +1614,7 @@ def on_field_end() -> None: if f is None: # If we get here, it's because there was no field data. # We create a field, set it to None, and then continue. - f = FieldClass(b"".join(name_buffer)) + f = FieldClass(b"".join(name_buffer), headers={}) del name_buffer[:] f.set_none() @@ -1636,7 +1646,7 @@ def _on_end() -> None: header_name: list[bytes] = [] header_value: list[bytes] = [] - headers: dict[bytes, bytes] = {} + headers: dict[str, bytes] = {} f_multi: FileProtocol | FieldProtocol | None = None writer = None @@ -1671,7 +1681,7 @@ def on_header_value(data: bytes, start: int, end: int) -> None: header_value.append(data[start:end]) def on_header_end() -> None: - headers[b"".join(header_name)] = b"".join(header_value) + headers[b"".join(header_name).decode().lower()] = b"".join(header_value) del header_name[:] del header_value[:] @@ -1681,26 +1691,25 @@ def on_headers_finished() -> None: is_file = False # Parse the content-disposition header. - # TODO: handle mixed case - content_disp = headers.get(b"Content-Disposition") + content_disp = headers.get("content-disposition") disp, options = parse_options_header(content_disp) # Get the field and filename. - field_name = options.get(b"name") + field_name = options.get(b"name", b"") file_name = options.get(b"filename") # TODO: check for errors # Create the proper class. if file_name is None: - f_multi = FieldClass(field_name) + f_multi = FieldClass(field_name, headers=headers) else: - f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config)) + f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config), headers=headers) is_file = True # Parse the given Content-Transfer-Encoding to determine what # we need to do with the incoming data. # TODO: check that we properly handle 8bit / 7bit encoding. - transfer_encoding = headers.get(b"Content-Transfer-Encoding", b"7bit") + transfer_encoding = headers.get("content-transfer-encoding", b"7bit") if transfer_encoding in (b"binary", b"8bit", b"7bit"): writer = f_multi diff --git a/tests/test_data/http/almost_match_boundary.yaml b/tests/test_data/http/almost_match_boundary.yaml index 235493e..c114ffe 100644 --- a/tests/test_data/http/almost_match_boundary.yaml +++ b/tests/test_data/http/almost_match_boundary.yaml @@ -3,6 +3,7 @@ expected: - name: file type: file file_name: test.txt + content_type: text/plain data: !!binary | LS1ib3VuZGFyaQ0KLS1ib3VuZGFyeXEtLWJvdW5kYXJ5DXEtLWJvdW5kYXJxDQotLWJvdW5hcnlkLS0NCi0tbm90Ym91bmQtLQ0KLS1taXNtYXRjaA0KLS1taXNtYXRjaC0tDQotLWJvdW5kYXJ5LVENCi0tYm91bmRhcnkNUS0tYm91bmRhcnlR diff --git a/tests/test_data/http/base64_encoding.yaml b/tests/test_data/http/base64_encoding.yaml index 1033150..db227a1 100644 --- a/tests/test_data/http/base64_encoding.yaml +++ b/tests/test_data/http/base64_encoding.yaml @@ -3,5 +3,6 @@ expected: - name: file type: file file_name: test.txt + content_type: text/plain data: !!binary | VGVzdCAxMjM= diff --git a/tests/test_data/http/case_insensitive_headers.http b/tests/test_data/http/case_insensitive_headers.http new file mode 100644 index 0000000..a14cc11 --- /dev/null +++ b/tests/test_data/http/case_insensitive_headers.http @@ -0,0 +1,21 @@ +------WebKitFormBoundarygbACTUR58IyeurVf +Content-Disposition: form-data; name="file1"; filename="test1.txt" +Content-Type: text/plain + +Test file #1 +------WebKitFormBoundarygbACTUR58IyeurVf +CONTENT-DISPOSITION: form-data; name="file2"; filename="test2.txt" +CONTENT-Type: text/plain + +Test file #2 +------WebKitFormBoundarygbACTUR58IyeurVf +content-disposition: form-data; name="file3"; filename="test3.txt" +content-type: text/plain + +Test file #3 +------WebKitFormBoundarygbACTUR58IyeurVf +cOnTenT-DiSpOsItiOn: form-data; name="file4"; filename="test4.txt" +Content-Type: text/plain + +Test file #4 +------WebKitFormBoundarygbACTUR58IyeurVf-- diff --git a/tests/test_data/http/case_insensitive_headers.yaml b/tests/test_data/http/case_insensitive_headers.yaml new file mode 100644 index 0000000..4c9d365 --- /dev/null +++ b/tests/test_data/http/case_insensitive_headers.yaml @@ -0,0 +1,26 @@ +boundary: ----WebKitFormBoundarygbACTUR58IyeurVf +expected: + - name: file1 + type: file + file_name: test1.txt + content_type: text/plain + data: !!binary | + VGVzdCBmaWxlICMx + - name: file2 + type: file + file_name: test2.txt + content_type: text/plain + data: !!binary | + VGVzdCBmaWxlICMy + - name: file3 + type: file + file_name: test3.txt + content_type: text/plain + data: !!binary | + VGVzdCBmaWxlICMz + - name: file4 + type: file + file_name: test4.txt + content_type: text/plain + data: !!binary | + VGVzdCBmaWxlICM0 diff --git a/tests/test_data/http/header_with_number.yaml b/tests/test_data/http/header_with_number.yaml index 493b783..86b4779 100644 --- a/tests/test_data/http/header_with_number.yaml +++ b/tests/test_data/http/header_with_number.yaml @@ -3,5 +3,6 @@ expected: - name: files type: file file_name: secret.txt + content_type: "text/plain; charset=utf-8" data: !!binary | YWFhYWFh diff --git a/tests/test_data/http/multiple_files.yaml b/tests/test_data/http/multiple_files.yaml index 3bf70e2..b372ab2 100644 --- a/tests/test_data/http/multiple_files.yaml +++ b/tests/test_data/http/multiple_files.yaml @@ -3,11 +3,13 @@ expected: - name: file1 type: file file_name: test1.txt + content_type: 'text/plain' data: !!binary | VGVzdCBmaWxlICMx - name: file2 type: file file_name: test2.txt + content_type: 'text/plain' data: !!binary | VGVzdCBmaWxlICMy diff --git a/tests/test_data/http/quoted_printable_encoding.yaml b/tests/test_data/http/quoted_printable_encoding.yaml index 2c6bbfb..6dcbde3 100644 --- a/tests/test_data/http/quoted_printable_encoding.yaml +++ b/tests/test_data/http/quoted_printable_encoding.yaml @@ -3,5 +3,6 @@ expected: - name: file type: file file_name: test.txt + content_type: 'text/plain' data: !!binary | Zm9vPWJhcg== diff --git a/tests/test_data/http/single_field_single_file.yaml b/tests/test_data/http/single_field_single_file.yaml index 47c8d6e..fa7002e 100644 --- a/tests/test_data/http/single_field_single_file.yaml +++ b/tests/test_data/http/single_field_single_file.yaml @@ -2,11 +2,13 @@ boundary: boundary expected: - name: field type: field + content_type: 'text/plain' data: !!binary | dGVzdDE= - name: file type: file file_name: file.txt + content_type: 'text/plain' data: !!binary | dGVzdDI= diff --git a/tests/test_data/http/single_file.yaml b/tests/test_data/http/single_file.yaml index 2a8e005..dbdff51 100644 --- a/tests/test_data/http/single_file.yaml +++ b/tests/test_data/http/single_file.yaml @@ -3,6 +3,7 @@ expected: - name: file type: file file_name: test.txt + content_type: 'text/plain' data: !!binary | VGhpcyBpcyBhIHRlc3QgZmlsZS4= diff --git a/tests/test_data/http/utf8_filename.yaml b/tests/test_data/http/utf8_filename.yaml index 507ba2c..25fab67 100644 --- a/tests/test_data/http/utf8_filename.yaml +++ b/tests/test_data/http/utf8_filename.yaml @@ -3,6 +3,7 @@ expected: - name: file type: file file_name: ???.txt + content_type: 'text/plain' data: !!binary | 44GT44KM44Gv44OG44K544OI44Gn44GZ44CC diff --git a/tests/test_multipart.py b/tests/test_multipart.py index ce92ff4..82f7138 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -758,7 +758,7 @@ def assert_file_data(self, f: File, data: bytes) -> None: file_data = o.read() self.assertEqual(file_data, data) - def assert_file(self, field_name: bytes, file_name: bytes, data: bytes) -> None: + def assert_file(self, field_name: bytes, file_name: bytes, content_type: str, data: bytes) -> None: # Find this file. found = None for f in self.files: @@ -770,6 +770,8 @@ def assert_file(self, field_name: bytes, file_name: bytes, data: bytes) -> None: self.assertIsNotNone(found) assert found is not None + self.assertEqual(found.content_type, content_type.encode()) + try: # Assert about this file. self.assert_file_data(found, data) @@ -839,7 +841,7 @@ def test_http(self, param: TestParams) -> None: self.assert_field(name, e["data"]) elif type == "file": - self.assert_file(name, e["file_name"].encode("latin-1"), e["data"]) + self.assert_file(name, e["file_name"].encode("latin-1"), e["content_type"], e["data"]) else: assert False @@ -870,7 +872,7 @@ def test_random_splitting(self) -> None: # Assert that our file and field are here. self.assert_field(b"field", b"test1") - self.assert_file(b"file", b"file.txt", b"test2") + self.assert_file(b"file", b"file.txt", "text/plain", b"test2") @parametrize("param", [t for t in http_tests if t["name"] in single_byte_tests]) def test_feed_single_bytes(self, param: TestParams) -> None: @@ -909,7 +911,7 @@ def test_feed_single_bytes(self, param: TestParams) -> None: self.assert_field(name, e["data"]) elif type == "file": - self.assert_file(name, e["file_name"].encode("latin-1"), e["data"]) + self.assert_file(name, e["file_name"].encode("latin-1"), "text/plain", e["data"]) else: assert False @@ -947,6 +949,62 @@ def test_feed_blocks(self) -> None: # Assert that our field is here. self.assert_field(b"field", b"0123456789ABCDEFGHIJ0123456789ABCDEFGHIJ") + def test_file_headers(self) -> None: + """ + This test checks headers for a file part are read. + """ + # Load test data. + test_file = "header_with_number.http" + with open(os.path.join(http_tests_dir, test_file), "rb") as f: + test_data = f.read() + + expected_headers = { + "content-disposition": b'form-data; filename="secret.txt"; name="files"', + "content-type": b"text/plain; charset=utf-8", + "x-funky-header-1": b"bar", + "abcdefghijklmnopqrstuvwxyz01234": b"foo", + "abcdefghijklmnopqrstuvwxyz56789": b"bar", + "other!#$%&'*+-.^_`|~": b"baz", + "content-length": b"6", + } + + # Create form parser. + self.make(boundary="b8825ae386be4fdc9644d87e392caad3") + self.f.write(test_data) + self.f.finalize() + + # Assert that our field is here. + self.assertEqual(1, len(self.files)) + actual_headers = self.files[0].headers + self.assertEqual(len(actual_headers), len(expected_headers)) + + for k, v in expected_headers.items(): + self.assertEqual(v, actual_headers[k]) + + def test_field_headers(self) -> None: + """ + This test checks headers for a field part are read. + """ + # Load test data. + test_file = "single_field.http" + with open(os.path.join(http_tests_dir, test_file), "rb") as f: + test_data = f.read() + + expected_headers = {"content-disposition": b'form-data; name="field"'} + + # Create form parser. + self.make(boundary="----WebKitFormBoundaryTkr3kCBQlBe1nrhc") + self.f.write(test_data) + self.f.finalize() + + # Assert that our field is here. + self.assertEqual(1, len(self.fields)) + actual_headers = self.fields[0].headers + self.assertEqual(len(actual_headers), len(expected_headers)) + + for k, v in expected_headers.items(): + self.assertEqual(v, actual_headers[k]) + def test_request_body_fuzz(self) -> None: """ This test randomly fuzzes the request body to ensure that no strange From f1e28d6efcb83e1140d2d34adf9acf101fc64909 Mon Sep 17 00:00:00 2001 From: John Stark Date: Sun, 8 Dec 2024 13:42:47 +0100 Subject: [PATCH 4/8] Use content_type in Fields not header --- python_multipart/multipart.py | 45 ++++++++++++++++++----------------- tests/test_multipart.py | 41 +++++++++++-------------------- 2 files changed, 37 insertions(+), 49 deletions(-) diff --git a/python_multipart/multipart.py b/python_multipart/multipart.py index c08dec5..99a8204 100644 --- a/python_multipart/multipart.py +++ b/python_multipart/multipart.py @@ -68,13 +68,13 @@ def finalize(self) -> None: ... def close(self) -> None: ... class FieldProtocol(_FormProtocol, Protocol): - def __init__(self, name: bytes, headers: dict[str, bytes]) -> None: ... + def __init__(self, name: bytes, content_type: str | None = None) -> None: ... def set_none(self) -> None: ... class FileProtocol(_FormProtocol, Protocol): def __init__( - self, file_name: bytes | None, field_name: bytes | None, config: FileConfig, headers: dict[str, bytes] + self, file_name: bytes | None, field_name: bytes | None, config: FileConfig, content_type: str | None = None ) -> None: ... OnFieldCallback = Callable[[FieldProtocol], None] @@ -223,12 +223,13 @@ class Field: Args: name: The name of the form field. + content_type: The value of the Content-Type header for this field. """ - def __init__(self, name: bytes, headers: dict[str, bytes] = {}) -> None: + def __init__(self, name: bytes, content_type: str | None = None) -> None: self._name = name self._value: list[bytes] = [] - self._headers: dict[str, bytes] = headers + self._content_type = content_type # We cache the joined version of _value for speed. self._cache = _missing @@ -321,9 +322,9 @@ def value(self) -> bytes | None: return self._cache @property - def headers(self) -> dict[str, bytes]: - """This property returns the headers of the field.""" - return self._headers + def content_type(self) -> str | None: + """This property returns the content_type value of the field.""" + return self._content_type def __eq__(self, other: object) -> bool: if isinstance(other, Field): @@ -362,6 +363,7 @@ class File: file_name: The name of the file that this [`File`][python_multipart.File] represents. field_name: The name of the form field that this file was uploaded with. This can be None, if, for example, the file was uploaded with Content-Type application/octet-stream. + content_type: The value of the Content-Type header. config: The configuration for this File. See above for valid configuration keys and their corresponding values. """ # noqa: E501 @@ -369,7 +371,7 @@ def __init__( self, file_name: bytes | None, field_name: bytes | None = None, - headers: dict[str, bytes] = {}, + content_type: str | None = None, config: FileConfig = {}, ) -> None: # Save configuration, set other variables default. @@ -382,7 +384,7 @@ def __init__( # Save the provided field/file name and content type. self._field_name = field_name self._file_name = file_name - self._headers = headers + self._content_type = content_type # Our actual file name is None by default, since, depending on our # config, we may not actually use the provided name. @@ -436,14 +438,9 @@ def in_memory(self) -> bool: return self._in_memory @property - def headers(self) -> dict[str, bytes]: - """The headers for this part.""" - return self._headers - - @property - def content_type(self) -> bytes | None: - """The Content-Type value for this part.""" - return self._headers.get("content-type") + def content_type(self) -> str | None: + """The Content-Type value for this part, if it was set.""" + return self._content_type def flush_to_disk(self) -> None: """If the file is already on-disk, do nothing. Otherwise, copy from @@ -1565,7 +1562,7 @@ def __init__( def on_start() -> None: nonlocal file - file = FileClass(file_name, None, headers={}, config=cast("FileConfig", self.config)) + file = FileClass(file_name, None, content_type=None, config=cast("FileConfig", self.config)) def on_data(data: bytes, start: int, end: int) -> None: nonlocal file @@ -1604,7 +1601,7 @@ def on_field_name(data: bytes, start: int, end: int) -> None: def on_field_data(data: bytes, start: int, end: int) -> None: nonlocal f if f is None: - f = FieldClass(b"".join(name_buffer), headers={}) + f = FieldClass(b"".join(name_buffer), content_type=None) del name_buffer[:] f.write(data[start:end]) @@ -1614,7 +1611,7 @@ def on_field_end() -> None: if f is None: # If we get here, it's because there was no field data. # We create a field, set it to None, and then continue. - f = FieldClass(b"".join(name_buffer), headers={}) + f = FieldClass(b"".join(name_buffer), content_type=None) del name_buffer[:] f.set_none() @@ -1700,10 +1697,14 @@ def on_headers_finished() -> None: # TODO: check for errors # Create the proper class. + content_type_b = headers.get("content-type") + content_type = content_type_b.decode("latin-1") if content_type_b is not None else None if file_name is None: - f_multi = FieldClass(field_name, headers=headers) + f_multi = FieldClass(field_name, content_type=content_type) else: - f_multi = FileClass(file_name, field_name, config=cast("FileConfig", self.config), headers=headers) + f_multi = FileClass( + file_name, field_name, config=cast("FileConfig", self.config), content_type=content_type + ) is_file = True # Parse the given Content-Transfer-Encoding to determine what diff --git a/tests/test_multipart.py b/tests/test_multipart.py index 82f7138..a29b7ae 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -758,7 +758,7 @@ def assert_file_data(self, f: File, data: bytes) -> None: file_data = o.read() self.assertEqual(file_data, data) - def assert_file(self, field_name: bytes, file_name: bytes, content_type: str, data: bytes) -> None: + def assert_file(self, field_name: bytes, file_name: bytes, content_type: str | None, data: bytes) -> None: # Find this file. found = None for f in self.files: @@ -770,7 +770,7 @@ def assert_file(self, field_name: bytes, file_name: bytes, content_type: str, da self.assertIsNotNone(found) assert found is not None - self.assertEqual(found.content_type, content_type.encode()) + self.assertEqual(found.content_type, content_type) try: # Assert about this file. @@ -911,7 +911,8 @@ def test_feed_single_bytes(self, param: TestParams) -> None: self.assert_field(name, e["data"]) elif type == "file": - self.assert_file(name, e["file_name"].encode("latin-1"), "text/plain", e["data"]) + content_type = "text/plain" + self.assert_file(name, e["file_name"].encode("latin-1"), content_type, e["data"]) else: assert False @@ -949,24 +950,16 @@ def test_feed_blocks(self) -> None: # Assert that our field is here. self.assert_field(b"field", b"0123456789ABCDEFGHIJ0123456789ABCDEFGHIJ") - def test_file_headers(self) -> None: + def test_file_content_type_header(self) -> None: """ - This test checks headers for a file part are read. + This test checks the content-type for a file part is passed on. """ # Load test data. test_file = "header_with_number.http" with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() - expected_headers = { - "content-disposition": b'form-data; filename="secret.txt"; name="files"', - "content-type": b"text/plain; charset=utf-8", - "x-funky-header-1": b"bar", - "abcdefghijklmnopqrstuvwxyz01234": b"foo", - "abcdefghijklmnopqrstuvwxyz56789": b"bar", - "other!#$%&'*+-.^_`|~": b"baz", - "content-length": b"6", - } + expected_content_type = "text/plain; charset=utf-8" # Create form parser. self.make(boundary="b8825ae386be4fdc9644d87e392caad3") @@ -975,22 +968,19 @@ def test_file_headers(self) -> None: # Assert that our field is here. self.assertEqual(1, len(self.files)) - actual_headers = self.files[0].headers - self.assertEqual(len(actual_headers), len(expected_headers)) + actual_content_type = self.files[0].content_type + self.assertEqual(actual_content_type, expected_content_type) - for k, v in expected_headers.items(): - self.assertEqual(v, actual_headers[k]) - - def test_field_headers(self) -> None: + def test_field_content_type_header(self) -> None: """ - This test checks headers for a field part are read. + This test checks content-tpye for a field part are read and passed. """ # Load test data. test_file = "single_field.http" with open(os.path.join(http_tests_dir, test_file), "rb") as f: test_data = f.read() - expected_headers = {"content-disposition": b'form-data; name="field"'} + expected_content_type = None # Create form parser. self.make(boundary="----WebKitFormBoundaryTkr3kCBQlBe1nrhc") @@ -999,11 +989,8 @@ def test_field_headers(self) -> None: # Assert that our field is here. self.assertEqual(1, len(self.fields)) - actual_headers = self.fields[0].headers - self.assertEqual(len(actual_headers), len(expected_headers)) - - for k, v in expected_headers.items(): - self.assertEqual(v, actual_headers[k]) + actual_content_type = self.fields[0].content_type + self.assertEqual(actual_content_type, expected_content_type) def test_request_body_fuzz(self) -> None: """ From 37f4c53f3a470626b8c995950b83dfc5195c024b Mon Sep 17 00:00:00 2001 From: John Stark Date: Mon, 16 Dec 2024 19:13:46 +0000 Subject: [PATCH 5/8] Update python_multipart/multipart.py Co-authored-by: Marcelo Trylesinski --- python_multipart/multipart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python_multipart/multipart.py b/python_multipart/multipart.py index 99a8204..bc538a3 100644 --- a/python_multipart/multipart.py +++ b/python_multipart/multipart.py @@ -1562,7 +1562,7 @@ def __init__( def on_start() -> None: nonlocal file - file = FileClass(file_name, None, content_type=None, config=cast("FileConfig", self.config)) + file = FileClass(file_name, config=cast("FileConfig", self.config)) def on_data(data: bytes, start: int, end: int) -> None: nonlocal file From afbf23f443121f17a1850b3e3e8a86bdcaa0350e Mon Sep 17 00:00:00 2001 From: John Stark Date: Mon, 16 Dec 2024 19:14:08 +0000 Subject: [PATCH 6/8] Update python_multipart/multipart.py Co-authored-by: Marcelo Trylesinski --- python_multipart/multipart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python_multipart/multipart.py b/python_multipart/multipart.py index bc538a3..00234cf 100644 --- a/python_multipart/multipart.py +++ b/python_multipart/multipart.py @@ -1601,7 +1601,7 @@ def on_field_name(data: bytes, start: int, end: int) -> None: def on_field_data(data: bytes, start: int, end: int) -> None: nonlocal f if f is None: - f = FieldClass(b"".join(name_buffer), content_type=None) + f = FieldClass(b"".join(name_buffer)) del name_buffer[:] f.write(data[start:end]) From b1c2d3ba73342f80bc026a2f4e31b14e8dfb90df Mon Sep 17 00:00:00 2001 From: John Stark Date: Mon, 16 Dec 2024 19:14:21 +0000 Subject: [PATCH 7/8] Update python_multipart/multipart.py Co-authored-by: Marcelo Trylesinski --- python_multipart/multipart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python_multipart/multipart.py b/python_multipart/multipart.py index 00234cf..7338e67 100644 --- a/python_multipart/multipart.py +++ b/python_multipart/multipart.py @@ -1611,7 +1611,7 @@ def on_field_end() -> None: if f is None: # If we get here, it's because there was no field data. # We create a field, set it to None, and then continue. - f = FieldClass(b"".join(name_buffer), content_type=None) + f = FieldClass(b"".join(name_buffer)) del name_buffer[:] f.set_none() From 954df085a5ac8c71a1cc845ef93712d6b96ee818 Mon Sep 17 00:00:00 2001 From: John Stark Date: Mon, 16 Dec 2024 20:59:24 +0100 Subject: [PATCH 8/8] Check for field name and raise if not found --- python_multipart/multipart.py | 8 +++++--- tests/test_multipart.py | 17 +++++++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/python_multipart/multipart.py b/python_multipart/multipart.py index 7338e67..ebfeaf3 100644 --- a/python_multipart/multipart.py +++ b/python_multipart/multipart.py @@ -1562,7 +1562,7 @@ def __init__( def on_start() -> None: nonlocal file - file = FileClass(file_name, config=cast("FileConfig", self.config)) + file = FileClass(file_name, b"", config=cast("FileConfig", self.config)) def on_data(data: bytes, start: int, end: int) -> None: nonlocal file @@ -1692,9 +1692,11 @@ def on_headers_finished() -> None: disp, options = parse_options_header(content_disp) # Get the field and filename. - field_name = options.get(b"name", b"") + field_name = options.get(b"name") file_name = options.get(b"filename") - # TODO: check for errors + if field_name is None: + raise FormParserError('Field name not found in Content-Disposition: "{!r}"'.format(content_disp)) + # TODO: check for other errors # Create the proper class. content_type_b = headers.get("content-type") diff --git a/tests/test_multipart.py b/tests/test_multipart.py index a29b7ae..3c5546f 100644 --- a/tests/test_multipart.py +++ b/tests/test_multipart.py @@ -1235,6 +1235,23 @@ def on_file(f: FileProtocol) -> None: f.finalize() self.assert_file_data(files[0], b"Test") + def test_bad_content_disposition(self) -> None: + # Field name is required. + data = ( + b"----boundary\r\nContent-Disposition: form-data;\r\n" + b"Content-Type: text/plain\r\n" + b"Test\r\n----boundary--\r\n" + ) + + on_field = Mock() + on_file = Mock() + + f = FormParser("multipart/form-data", on_field, on_file, boundary="--boundary") + + with self.assertRaises(FormParserError): + f.write(data) + f.finalize() + def test_handles_None_fields(self) -> None: fields: list[Field] = []