diff --git a/multipart/multipart.py b/multipart/multipart.py index 27aaecb..5f70a8a 100644 --- a/multipart/multipart.py +++ b/multipart/multipart.py @@ -1193,6 +1193,13 @@ def data_callback(name, remaining=False): # Set a mark of our header field. set_mark('header_field') + # Notify that we're starting a header if the next character is + # not a CR; a CR at the beginning of the header will cause us + # to stop parsing headers in the STATE_HEADER_FIELD state, + # below. + if c != CR: + self.callback('header_begin') + # Move to parsing header fields. state = STATE_HEADER_FIELD i -= 1 diff --git a/multipart/tests/test_multipart.py b/multipart/tests/test_multipart.py index 7913cd2..1ef5c57 100644 --- a/multipart/tests/test_multipart.py +++ b/multipart/tests/test_multipart.py @@ -1244,6 +1244,38 @@ def test_invalid_max_size_multipart(self): with self.assertRaises(ValueError): q = MultipartParser(b'bound', max_size='foo') + def test_header_begin_callback(self): + """ + This test verifies we call the `on_header_begin` callback. + + See GitHub issue #23 + """ + # Load test data. + test_file = 'single_field_single_file.http' + with open(os.path.join(http_tests_dir, test_file), 'rb') as f: + test_data = f.read() + + calls = [] + def on_header_begin(*args, **kwargs): + calls.append((args, kwargs)) + + callbacks = { + 'on_header_begin': on_header_begin, + } + parser = MultipartParser('boundary', callbacks, + max_size=1000) + + # Create multipart parser and feed it + i = parser.write(test_data) + parser.finalize() + + # Assert we processed everything. + self.assertEqual(i, len(test_data)) + + # Assert that we called our 'header_begin' callbakc three times; once + # for each header in the multipart message. + self.assertEqual(len(calls), 3) + class TestHelperFunctions(unittest.TestCase): def test_create_form_parser(self):