Skip to content

Commit

Permalink
async_buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Mar 29, 2023
1 parent 0d39c19 commit 3f20267
Show file tree
Hide file tree
Showing 12 changed files with 369 additions and 212 deletions.
5 changes: 4 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,7 @@ repos:
- tornado
- pyarrow
- git+https://github.com/dask/dask
- git+https://github.com/dask/zict
# DO NOT MERGE
- git+https://github.com/crusaderky/zict@async_buffer2

# clear cache +1
3 changes: 2 additions & 1 deletion continuous_integration/environment-3.10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ dependencies:
- pip:
- git+https://github.com/dask/dask
- git+https://github.com/dask/s3fs
- git+https://github.com/dask/zict
# DO NOT MERGE
- git+https://github.com/crusaderky/zict@async_buffer2
- git+https://github.com/fsspec/filesystem_spec
- keras
- gilknocker>=0.3.0
3 changes: 2 additions & 1 deletion continuous_integration/environment-3.11.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ dependencies:
- pip:
- git+https://github.com/dask/dask
- git+https://github.com/dask/s3fs
- git+https://github.com/dask/zict
# DO NOT MERGE
- git+https://github.com/crusaderky/zict@async_buffer2
- git+https://github.com/fsspec/filesystem_spec
- keras
- gilknocker>=0.3.0
2 changes: 1 addition & 1 deletion distributed/http/worker/prometheus/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def collect(self) -> Iterator[Metric]:
)

try:
spilled_memory, spilled_disk = self.server.data.spilled_total # type: ignore
spilled_memory, spilled_disk = self.server.data.spilled_total() # type: ignore
except AttributeError:
spilled_memory, spilled_disk = 0, 0 # spilling is disabled
process_memory = self.server.monitor.get_process_memory()
Expand Down
241 changes: 172 additions & 69 deletions distributed/spill.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from __future__ import annotations

import asyncio
import logging
from collections import defaultdict
from collections.abc import Hashable, Iterator, Mapping, Sized
from collections.abc import (
Awaitable,
Hashable,
Iterator,
Mapping,
MutableMapping,
Sized,
)
from contextlib import contextmanager
from functools import partial
from typing import Any, Literal, NamedTuple, Protocol, cast
from typing import Collection # TODO import from typing (requires Python >=3.9)
from typing import TYPE_CHECKING, Literal, NamedTuple, Protocol, cast, runtime_checkable

from packaging.version import parse as parse_version

Expand All @@ -14,7 +23,11 @@
from distributed.metrics import context_meter
from distributed.protocol import deserialize_bytes, serialize_bytelist
from distributed.sizeof import safe_sizeof
from distributed.utils import RateLimiterFilter
from distributed.utils import RateLimiterFilter, empty_context

if TYPE_CHECKING:
# TODO import from typing (requires Python >=3.10)
from typing_extensions import TypeAlias

logger = logging.getLogger(__name__)
logger.addFilter(RateLimiterFilter("Spill file on disk reached capacity"))
Expand All @@ -39,7 +52,9 @@ def __sub__(self, other: SpilledSize) -> SpilledSize:


class ManualEvictProto(Protocol):
"""Duck-type API that a third-party alternative to SpillBuffer must respect (in
"""**DEPRECATED**; please upgrade to AsyncBufferProto
Duck-type API that a third-party alternative to SpillBuffer must respect (in
addition to MutableMapping) if it wishes to support spilling when the
``distributed.worker.memory.spill`` threshold is surpassed.
Expand Down Expand Up @@ -68,8 +83,54 @@ def evict(self) -> int:
... # pragma: nocover


# zict.Buffer[str, Any] requires zict >= 2.2.0
class SpillBuffer(zict.Buffer):
@runtime_checkable
class AsyncBufferProto(Protocol, Collection[str]):
"""Duck-type API that a third-party alternative to SpillBuffer must respect if it
wishes to support spilling.
Notes
-----
``async_get`` must raise KeyError if and only if a key is not in the collection at
some point during the call. ``__setitem__`` immediately followed by ``async_get``
*never* raises KeyError; ``__delitem__ `` immediately followed by ``async_get``
*always* raises KeyError. Likewise, ``__contains__`` and ``__len__`` must
immediately reflect the changes wrought by ``__setitem__`` / ``__delitem__``.
This is public API.
"""

def __setitem__(self, key: str, value: object) -> None:
...

def __delitem__(self, key: str) -> None:
...

def async_get(
self, keys: Collection[str], missing: Literal["raise", "omit"] = "raise"
) -> Awaitable[dict[str, object]]:
"""Fetch one or more key/value pairs"""
... # pragma: nocover

