Skip to content

Commit

Permalink
Lock-based thread safety
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Mar 25, 2023
1 parent cde32b0 commit ee658a2
Show file tree
Hide file tree
Showing 18 changed files with 289 additions and 313 deletions.
6 changes: 3 additions & 3 deletions doc/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ Changelog
(:pr:`78`) `Guido Imperiale`_
- ``LMDB`` now uses memory-mapped I/O on MacOSX and is usable on Windows.
(:pr:`78`) `Guido Imperiale`_
- The library is now partially thread-safe.
(:pr:`82`, :pr:`90`) `Guido Imperiale`_
- The library is now almost completely thread-safe.
(:pr:`82`, :pr:`90`, :pr:`92`) `Guido Imperiale`_
- :class:`LRU` and :class:`Buffer` now support delayed eviction.
New objects :class:`Accumulator` and :class:`InsertionSortedSet`.
New object :class:`InsertionSortedSet`.
(:pr:`87`) `Guido Imperiale`_


Expand Down
6 changes: 2 additions & 4 deletions doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ zlib-compressed, directory of files.
Thread-safety
-------------
This library is only partially thread-safe. Refer to the documentation of the individual
mappings for details.
Most classes in this library are thread-safe.
Refer to the documentation of the individual mappings for exceptions.

API
---
Expand All @@ -64,8 +64,6 @@ API

Additionally, **zict** makes available the following general-purpose objects:

.. autoclass:: Accumulator
:members:
.. autoclass:: InsertionSortedSet
:members:
.. autoclass:: WeakValueMapping
Expand Down
1 change: 0 additions & 1 deletion zict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from zict.lmdb import LMDB as LMDB
from zict.lru import LRU as LRU
from zict.sieve import Sieve as Sieve
from zict.utils import Accumulator as Accumulator
from zict.utils import InsertionSortedSet as InsertionSortedSet
from zict.zip import Zip as Zip

Expand Down
85 changes: 65 additions & 20 deletions zict/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections.abc import Callable, Iterator, MutableMapping
from itertools import chain

from zict.common import KT, VT, ZictBase, close, flush
from zict.common import KT, VT, ZictBase, close, discard, flush, locked
from zict.lru import LRU


Expand Down Expand Up @@ -33,8 +33,10 @@ class Buffer(ZictBase[KT, VT]):
Notes
-----
``__contains__`` and ``__len__`` are thread-safe if the same methods on both
``fast`` and ``slow`` are thread-safe. All other methods are not thread-safe.
If you call methods of this class from multiple threads, access will be fast as long
as all methods of ``fast``, plus ``slow.__contains__`` and ``slow.__delitem__``, are
fast. ``slow.__getitem__``, and ``slow.__setitem__`` and callbacks are not protected
by locks.
Examples
--------
Expand All @@ -54,6 +56,7 @@ class Buffer(ZictBase[KT, VT]):
weight: Callable[[KT, VT], float]
fast_to_slow_callbacks: list[Callable[[KT, VT], None]]
slow_to_fast_callbacks: list[Callable[[KT, VT], None]]
_cancel_restore: dict[KT, bool]

