diff --git a/doc/source/changelog.rst b/doc/source/changelog.rst index 01ec609..1bae689 100644 --- a/doc/source/changelog.rst +++ b/doc/source/changelog.rst @@ -14,10 +14,10 @@ Changelog (:pr:`78`) `Guido Imperiale`_ - ``LMDB`` now uses memory-mapped I/O on MacOSX and is usable on Windows. (:pr:`78`) `Guido Imperiale`_ -- The library is now partially thread-safe. - (:pr:`82`, :pr:`90`) `Guido Imperiale`_ +- The library is now almost completely thread-safe. + (:pr:`82`, :pr:`90`, :pr:`92`) `Guido Imperiale`_ - :class:`LRU` and :class:`Buffer` now support delayed eviction. - New objects :class:`Accumulator` and :class:`InsertionSortedSet`. + New object :class:`InsertionSortedSet`. (:pr:`87`) `Guido Imperiale`_ diff --git a/doc/source/index.rst b/doc/source/index.rst index eb38de0..428bcb0 100644 --- a/doc/source/index.rst +++ b/doc/source/index.rst @@ -36,8 +36,8 @@ zlib-compressed, directory of files. Thread-safety ------------- -This library is only partially thread-safe. Refer to the documentation of the individual -mappings for details. +Most classes in this library are thread-safe. +Refer to the documentation of the individual mappings for exceptions. API --- @@ -64,8 +64,6 @@ API Additionally, **zict** makes available the following general-purpose objects: -.. autoclass:: Accumulator - :members: .. autoclass:: InsertionSortedSet :members: .. autoclass:: WeakValueMapping diff --git a/zict/__init__.py b/zict/__init__.py index cbcdbd4..a13f0d1 100644 --- a/zict/__init__.py +++ b/zict/__init__.py @@ -6,7 +6,6 @@ from zict.lmdb import LMDB as LMDB from zict.lru import LRU as LRU from zict.sieve import Sieve as Sieve -from zict.utils import Accumulator as Accumulator from zict.utils import InsertionSortedSet as InsertionSortedSet from zict.zip import Zip as Zip diff --git a/zict/buffer.py b/zict/buffer.py index 140aabb..cf12186 100644 --- a/zict/buffer.py +++ b/zict/buffer.py @@ -3,7 +3,7 @@ from collections.abc import Callable, Iterator, MutableMapping from itertools import chain -from zict.common import KT, VT, ZictBase, close, flush +from zict.common import KT, VT, ZictBase, close, discard, flush, locked from zict.lru import LRU @@ -33,8 +33,10 @@ class Buffer(ZictBase[KT, VT]): Notes ----- - ``__contains__`` and ``__len__`` are thread-safe if the same methods on both - ``fast`` and ``slow`` are thread-safe. All other methods are not thread-safe. + If you call methods of this class from multiple threads, access will be fast as long + as all methods of ``fast``, plus ``slow.__contains__`` and ``slow.__delitem__``, are + fast. ``slow.__getitem__``, and ``slow.__setitem__`` and callbacks are not protected + by locks. Examples -------- @@ -54,6 +56,7 @@ class Buffer(ZictBase[KT, VT]): weight: Callable[[KT, VT], float] fast_to_slow_callbacks: list[Callable[[KT, VT], None]] slow_to_fast_callbacks: list[Callable[[KT, VT], None]] + _cancel_restore: dict[KT, bool] def __init__( self, @@ -68,7 +71,14 @@ def __init__( | list[Callable[[KT, VT], None]] | None = None, ): - self.fast = LRU(n, fast, weight=weight, on_evict=[self.fast_to_slow]) + super().__init__() + self.fast = LRU( + n, + fast, + weight=weight, + on_evict=[self.fast_to_slow], + on_cancel_evict=[self._cancel_evict], + ) self.slow = slow self.weight = weight if callable(fast_to_slow_callbacks): @@ -77,6 +87,7 @@ def __init__( slow_to_fast_callbacks = [slow_to_fast_callbacks] self.fast_to_slow_callbacks = fast_to_slow_callbacks or [] self.slow_to_fast_callbacks = slow_to_fast_callbacks or [] + self._cancel_restore = {} @property def n(self) -> float: @@ -94,16 +105,38 @@ def fast_to_slow(self, key: KT, value: VT) -> None: raise def slow_to_fast(self, key: KT) -> VT: - value = self.slow[key] + self._cancel_restore[key] = False + try: + with self._unlock(): + value = self.slow[key] + if self._cancel_restore[key]: + raise KeyError(key) + finally: + self._cancel_restore.pop(key) + # Avoid useless movement for heavy values w = self.weight(key, value) if w <= self.n: + # Multithreaded edge case: + # - Thread 1 starts slow_to_fast(x) and puts it at the top of fast + # - This causes the eviction of older key(s) + # - While thread 1 is evicting older keys, thread 2 is loading fast with + # set_noevict() + # - By the time the eviction of the older key(s) is done, there is + # enough weight in fast that thread 1 will spill x + # - If the below code was just `self.fast[key] = value; del + # self.slow[key]` now the key would be in neither slow nor fast! + self.fast.set_noevict(key, value) del self.slow[key] - self.fast[key] = value - for cb in self.slow_to_fast_callbacks: - cb(key, value) + + with self._unlock(): + self.fast.evict_until_below_capacity() + for cb in self.slow_to_fast_callbacks: + cb(key, value) + return value + @locked def __getitem__(self, key: KT) -> VT: try: return self.fast[key] @@ -111,31 +144,35 @@ def __getitem__(self, key: KT) -> VT: return self.slow_to_fast(key) def __setitem__(self, key: KT, value: VT) -> None: - try: - del self.slow[key] - except KeyError: - pass - # This may trigger an eviction from fast to slow of older keys. - # If the weight is individually greater than n, then key/value will be stored - # into self.slow instead (see LRU.__setitem__). + with self._lock: + discard(self.slow, key) + if key in self._cancel_restore: + self._cancel_restore[key] = True self.fast[key] = value + @locked def set_noevict(self, key: KT, value: VT) -> None: """Variant of ``__setitem__`` that does not move keys from fast to slow if the total weight exceeds n """ - try: - del self.slow[key] - except KeyError: - pass + discard(self.slow, key) + if key in self._cancel_restore: + self._cancel_restore[key] = True self.fast.set_noevict(key, value) + @locked def __delitem__(self, key: KT) -> None: + if key in self._cancel_restore: + self._cancel_restore[key] = True try: del self.fast[key] except KeyError: del self.slow[key] + @locked + def _cancel_evict(self, key: KT, value: VT) -> None: + discard(self.slow, key) + # FIXME dictionary views https://github.com/dask/zict/issues/61 def keys(self) -> Iterator[KT]: # type: ignore return iter(self) @@ -147,7 +184,15 @@ def items(self) -> Iterator[tuple[KT, VT]]: # type: ignore return chain(self.fast.items(), self.slow.items()) def __len__(self) -> int: - return len(self.fast) + len(self.slow) + with self._lock, self.fast._lock: + return ( + len(self.fast) + + len(self.slow) + - sum( + k in self.fast and k in self.slow + for k in chain(self._cancel_restore, self.fast._cancel_evict) + ) + ) def __iter__(self) -> Iterator[KT]: return chain(self.fast, self.slow) diff --git a/zict/cache.py b/zict/cache.py index 18f4363..7edcbf8 100644 --- a/zict/cache.py +++ b/zict/cache.py @@ -4,7 +4,7 @@ from collections.abc import Iterator, KeysView, MutableMapping from typing import TYPE_CHECKING -from zict.common import KT, VT, ZictBase, close, flush +from zict.common import KT, VT, ZictBase, close, discard, flush class Cache(ZictBase[KT, VT]): @@ -22,14 +22,6 @@ class Cache(ZictBase[KT, VT]): If True (default), the cache will be updated both when writing and reading. If False, update the cache when reading, but just invalidate it when writing. - Notes - ----- - All methods are thread-safe if all methods on both ``data`` and ``cache`` are - thread-safe; however, only one thread can call ``__setitem__`` and ``__delitem__`` - at any given time. - ``__contains__`` and ``__len__`` are thread-safe if the same methods on ``data`` are - thread-safe. - Examples -------- Keep the latest 100 accessed values in memory @@ -51,6 +43,7 @@ def __init__( cache: MutableMapping[KT, VT], update_on_set: bool = True, ): + super().__init__() self.data = data self.cache = cache self.update_on_set = update_on_set @@ -68,20 +61,14 @@ def __setitem__(self, key: KT, value: VT) -> None: # If the item was already in cache and data.__setitem__ fails, e.g. because it's # a File and the disk is full, make sure that the cache is invalidated. # FIXME https://github.com/python/mypy/issues/10152 - try: - del self.cache[key] - except KeyError: - pass + discard(self.cache, key) self.data[key] = value if self.update_on_set: self.cache[key] = value def __delitem__(self, key: KT) -> None: - try: - del self.cache[key] - except KeyError: - pass + discard(self.cache, key) del self.data[key] def __len__(self) -> int: diff --git a/zict/common.py b/zict/common.py index 1907e0e..63bf8db 100644 --- a/zict/common.py +++ b/zict/common.py @@ -1,18 +1,24 @@ from __future__ import annotations -from collections.abc import Iterable, Mapping +import threading +from collections.abc import Callable, Iterable, Iterator, Mapping +from contextlib import contextmanager from enum import Enum +from functools import wraps from itertools import chain from typing import MutableMapping # TODO move to collections.abc (needs Python >=3.9) -from typing import TYPE_CHECKING, Any, TypeVar +from typing import TYPE_CHECKING, Any, TypeVar, cast T = TypeVar("T") KT = TypeVar("KT") VT = TypeVar("VT") if TYPE_CHECKING: + # TODO import from typing (needs Python >=3.10) # TODO import from typing (needs Python >=3.11) - from typing_extensions import Self + from typing_extensions import ParamSpec, Self + + P = ParamSpec("P") class NoDefault(Enum): @@ -25,6 +31,20 @@ class NoDefault(Enum): class ZictBase(MutableMapping[KT, VT]): """Base class for zict mappings""" + _lock: threading.RLock + + def __init__(self) -> None: + self._lock = threading.RLock() + + def __getstate__(self) -> dict[str, Any]: + state = self.__dict__.copy() + del state["_lock"] + return state + + def __setstate__(self, state: dict[str, Any]) -> None: + self.__dict__ = state + self._lock = threading.RLock() + def update( # type: ignore[override] self, other: Mapping[KT, VT] | Iterable[tuple[KT, VT]] = (), @@ -41,6 +61,12 @@ def _do_update(self, items: Iterable[tuple[KT, VT]]) -> None: for k, v in items: self[k] = v + def discard(self, key: KT) -> None: + """Flush *key* if possible. + Not the same as ``m.pop(key, None)``, as it doesn't trigger ``__getitem__``. + """ + discard(self, key) + def close(self) -> None: """Release any system resources held by this object""" @@ -53,6 +79,17 @@ def __exit__(self, *args: Any) -> None: def __del__(self) -> None: self.close() + @contextmanager + def _unlock(self) -> Iterator[None]: + """To be used in a method decorated by ``@locked``. + Temporarily releases the mapping's RLock. + """ + self._lock.release() + try: + yield + finally: + self._lock.acquire() + def close(*z: Any) -> None: """Close *z* if possible.""" @@ -66,3 +103,27 @@ def flush(*z: Any) -> None: for zi in z: if hasattr(zi, "flush"): zi.flush() + + +def discard(m: MutableMapping[KT, VT], key: KT) -> None: + """Flush *key* if possible. + Not the same as ``m.pop(key, None)``, as it doesn't trigger ``__getitem__``. + """ + try: + del m[key] + except KeyError: + pass + + +def locked(func: Callable[P, VT]) -> Callable[P, VT]: + """Decorator for a method of ZictBase, which wraps the whole method in a + mapping-global rlock. + """ + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> VT: + self = cast(ZictBase, args[0]) + with self._lock: + return func(*args, **kwargs) + + return wrapper diff --git a/zict/file.py b/zict/file.py index 2c419e2..503e193 100644 --- a/zict/file.py +++ b/zict/file.py @@ -6,7 +6,7 @@ from collections.abc import Iterator, KeysView from urllib.parse import quote, unquote -from zict.common import ZictBase +from zict.common import ZictBase, locked class File(ZictBase[str, bytes]): @@ -25,11 +25,6 @@ class File(ZictBase[str, bytes]): memmap: bool (optional) If True, use `mmap` for reading. Defaults to False. - Notes - ----- - This class is fully thread-safe, with only one caveat: you can't have two calls to - ``__setitem__``, on the same key, at the same time from two different threads. - Examples -------- >>> z = File('myfile') # doctest: +SKIP @@ -54,6 +49,7 @@ class File(ZictBase[str, bytes]): _inc: int def __init__(self, directory: str | pathlib.Path, memmap: bool = False): + super().__init__() self.directory = str(directory) self.memmap = memmap self.filenames = {} @@ -89,6 +85,7 @@ def __str__(self) -> str: __repr__ = __str__ + @locked def __getitem__(self, key: str) -> bytearray | memoryview: fn = os.path.join(self.directory, self.filenames[key]) @@ -99,21 +96,19 @@ def __getitem__(self, key: str) -> bytearray | memoryview: # Note that this is a dask-specific feature; vanilla pickle.loads will instead # return an array with flags.writeable=False. - try: - if self.memmap: - with open(fn, "r+b") as fh: - return memoryview(mmap.mmap(fh.fileno(), 0)) - else: - with open(fn, "rb") as fh: - size = os.fstat(fh.fileno()).st_size - buf = bytearray(size) + if self.memmap: + with open(fn, "r+b") as fh: + return memoryview(mmap.mmap(fh.fileno(), 0)) + else: + with open(fn, "rb") as fh: + size = os.fstat(fh.fileno()).st_size + buf = bytearray(size) + with self._unlock(): nread = fh.readinto(buf) - assert nread == size - return buf - - except FileNotFoundError: # pragma: nocover - raise KeyError(key) # Race condition with __setitem__ or __delitem__ + assert nread == size + return buf + @locked def __setitem__( self, key: str, @@ -123,13 +118,9 @@ def __setitem__( | list[bytes | bytearray | memoryview] | tuple[bytes | bytearray | memoryview, ...], ) -> None: - try: - del self[key] - except KeyError: - pass - + self.discard(key) fn = self._safe_key(key) - with open(os.path.join(self.directory, fn), "wb") as fh: + with open(os.path.join(self.directory, fn), "wb") as fh, self._unlock(): if isinstance(value, (tuple, list)): fh.writelines(value) else: @@ -145,12 +136,10 @@ def keys(self) -> KeysView[str]: def __iter__(self) -> Iterator[str]: return iter(self.filenames) + @locked def __delitem__(self, key: str) -> None: fn = self.filenames.pop(key) - try: - os.remove(os.path.join(self.directory, fn)) - except FileNotFoundError: # pragma: nocover - raise KeyError(key) # Race condition with __setitem__ or __delitem__ + os.remove(os.path.join(self.directory, fn)) def __len__(self) -> int: return len(self.filenames) diff --git a/zict/func.py b/zict/func.py index 8b1d512..aea70c3 100644 --- a/zict/func.py +++ b/zict/func.py @@ -55,6 +55,7 @@ def __init__( load: Callable[[WT], VT], d: MutableMapping[KT, WT], ): + super().__init__() self.dump = dump self.load = load self.d = d diff --git a/zict/lmdb.py b/zict/lmdb.py index 2a12d74..81f7933 100644 --- a/zict/lmdb.py +++ b/zict/lmdb.py @@ -46,6 +46,7 @@ class LMDB(ZictBase[str, bytes]): def __init__(self, directory: str | pathlib.Path, map_size: int | None = None): import lmdb + super().__init__() if map_size is None: if sys.platform != "win32": map_size = min(2**40, sys.maxsize // 4) diff --git a/zict/lru.py b/zict/lru.py index 4115b53..2c09133 100644 --- a/zict/lru.py +++ b/zict/lru.py @@ -9,8 +9,8 @@ ValuesView, ) -from zict.common import KT, VT, NoDefault, ZictBase, close, flush, nodefault -from zict.utils import Accumulator, InsertionSortedSet +from zict.common import KT, VT, NoDefault, ZictBase, close, flush, locked, nodefault +from zict.utils import InsertionSortedSet class LRU(ZictBase[KT, VT]): @@ -23,23 +23,24 @@ class LRU(ZictBase[KT, VT]): d: MutableMapping Dict-like in which to hold elements. There are no expectations on its internal ordering. Iteration on the LRU follows the order of the underlying mapping. - on_evict: list of callables - Function:: k, v -> action to call on key value pairs prior to eviction + on_evict: callable or list of callables + Function:: k, v -> action to call on key/value pairs prior to eviction If an exception occurs during an on_evict callback (e.g a callback tried storing to disk and raised a disk full error) the key will remain in the LRU. + on_cancel_evict: callable or list of callables + Function:: k, v -> action to call on key/value pairs if they're deleted or + updated from a thread while the on_evict callables are being executed in + another. weight: callable Function:: k, v -> number to determine the size of keeping the item in the mapping. Defaults to ``(k, v) -> 1`` Notes ----- - Most methods are thread-safe if the same methods on ``d`` are thread-safe. - ``__setitem__``, ``__delitem__``, :meth:`evict`, and - :meth:`evict_until_below_capacity` also require all callables in ``on_evict`` to be - thread-safe and should not be called from different threads for the same - key. It's OK to set/delete different keys from different threads, it's OK to set a - key in a thread and read it from many other threads, but it's not OK to set/delete - the same key from different threads at the same time. + If you call methods of this class from multiple threads, access will be fast as long + as all methods of ``d`` are fast. Callbacks are not protected by locks and can be + arbitrarily slow. + Examples -------- @@ -54,38 +55,50 @@ class LRU(ZictBase[KT, VT]): order: InsertionSortedSet[KT] heavy: InsertionSortedSet[KT] on_evict: list[Callable[[KT, VT], None]] + on_cancel_evict: list[Callable[[KT, VT], None]] weight: Callable[[KT, VT], float] n: float weights: dict[KT, float] closed: bool - total_weight: Accumulator + total_weight: float + _cancel_evict: dict[KT, bool] def __init__( self, n: float, d: MutableMapping[KT, VT], + *, on_evict: Callable[[KT, VT], None] | list[Callable[[KT, VT], None]] | None = None, + on_cancel_evict: Callable[[KT, VT], None] + | list[Callable[[KT, VT], None]] + | None = None, weight: Callable[[KT, VT], float] = lambda k, v: 1, ): + super().__init__() self.d = d self.n = n + if callable(on_evict): on_evict = [on_evict] self.on_evict = on_evict or [] + if callable(on_cancel_evict): + on_cancel_evict = [on_cancel_evict] + self.on_cancel_evict = on_cancel_evict or [] + self.weight = weight self.weights = {k: weight(k, v) for k, v in d.items()} - self.total_weight = Accumulator(sum(self.weights.values())) + self.total_weight = sum(self.weights.values()) self.order = InsertionSortedSet(d) self.heavy = InsertionSortedSet(k for k, v in self.weights.items() if v >= n) self.closed = False + self._cancel_evict = {} + @locked def __getitem__(self, key: KT) -> VT: result = self.d[key] - # Don't use .remove() to prevent race condition which can happen during - # multithreaded access - self.order.discard(key) + self.order.remove(key) self.order.add(key) return result @@ -104,18 +117,17 @@ def __setitem__(self, key: KT, value: VT) -> None: pass raise + @locked def set_noevict(self, key: KT, value: VT) -> None: """Variant of ``__setitem__`` that does not evict if the total weight exceeds n. Unlike ``__setitem__``, this method does not depend on the ``on_evict`` functions to be thread-safe for its own thread-safety. It also is not prone to re-raising exceptions from the ``on_evict`` callbacks. """ - try: - del self[key] - except KeyError: - pass - + self.discard(key) weight = self.weight(key, value) + if key in self._cancel_evict: + self._cancel_evict[key] = True self.d[key] = value self.order.add(key) if weight > self.n: @@ -128,7 +140,8 @@ def evict_until_below_capacity(self) -> None: while self.total_weight > self.n and not self.closed: self.evict() - def evict(self, key: KT | NoDefault = nodefault) -> tuple[KT, VT, float]: + @locked + def evict(self, key: KT | NoDefault = nodefault) -> tuple[KT, VT, float] | None: """Evict least recently used key, or least recently inserted key with individual weight > n, if any. You may also evict a specific key. @@ -138,43 +151,58 @@ def evict(self, key: KT | NoDefault = nodefault) -> tuple[KT, VT, float]: Returns ------- Tuple of (key, value, weight) + + Or None if the key that was being evicted was updated or deleted from another + thread while the on_evict callbacks were being executed. + This outcome is only possible in multithreaded access. """ # For the purpose of multithreaded access, it's important that the value remains # in self.d until all callbacks are successful. # When this is used inside a Buffer, there must never be a moment when the key # is neither in fast nor in slow. if key is nodefault: - while True: - try: - key = next(iter(self.heavy or self.order)) - value = self.d[key] - break - except StopIteration: - raise KeyError("evict(): dictionary is empty") - except (KeyError, RuntimeError): # pragma: nocover - pass # Race condition caused by multithreading - else: - value = self.d[key] + try: + key = next(iter(self.heavy or self.order)) + except StopIteration: + raise KeyError("evict(): dictionary is empty") + + value = self.d[key] # If we are evicting a heavy key we just inserted and one of the callbacks # fails, put it at the bottom of the LRU instead of the top. This way lighter # keys will have a chance to be evicted first and make space. self.heavy.discard(key) - # This may raise; e.g. if a callback tries storing to a full disk - for cb in self.on_evict: - cb(key, value) + self._cancel_evict[key] = False + try: + with self._unlock(): + # This may raise; e.g. if a callback tries storing to a full disk + for cb in self.on_evict: + cb(key, value) + + if self._cancel_evict[key]: + for cb in self.on_cancel_evict: + cb(key, value) + return None + finally: + self._cancel_evict.pop(key) - self.d.pop(key, None) # type: ignore[arg-type] - self.order.discard(key) + try: + del self.d[key] + except KeyError: + return None + self.order.remove(key) weight = self.weights.pop(key) self.total_weight -= weight return key, value, weight + @locked def __delitem__(self, key: KT) -> None: + if key in self._cancel_evict: + self._cancel_evict[key] = True del self.d[key] - self.order.discard(key) + self.order.remove(key) self.heavy.discard(key) self.total_weight -= self.weights.pop(key) diff --git a/zict/sieve.py b/zict/sieve.py index 78f20c0..f56182d 100644 --- a/zict/sieve.py +++ b/zict/sieve.py @@ -5,7 +5,7 @@ from itertools import chain from typing import Generic, TypeVar -from zict.common import KT, VT, ZictBase, close, flush +from zict.common import KT, VT, ZictBase, close, flush, locked MKT = TypeVar("MKT") @@ -23,12 +23,6 @@ class Sieve(ZictBase[KT, VT], Generic[KT, VT, MKT]): mappings: dict of {mapping key: MutableMapping} selector: callable (key, value) -> mapping key - Notes - ----- - ``__contains__`` is thread-safe. - ``__len__`` is thread-safe if the same method on all mappings is thread-safe. - All other methods are not thread-safe. - Examples -------- >>> small = {} @@ -37,10 +31,6 @@ class Sieve(ZictBase[KT, VT], Generic[KT, VT, MKT]): >>> def is_small(key, value): # doctest: +SKIP ... return sys.getsizeof(value) < 10000 # doctest: +SKIP >>> d = Sieve(mappings, is_small) # doctest: +SKIP - - See Also - -------- - Buffer """ mappings: Mapping[MKT, MutableMapping[KT, VT]] @@ -52,6 +42,7 @@ def __init__( mappings: Mapping[MKT, MutableMapping[KT, VT]], selector: Callable[[KT, VT], MKT], ): + super().__init__() self.mappings = mappings self.selector = selector self.key_to_mapping = {} @@ -60,37 +51,40 @@ def __getitem__(self, key: KT) -> VT: return self.key_to_mapping[key][key] def __setitem__(self, key: KT, value: VT) -> None: - old_mapping = self.key_to_mapping.get(key) - mkey = self.selector(key, value) - mapping = self.mappings[mkey] - if old_mapping is not None and old_mapping is not mapping: - del old_mapping[key] + with self._lock: + old_mapping = self.key_to_mapping.get(key) + mkey = self.selector(key, value) + mapping = self.mappings[mkey] + if old_mapping is not None and old_mapping is not mapping: + del old_mapping[key] + self.key_to_mapping[key] = mapping + mapping[key] = value - self.key_to_mapping[key] = mapping + @locked def __delitem__(self, key: KT) -> None: del self.key_to_mapping.pop(key)[key] def _do_update(self, items: Iterable[tuple[KT, VT]]) -> None: # Optimized update() implementation issuing a single update() # call per underlying mapping. - updates = defaultdict(list) - mapping_ids = {id(m): m for m in self.mappings.values()} - - for key, value in items: - old_mapping = self.key_to_mapping.get(key) - mkey = self.selector(key, value) - mapping = self.mappings[mkey] - if old_mapping is not None and old_mapping is not mapping: - del old_mapping[key] - # Can't hash a mutable mapping, so use its id() instead - updates[id(mapping)].append((key, value)) + with self._lock: + updates = defaultdict(list) + mapping_ids = {id(m): m for m in self.mappings.values()} + + for key, value in items: + old_mapping = self.key_to_mapping.get(key) + mkey = self.selector(key, value) + mapping = self.mappings[mkey] + if old_mapping is not None and old_mapping is not mapping: + del old_mapping[key] + # Can't hash a mutable mapping, so use its id() instead + updates[id(mapping)].append((key, value)) + self.key_to_mapping[key] = mapping for mid, mitems in updates.items(): mapping = mapping_ids[mid] mapping.update(mitems) - for key, _ in mitems: - self.key_to_mapping[key] = mapping # FIXME dictionary views https://github.com/dask/zict/issues/61 def keys(self) -> Iterator[KT]: # type: ignore @@ -102,6 +96,7 @@ def values(self) -> Iterator[VT]: # type: ignore def items(self) -> Iterator[tuple[KT, VT]]: # type: ignore return chain.from_iterable(m.items() for m in self.mappings.values()) + @locked def __len__(self) -> int: return sum(map(len, self.mappings.values())) diff --git a/zict/tests/test_buffer.py b/zict/tests/test_buffer.py index 7ba346a..9fa411b 100644 --- a/zict/tests/test_buffer.py +++ b/zict/tests/test_buffer.py @@ -185,7 +185,7 @@ def s2f_cb(k, v): assert b == {"x": 1} # Add key > n, again total weight > n this will move everything to slow except w - # that stays in fast due after callback raise + # that stays in fast due to callback raising with pytest.raises(MyError): buff["w"] = 11 diff --git a/zict/tests/test_common.py b/zict/tests/test_common.py index e10f7b5..938232a 100644 --- a/zict/tests/test_common.py +++ b/zict/tests/test_common.py @@ -1,12 +1,12 @@ -from collections import UserDict +import pickle -from zict.common import ZictBase +from zict.tests.utils_test import SimpleDict def test_close_on_del(): closed = False - class D(ZictBase, UserDict): + class D(SimpleDict): def close(self): nonlocal closed closed = True @@ -19,7 +19,7 @@ def close(self): def test_context(): closed = False - class D(ZictBase, UserDict): + class D(SimpleDict): def close(self): nonlocal closed closed = True @@ -33,7 +33,7 @@ def close(self): def test_update(): items = [] - class D(ZictBase, UserDict): + class D(SimpleDict): def _do_update(self, items_): nonlocal items items = items_ @@ -51,3 +51,23 @@ def _do_update(self, items_): # Special kwargs can't overwrite positional-only parameters d.update(self=1, other=2) assert list(items) == [("self", 1), ("other", 2)] + + +def test_discard(): + class D(SimpleDict): + def __getitem__(self, key): + raise AssertionError() + + d = D() + d["x"] = 1 + d["z"] = 2 + d.discard("x") + d.discard("y") + assert d.data == {"z": 2} + + +def test_pickle(): + d = SimpleDict() + d["x"] = 1 + d2 = pickle.loads(pickle.dumps(d)) + assert d2.data == {"x": 1} diff --git a/zict/tests/test_lru.py b/zict/tests/test_lru.py index 2cfe323..75a5391 100644 --- a/zict/tests/test_lru.py +++ b/zict/tests/test_lru.py @@ -1,11 +1,9 @@ -from collections import UserDict from concurrent.futures import ThreadPoolExecutor from threading import Barrier import pytest from zict import LRU -from zict.common import ZictBase from zict.tests import utils_test @@ -283,7 +281,7 @@ def test_getitem_is_threasafe(): def f(_): barrier.wait() - for _ in range(5_000_000): + for _ in range(500_000): assert lru["x"] == 1 barrier = Barrier(2) @@ -316,7 +314,7 @@ def test_flush_close(): flushed = 0 closed = False - class D(ZictBase, UserDict): + class D(utils_test.SimpleDict): def flush(self): nonlocal flushed flushed += 1 diff --git a/zict/tests/test_utils.py b/zict/tests/test_utils.py index 434afb2..a34757d 100644 --- a/zict/tests/test_utils.py +++ b/zict/tests/test_utils.py @@ -3,8 +3,7 @@ import pytest -from zict import Accumulator, InsertionSortedSet -from zict.utils import ATOMIC_INT_IADD +from zict import InsertionSortedSet def test_insertion_sorted_set(): @@ -108,66 +107,3 @@ def t(): # On Windows, we've seen as little as 2300. assert f1.result() > 100 assert f2.result() > 100 - - -def test_accumulator(): - acc = Accumulator() - assert acc == 0 - acc = Accumulator(123) - assert acc == 123 - assert repr(acc) == "123" - acc += 1 - assert acc == 124 - acc -= 1 - assert acc == 123 - acc += 0.5 - assert acc == 123.5 - - # Test operators - assert int(acc) == 123 - assert float(acc) == 123.5 - assert not acc != 123.5 - assert acc >= 123.5 - assert not acc >= 124 - assert acc > 123 - assert not acc > 123.5 - assert acc <= 123.5 - assert not acc <= 123 - assert acc < 124 - assert not acc < 123 - assert acc + 1 == 124.5 - assert acc - 1 == 122.5 - assert acc * 2 == 247 - assert acc / 2 == 61.75 - assert hash(acc) == hash(123.5) - - -@pytest.mark.parametrize("dtype", [int, float]) -def test_accumulator_threadsafe(dtype): - acc = Accumulator(dtype(2)) - if ATOMIC_INT_IADD: - # CPython >= 3.10 - assert isinstance(acc, dtype) - N = 10_000_000 - expect = 99999970000002 - else: - assert isinstance(acc, Accumulator) - N = 1_000_000 - expect = 999997000002 - - barrier = Barrier(2) - - def t(): - nonlocal acc - barrier.wait() - for i in range(N): - acc += i - acc -= 1 - assert acc >= 0 - - with ThreadPoolExecutor(2) as ex: - f1 = ex.submit(t) - f2 = ex.submit(t) - f1.result() - f2.result() - assert acc == expect diff --git a/zict/tests/utils_test.py b/zict/tests/utils_test.py index 549414b..55de5d2 100644 --- a/zict/tests/utils_test.py +++ b/zict/tests/utils_test.py @@ -1,9 +1,12 @@ import random import string +from collections import UserDict from collections.abc import MutableMapping import pytest +from zict.common import ZictBase + def generate_random_strings(n, min_len, max_len): r = random.Random(42) @@ -115,3 +118,9 @@ def check_mapping(z): def check_closing(z): z.close() + + +class SimpleDict(ZictBase, UserDict): + def __init__(self): + ZictBase.__init__(self) + UserDict.__init__(self) diff --git a/zict/utils.py b/zict/utils.py index 4f2033d..438310b 100644 --- a/zict/utils.py +++ b/zict/utils.py @@ -1,11 +1,6 @@ from __future__ import annotations -import platform -import sys -import threading -from collections import defaultdict from collections.abc import Iterable, Iterator -from numbers import Number from typing import MutableSet # TODO import from collections.abc (needs Python >=3.9) from zict.common import T @@ -70,90 +65,3 @@ def popright(self) -> T: def clear(self) -> None: self._d.clear() - - -ATOMIC_INT_IADD = ( - platform.python_implementation() == "CPython" and sys.version_info >= (3, 10) -) - - -class Accumulator(Number): - """A lockless thread-safe accumulator""" - - _values: defaultdict[int, float] - __slots__ = ("_values",) - - def __new__(cls, value: float = 0) -> Accumulator: - if ATOMIC_INT_IADD: - # int.__iadd__ and float.__iadd__ are GIL-atomic starting from CPython 3.10. - # We can get rid of the whole class and just use them instead. - # This is an implementation detail. - return value # type: ignore[return-value] - - self = object.__new__(cls) - # Don't return float unless you actually added floats. - # This behaviour is consistent with sum(). - self._values = defaultdict(int) - self._values[threading.get_ident()] = value - return self - - def _value(self) -> float: - """Return accumulator total across all threads. - The return type is float if any float elements were added, otherwise it's int. - """ - while True: - try: - return sum(self._values.values()) - except RuntimeError: # dictionary changed size during iteration - pass # pragma: nocover - - def __iadd__(self, other: float) -> Accumulator: - self._values[threading.get_ident()] += other - return self - - def __isub__(self, other: float) -> Accumulator: - self._values[threading.get_ident()] -= other - return self - - # Trivial wrappers around self._value(). - # Since they are magic methods, they can't be implemented with __getattr__ - # or with accessor classes. - - def __repr__(self) -> str: - return repr(self._value()) - - def __int__(self) -> int: - return int(self._value()) - - def __float__(self) -> float: - return float(self._value()) - - def __eq__(self, other: object) -> bool: - return self._value() == other - - def __gt__(self, other: float) -> bool: - return self._value() > other - - def __ge__(self, other: float) -> bool: - return self._value() >= other - - def __lt__(self, other: float) -> bool: - return self._value() < other - - def __le__(self, other: float) -> bool: - return self._value() <= other - - def __add__(self, other: float) -> float: - return self._value() + other - - def __sub__(self, other: float) -> float: - return self._value() - other - - def __mul__(self, other: float) -> float: - return self._value() * other - - def __truediv__(self, other: float) -> float: - return self._value() / other - - def __hash__(self) -> int: - return hash(self._value()) diff --git a/zict/zip.py b/zict/zip.py index 9ae73c3..486ca95 100644 --- a/zict/zip.py +++ b/zict/zip.py @@ -41,6 +41,7 @@ class Zip(MutableMapping[str, bytes]): _file: zipfile.ZipFile | None def __init__(self, filename: str, mode: FileMode = "a"): + super().__init__() self.filename = filename self.mode = mode self._file = None