def memory_total(self) -> int:
"""(estimated) bytes currently held in fast memory"""
... # pragma: nocover

def async_evict_until_below_target(self, n: int) -> None:
"""Asynchronously start spilling keys until memory_total <= n"""


# TODO remove quotes (requires Python >=3.9)
WorkerDataProto: TypeAlias = "MutableMapping[str, object] | AsyncBufferProto"


if has_zict_230:
from zict import AsyncBuffer
else:
from zict import Buffer as AsyncBuffer # type: ignore[assignment]


# zict.Buffer[str, object] requires zict >= 2.2.0
class SpillBuffer(AsyncBuffer):
"""MutableMapping that automatically spills out dask key/value pairs to disk when
the total size of the stored data exceeds the target. If max_spill is provided the
key/value pairs won't be spilled once this threshold has been reached.
Expand Down Expand Up @@ -103,6 +164,10 @@ def __init__(
self.logged_pickle_errors = set() # keys logged with pickle error
self.cumulative_metrics = defaultdict(float)

if not has_zict_230:
self.lock = empty_context # type: ignore
self.fast.lock = empty_context # type: ignore

@contextmanager
def _capture_metrics(self) -> Iterator[None]:
"""Capture metrics re. disk read/write, serialize/deserialize, and
Expand Down Expand Up @@ -164,60 +229,87 @@ def _handle_errors(self, key: str | None) -> Iterator[None]:
self.logged_pickle_errors.add(key_e)
raise HandledError()

def __setitem__(self, key: str, value: Any) -> None:
"""If sizeof(value) < target, write key/value pair to self.fast; this may in
turn cause older keys to be spilled from fast to slow.
If sizeof(value) >= target, write key/value pair directly to self.slow instead.
Raises
------
Exception
sizeof(value) >= target, and value failed to pickle.
The key/value pair has been forgotten.
In all other cases:
- an older value was evicted and failed to pickle,
- this value or an older one caused the disk to fill and raise OSError,
- this value or an older one caused the max_spill threshold to be exceeded,
this method does not raise and guarantees that the key/value that caused the
issue remained in fast.
"""
try:
with self._capture_metrics(), self._handle_errors(key):
super().__setitem__(key, value)
self.logged_pickle_errors.discard(key)
except HandledError:
assert key in self.fast
assert key not in self.slow

def evict(self) -> int:
"""Implementation of :meth:`ManualEvictProto.evict`.
Manually evict the oldest key/value pair, even if target has not been
reached. Returns sizeof(value).
If the eviction failed (value failed to pickle, disk full, or max_spill
exceeded), return -1; the key/value pair that caused the issue will remain in
fast. The exception has been logged internally.
This method never raises.
"""
try:
with self._capture_metrics(), self._handle_errors(None):
_, _, weight = self.fast.evict()
return cast(int, weight)
except HandledError:
return -1
if has_zict_230:

def async_get(
self, keys: Collection[str], missing: Literal["raise", "omit"] = "raise"
) -> asyncio.Future[dict[str, object]]:
with self.lock, self.fast.lock, self._capture_metrics():
f = super().async_get(keys, missing=missing)
if f.done():
for key in keys:
nbytes = cast(int, self.fast.weights[key])
context_meter.digest_metric("memory-read", 1, "count")
context_meter.digest_metric("memory-read", nbytes, "bytes")
return f

def evict_until_below_target(self, n: float | None = None) -> None:
try:
with self._capture_metrics(), self._handle_errors(None):
super().evict_until_below_target(n)
except HandledError:
pass

def __getitem__(self, key: str) -> Any:
else:

def __setitem__(self, key: str, value: object) -> None:
"""If sizeof(value) < target, write key/value pair to self.fast; this may in
turn cause older keys to be spilled from fast to slow.
If sizeof(value) >= target, write key/value pair directly to self.slow instead.
Raises
------
Exception
sizeof(value) >= target, and value failed to pickle.
The key/value pair has been forgotten.
In all other cases:
- an older value was evicted and failed to pickle,
- this value or an older one caused the disk to fill and raise OSError,
- this value or an older one caused the max_spill threshold to be exceeded,
this method does not raise and guarantees that the key/value that caused the
issue remained in fast.
"""
try:
with self._capture_metrics(), self._handle_errors(key):
super().__setitem__(key, value)
self.logged_pickle_errors.discard(key)
except HandledError:
assert key in self.fast
assert key not in self.slow

def evict(self) -> int:
"""Implementation of :meth:`ManualEvictProto.evict`.
Manually evict the oldest key/value pair, even if target has not been
reached. Returns sizeof(value).
If the eviction failed (value failed to pickle, disk full, or max_spill
exceeded), return -1; the key/value pair that caused the issue will remain in
fast. The exception has been logged internally.
This method never raises.
"""
try:
with self._capture_metrics(), self._handle_errors(None):
_, _, weight = self.fast.evict()
return cast(int, weight)
except HandledError:
return -1