def __init__(
self,
Expand All @@ -68,7 +71,14 @@ def __init__(
| list[Callable[[KT, VT], None]]
| None = None,
):
self.fast = LRU(n, fast, weight=weight, on_evict=[self.fast_to_slow])
super().__init__()
self.fast = LRU(
n,
fast,
weight=weight,
on_evict=[self.fast_to_slow],
on_cancel_evict=[self._cancel_evict],
)
self.slow = slow
self.weight = weight
if callable(fast_to_slow_callbacks):
Expand All @@ -77,6 +87,7 @@ def __init__(
slow_to_fast_callbacks = [slow_to_fast_callbacks]
self.fast_to_slow_callbacks = fast_to_slow_callbacks or []
self.slow_to_fast_callbacks = slow_to_fast_callbacks or []
self._cancel_restore = {}

@property
def n(self) -> float:
Expand All @@ -94,48 +105,74 @@ def fast_to_slow(self, key: KT, value: VT) -> None:
raise

def slow_to_fast(self, key: KT) -> VT:
value = self.slow[key]
self._cancel_restore[key] = False
try:
with self._unlock():
value = self.slow[key]
if self._cancel_restore[key]:
raise KeyError(key)
finally:
self._cancel_restore.pop(key)

# Avoid useless movement for heavy values
w = self.weight(key, value)
if w <= self.n:
# Multithreaded edge case:
# - Thread 1 starts slow_to_fast(x) and puts it at the top of fast
# - This causes the eviction of older key(s)
# - While thread 1 is evicting older keys, thread 2 is loading fast with
# set_noevict()
# - By the time the eviction of the older key(s) is done, there is
# enough weight in fast that thread 1 will spill x
# - If the below code was just `self.fast[key] = value; del
# self.slow[key]` now the key would be in neither slow nor fast!
self.fast.set_noevict(key, value)
del self.slow[key]
self.fast[key] = value
for cb in self.slow_to_fast_callbacks:
cb(key, value)

with self._unlock():
self.fast.evict_until_below_capacity()
for cb in self.slow_to_fast_callbacks:
cb(key, value)

return value

@locked
def __getitem__(self, key: KT) -> VT:
try:
return self.fast[key]
except KeyError:
return self.slow_to_fast(key)

def __setitem__(self, key: KT, value: VT) -> None:
try:
del self.slow[key]
except KeyError:
pass
# This may trigger an eviction from fast to slow of older keys.
# If the weight is individually greater than n, then key/value will be stored
# into self.slow instead (see LRU.__setitem__).
with self._lock:
discard(self.slow, key)
if key in self._cancel_restore:
self._cancel_restore[key] = True
self.fast[key] = value

@locked
def set_noevict(self, key: KT, value: VT) -> None:
"""Variant of ``__setitem__`` that does not move keys from fast to slow if the
total weight exceeds n
"""
try:
del self.slow[key]
except KeyError:
pass
discard(self.slow, key)
if key in self._cancel_restore:
self._cancel_restore[key] = True
self.fast.set_noevict(key, value)

@locked
def __delitem__(self, key: KT) -> None:
if key in self._cancel_restore:
self._cancel_restore[key] = True
try:
del self.fast[key]
except KeyError:
del self.slow[key]

@locked
def _cancel_evict(self, key: KT, value: VT) -> None:
discard(self.slow, key)

# FIXME dictionary views https://github.com/dask/zict/issues/61
def keys(self) -> Iterator[KT]: # type: ignore
return iter(self)
Expand All @@ -147,7 +184,15 @@ def items(self) -> Iterator[tuple[KT, VT]]: # type: ignore
return chain(self.fast.items(), self.slow.items())

def __len__(self) -> int:
return len(self.fast) + len(self.slow)
with self._lock, self.fast._lock:
return (
len(self.fast)
+ len(self.slow)
- sum(
k in self.fast and k in self.slow
for k in chain(self._cancel_restore, self.fast._cancel_evict)
)
)

def __iter__(self) -> Iterator[KT]:
return chain(self.fast, self.slow)
Expand Down
21 changes: 4 additions & 17 deletions zict/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Iterator, KeysView, MutableMapping
from typing import TYPE_CHECKING

from zict.common import KT, VT, ZictBase, close, flush
from zict.common import KT, VT, ZictBase, close, discard, flush


class Cache(ZictBase[KT, VT]):
Expand All @@ -22,14 +22,6 @@ class Cache(ZictBase[KT, VT]):
If True (default), the cache will be updated both when writing and reading.
If False, update the cache when reading, but just invalidate it when writing.
Notes
-----
All methods are thread-safe if all methods on both ``data`` and ``cache`` are
thread-safe; however, only one thread can call ``__setitem__`` and ``__delitem__``
at any given time.
``__contains__`` and ``__len__`` are thread-safe if the same methods on ``data`` are
thread-safe.
Examples
--------
Keep the latest 100 accessed values in memory
Expand All @@ -51,6 +43,7 @@ def __init__(
cache: MutableMapping[KT, VT],
update_on_set: bool = True,
):
super().__init__()
self.data = data
self.cache = cache
self.update_on_set = update_on_set
Expand All @@ -68,20 +61,14 @@ def __setitem__(self, key: KT, value: VT) -> None:
# If the item was already in cache and data.__setitem__ fails, e.g. because it's
# a File and the disk is full, make sure that the cache is invalidated.
# FIXME https://github.com/python/mypy/issues/10152
try:
del self.cache[key]
except KeyError:
pass
discard(self.cache, key)

self.data[key] = value
if self.update_on_set:
self.cache[key] = value

def __delitem__(self, key: KT) -> None:
try:
del self.cache[key]
except KeyError:
pass
discard(self.cache, key)
del self.data[key]

def __len__(self) -> int:
Expand Down
67 changes: 64 additions & 3 deletions zict/common.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping
import threading
from collections.abc import Callable, Iterable, Iterator, Mapping
from contextlib import contextmanager
from enum import Enum
from functools import wraps
from itertools import chain
from typing import MutableMapping # TODO move to collections.abc (needs Python >=3.9)
from typing import TYPE_CHECKING, Any, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar, cast

T = TypeVar("T")
KT = TypeVar("KT")
VT = TypeVar("VT")

if TYPE_CHECKING:
# TODO import from typing (needs Python >=3.10)
# TODO import from typing (needs Python >=3.11)
from typing_extensions import Self
from typing_extensions import ParamSpec, Self

P = ParamSpec("P")


class NoDefault(Enum):
Expand All @@ -25,6 +31,20 @@ class NoDefault(Enum):
class ZictBase(MutableMapping[KT, VT]):
"""Base class for zict mappings"""

_lock: threading.RLock

def __init__(self) -> None:
self._lock = threading.RLock()

def __getstate__(self) -> dict[str, Any]:
state = self.__dict__.copy()
del state["_lock"]
return state

def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__ = state
self._lock = threading.RLock()

def update( # type: ignore[override]
self,
other: Mapping[KT, VT] | Iterable[tuple[KT, VT]] = (),
Expand All @@ -41,6 +61,12 @@ def _do_update(self, items: Iterable[tuple[KT, VT]]) -> None:
for k, v in items:
self[k] = v

def discard(self, key: KT) -> None:
"""Flush *key* if possible.
Not the same as ``m.pop(key, None)``, as it doesn't trigger ``__getitem__``.
"""
discard(self, key)

def close(self) -> None:
"""Release any system resources held by this object"""

Expand All @@ -53,6 +79,17 @@ def __exit__(self, *args: Any) -> None:
def __del__(self) -> None:
self.close()

@contextmanager
def _unlock(self) -> Iterator[None]:
"""To be used in a method decorated by ``@locked``.
Temporarily releases the mapping's RLock.
"""
self._lock.release()
try:
yield
finally:
self._lock.acquire()


def close(*z: Any) -> None:
"""Close *z* if possible."""
Expand All @@ -66,3 +103,27 @@ def flush(*z: Any) -> None:
for zi in z:
if hasattr(zi, "flush"):
zi.flush()


def discard(m: MutableMapping[KT, VT], key: KT) -> None:
"""Flush *key* if possible.
Not the same as ``m.pop(key, None)``, as it doesn't trigger ``__getitem__``.
"""
try:
del m[key]
except KeyError:
pass


def locked(func: Callable[P, VT]) -> Callable[P, VT]:
"""Decorator for a method of ZictBase, which wraps the whole method in a
mapping-global rlock.
"""

@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> VT:
self = cast(ZictBase, args[0])
with self._lock:
return func(*args, **kwargs)

return wrapper
Loading

0 comments on commit ee658a2

Please sign in to comment.