Skip to content

Commit

Permalink
Create object of cache. Use specific imports
Browse files Browse the repository at this point in the history
  • Loading branch information
eltbus committed Jan 21, 2024
1 parent 83f5914 commit dbd31f2
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 18 deletions.
60 changes: 43 additions & 17 deletions multipart/multipart.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .decoders import *
from .exceptions import *
from .decoders import Base64Decoder, QuotedPrintableDecoder
from .exceptions import FormParserError, MultipartParseError, QuerystringParseError, FileError

import os
import re
Expand All @@ -9,10 +9,9 @@
import tempfile
from io import BytesIO
from numbers import Number
from typing import overload, Dict, List, Optional, Tuple, Union
from typing import overload, Dict, Generic, List, Optional, Tuple, TypeVar, Union

# Unique missing object.
_missing = object()
T = TypeVar("T")

# States for the querystring parser.
STATE_BEFORE_FIELD = 0
Expand Down Expand Up @@ -124,6 +123,33 @@ def parse_options_header(value):
return ctype, options


class Cache(Generic[T]):
def __init__(self):
self._value: Optional[T] = None
self._is_set: bool = False

@property
def value(self) -> Optional[T]:
if self.is_set:
return self._value
else:
raise ValueError("Value not yet set")

@value.setter
def value(self, v: T):
self._value = v
self._is_set = True

def clear(self):
"""Reset the cache"""
self._value = None
self._is_set = False

def is_set(self) -> bool:
"""Check if value has been set"""
return self._is_set


class Field:
"""A Field object represents a (parsed) form field. It represents a single
field with a corresponding name and value.
Expand All @@ -144,7 +170,7 @@ def __init__(self, name: Union[bytes, str]):
self._value: List[bytes] = []

# We cache the joined version of _value for speed.
self._cache = _missing
self._cache = Cache()

@classmethod
def from_value(cls, name: Union[bytes, str], value: Optional[bytes]):
Expand Down Expand Up @@ -179,14 +205,14 @@ def on_data(self, data: bytes) -> int:
:param data: a bytestring
"""
self._value.append(data)
self._cache = _missing
self._cache.clear()
return len(data)

def on_end(self):
"""This method is called whenever the Field is finalized.
"""
if self._cache is _missing:
self._cache = b''.join(self._value)
if not self._cache.is_set():
self._cache.value = b''.join(self._value)

def finalize(self):
"""Finalize the form field.
Expand All @@ -197,8 +223,8 @@ def close(self):
"""Close the Field object. This will free any underlying cache.
"""
# Free our value array.
if self._cache is _missing:
self._cache = b''.join(self._value)
if not self._cache.is_set():
self._cache.value = b''.join(self._value)

del self._value

Expand All @@ -209,20 +235,20 @@ def set_none(self):
with name "baz" and value "asdf". Since the write() interface doesn't
support writing None, this function will set the field value to None.
"""
self._cache = None
self._cache.value = None

@property
def field_name(self):
"""This property returns the name of the field."""
return self._name

@property
def value(self):
def value(self) -> Optional[bytes]:
"""This property returns the value of the form field."""
if self._cache is _missing:
self._cache = b''.join(self._value)
if not self._cache.is_set():
self._cache.value = b''.join(self._value)

return self._cache
return self._cache.value

def __eq__(self, other):
if isinstance(other, Field):
Expand All @@ -234,7 +260,7 @@ def __eq__(self, other):
return NotImplemented

def __repr__(self):
if len(self.value) > 97:
if self.value is not None and len(self.value) > 97:
# We get the repr, and then insert three dots before the final
# quote.
v = repr(self.value[:97])[:-1] + "...'"
Expand Down
21 changes: 20 additions & 1 deletion tests/test_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,26 @@
from io import BytesIO
from unittest.mock import Mock

from multipart.multipart import *
from multipart.multipart import (
BaseParser,
Field,
File,
FormParser,
MultipartParser,
OctetStreamParser,
QuerystringParser,
create_form_parser,
parse_form,
parse_options_header,
)
from multipart.decoders import Base64Decoder, QuotedPrintableDecoder
from multipart.exceptions import (
DecodeError,
FileError,
FormParserError,
MultipartParseError,
QuerystringParseError,
)


# Get the current directory for our later test cases.
Expand Down

0 comments on commit dbd31f2

Please sign in to comment.