# In zict >=2.3.0, this is always called from an offloaded thread, except at most
# in unit tests (see zict.AsyncBuffer.async_get)
def __getitem__(self, key: str) -> object:
# Note: don't log from self.fast.__getitem__, because that's called every time a
# key is evicted, and we don't want to count those events here.
# This is logged not only by the internal metrics callback but also by those
# installed by gather_dep, get_data, and execute
with self._capture_metrics():
if key in self.fast:
# Note: don't log from self.fast.__getitem__, because that's called
# every time a key is evicted, and we don't want to count those events
# here.
try:
nbytes = cast(int, self.fast.weights[key])
# This is logged not only by the internal metrics callback but also by
# those installed by gather_dep, get_data, and execute
except KeyError:
pass
else:
context_meter.digest_metric("memory-read", 1, "count")
context_meter.digest_metric("memory-read", nbytes, "bytes")

Expand All @@ -227,21 +319,21 @@ def __delitem__(self, key: str) -> None:
super().__delitem__(key)
self.logged_pickle_errors.discard(key)

def pop(self, key: str, default: Any = None) -> Any:
def pop(self, key: str, default: object = None) -> object:
raise NotImplementedError(
"Are you calling .pop(key, None) as a way to discard a key if it exists?"
"It may cause data to be read back from disk! Please use `del` instead."
)

@property
def memory(self) -> Mapping[str, Any]:
def memory(self) -> Mapping[str, object]:
"""Key/value pairs stored in RAM. Alias of zict.Buffer.fast.
For inspection only - do not modify directly!
"""
return self.fast

@property
def disk(self) -> Mapping[str, Any]:
def disk(self) -> Mapping[str, object]:
"""Key/value pairs spilled out to disk. Alias of zict.Buffer.slow.
For inspection only - do not modify directly!
"""
Expand All @@ -252,7 +344,10 @@ def _slow_uncached(self) -> Slow:
cache = cast(zict.Cache, self.slow)
return cast(Slow, cache.data)

@property
def memory_total(self) -> int:
"""Number of bytes in memory (output of sizeof())"""
return cast(int, self.fast.total_weight)

def spilled_total(self) -> SpilledSize:
"""Number of bytes spilled to disk. Tuple of
Expand All @@ -265,7 +360,7 @@ def spilled_total(self) -> SpilledSize:
return self._slow_uncached.total_weight


def _in_memory_weight(key: str, value: Any) -> int:
def _in_memory_weight(key: str, value: object) -> int:
return safe_sizeof(value)


Expand All @@ -282,7 +377,7 @@ class HandledError(Exception):
pass


# zict.Func[str, Any] requires zict >= 2.2.0
# zict.Func[str, object] requires zict >= 2.2.0
class Slow(zict.Func):
max_weight: int | Literal[False]
weight_by_key: dict[str, SpilledSize]
Expand All @@ -298,15 +393,18 @@ def __init__(self, spill_directory: str, max_weight: int | Literal[False] = Fals
self.weight_by_key = {}
self.total_weight = SpilledSize(0, 0)

def __getitem__(self, key: str) -> Any:
if not has_zict_230:
self.lock = empty_context # type: ignore

def __getitem__(self, key: str) -> object:
with context_meter.meter("disk-read", "seconds"):
pickled = self.d[key]
context_meter.digest_metric("disk-read", 1, "count")
context_meter.digest_metric("disk-read", len(pickled), "bytes")
out = self.load(pickled)
return out

def __setitem__(self, key: str, value: Any) -> None:
def __setitem__(self, key: str, value: object) -> None:
try:
pickled = self.dump(value)
except Exception as e:
Expand Down Expand Up @@ -341,9 +439,14 @@ def __setitem__(self, key: str, value: Any) -> None:
context_meter.digest_metric("disk-write", pickled_size, "bytes")

weight = SpilledSize(safe_sizeof(value), pickled_size)
self.weight_by_key[key] = weight
self.total_weight += weight

with self.lock:
# 2 threads call Slow.__delitem__, but only one calls Slow.__setitem__
assert key not in self.weight_by_key
self.weight_by_key[key] = weight
self.total_weight += weight

def __delitem__(self, key: str) -> None:
super().__delitem__(key)
self.total_weight -= self.weight_by_key.pop(key)
with self.lock:
super().__delitem__(key)
self.total_weight -= self.weight_by_key.pop(key)
Loading

0 comments on commit 3f20267

Please sign in to comment.