Skip to content

Commit

Permalink
WIP: shared memory without tmpfs
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Apr 20, 2023
1 parent bb17ff3 commit f883663
Show file tree
Hide file tree
Showing 6 changed files with 365 additions and 0 deletions.
2 changes: 2 additions & 0 deletions doc/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ API
:members:
.. autoclass:: LRU
:members:
.. autoclass:: SharedMemory
:members:
.. autoclass:: Sieve
:members:
.. autoclass:: Zip
Expand Down
1 change: 1 addition & 0 deletions zict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from zict.func import Func as Func
from zict.lmdb import LMDB as LMDB
from zict.lru import LRU as LRU
from zict.shared_memory import SharedMemory as SharedMemory
from zict.sieve import Sieve as Sieve
from zict.utils import InsertionSortedSet as InsertionSortedSet
from zict.zip import Zip as Zip
Expand Down
1 change: 1 addition & 0 deletions zict/shared_memory/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from zict.shared_memory.shared_memory import SharedMemory
62 changes: 62 additions & 0 deletions zict/shared_memory/_linux.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
"""Linux implementation of :class:`zict.SharedMemory`.
Wraps around glibc ``memfd_create``.
"""
from __future__ import annotations

import ctypes
import mmap
import os
from collections.abc import Iterable

_memfd_create = None


def _setitem(safe_key: str, value: Iterable[bytes | bytearray | memoryview]) -> int:
global _memfd_create
if _memfd_create is None:
libc = ctypes.CDLL("libc.so.6")
_memfd_create = libc.memfd_create

fd = _memfd_create(safe_key.encode("ascii"), 0)
if fd == -1:
raise OSError("Call to memfd_create failed") # pragma: nocover

with os.fdopen(fd, "wb", closefd=False) as fh:
fh.writelines(value)

return fd


def _getitem(fd: int) -> memoryview:
# This opens a second fd for as long as the memory map is referenced.
# Sadly there does not seem a way to extract the fd from the mmap, so we have to
# keep the original fd open for the purpose of exporting.
return memoryview(mmap.mmap(fd, 0))


def _delitem(fd: int) -> None:
# Close the original fd. There may be other fd's still open if the shared memory is
# referenced somewhere else.
# This is also called by SharedMemory.__del__.
os.close(fd)


def _export(safe_key: str, fd: int) -> tuple:
return safe_key, os.getpid(), fd


def _import(safe_key: str, pid: int, fd: int) -> int:
# if fd has been closed, raise FileNotFoundError
# if fd has been closed and reopened to something else, this may also raise a
# generic OSError, e.g. if this is now a socket
new_fd = os.open(f"/proc/{pid}/fd/{fd}", os.O_RDWR)

expect = f"/memfd:{safe_key} (deleted)"
actual = os.readlink(f"/proc/{os.getpid()}/fd/{new_fd}")
if actual != expect:
# fd has been closed and reopened to something else
os.close(new_fd)
raise OSError()

return new_fd
51 changes: 51 additions & 0 deletions zict/shared_memory/_windows.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""Windows implementation of :class:`zict.SharedMemory`.
Conveniently, :class:`multiprocessing.shared_memory.SharedMemory` already wraps around
the Windows API we want to use, so this is implemented as a hack on top of it.
"""
from __future__ import annotations

import mmap
import multiprocessing.shared_memory
from collections.abc import Collection
from typing import cast


class _PySharedMemoryNoClose(multiprocessing.shared_memory.SharedMemory):
def __del__(self) -> None:
pass


def _setitem(
safe_key: str, value: Collection[bytes | bytearray | memoryview]
) -> memoryview:
nbytes = sum(v.nbytes if isinstance(v, memoryview) else len(v) for v in value)
shm = _PySharedMemoryNoClose(safe_key, create=True, size=nbytes)
mm = cast(mmap.mmap, shm.buf.obj)
for v in value:
mm.write(v)
# This dereferences shm; if we hadn't overridden the __del__ method, it would cause
# it to automatically close the memory map and deallocate the shared memory.
return shm.buf


def _getitem(mm: memoryview) -> memoryview:
# Nothing to do. This is just for compatibility with the Linux implementation, which
# instead creates a memory map on the fly.
return mm


def _delitem(mm: memoryview) -> None:
# Nothing to do. The shared memory is released as soon as the last memory map
# referencing it is destroyed.
pass


def _export(safe_key: str, mm: memoryview) -> tuple:
return (safe_key,)


def _import(safe_key: str) -> memoryview:
# Raises OSError in case of invalid key
shm = _PySharedMemoryNoClose(safe_key)
return shm.buf
248 changes: 248 additions & 0 deletions zict/shared_memory/shared_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
from __future__ import annotations

import secrets
import sys
from collections.abc import Iterator, KeysView
from typing import Any
from urllib.parse import quote, unquote

from zict.common import ZictBase

if sys.platform == "linux":
from zict.shared_memory._linux import _delitem, _export, _getitem, _import, _setitem
elif sys.platform == "win32":
from zict.shared_memory._windows import (
_delitem,
_export,
_getitem,
_import,
_setitem,
)


class SharedMemory(ZictBase[str, memoryview]):
"""Mutable Mapping interface to shared memory.
**Supported OSs:** Linux, Windows
Keys must be strings, values must be buffers.
Keys are stored in private memory, and other SharedMemory objects by default won't
see them - even in case of key collision, the two pieces of data remain separate.
In order to share the same buffer, one SharedMemory object must call
:meth:`export` and the other :meth:`import_`.
**Resources usage**
On Linux, you will hold 1 file descriptor open for every key in the SharedMemory
mapping, plus 1 file descriptor for every returned memoryview that is referenced
somewhere else. Please ensure that your ``ulimit`` is high enough to cope with this.
If you expect to call ``__getitem__`` multiple times on the same key while the
return value from the previous call is still in use, you should wrap this mapping in
a :class:`~zict.Cache`:
>>> import zict
>>> shm = zict.Cache(
... zict.SharedMemory(),
... zict.WeakValueMapping(),
... update_on_set=False,
... ) # doctest: +SKIP
The above will cap the amount of open file descriptors per key to 2.
**Lifecycle**
Memory is released when all the SharedMemory objects that were sharing the key have
deleted it *and* the buffer returned by ``__getitem__`` is no longer referenced
anywhere else.
Process termination, including ungraceful termination (SIGKILL, SIGSEGV), also
releases the memory; in other words you don't risk leaking memory to the
OS if all processes that were sharing it crash or are killed.
Examples
--------
In process 1:
>>> import pickle, numpy, zict # doctest: +SKIP
>>> shm = zict.SharedMemory() # doctest: +SKIP
>>> a = numpy.random.random(2**27) # 1 GiB # doctest: +SKIP
>>> buffers = [] # doctest: +SKIP
>>> pik = pickle.dumps(a, protocol=5, buffer_callback=buffers.append)
... # doctest: +SKIP
>>> # This deep-copies the buffer, resulting in 1 GiB private + 1 GiB shared memory.
>>> shm["a"] = buffers # doctest: +SKIP
>>> # Release private memory, leaving only the shared memory allocated
>>> del a, buffers # doctest: +SKIP
>>> # Recreate array from shared memory. This requires no extra memory.
>>> a = pickle.loads(pik, buffers=[shm["a"]]) # doctest: +SKIP
>>> # Send trivially-sized metadata (<1 kiB) to the peer process somehow.
>>> send_to_process_2((pik, shm.export("a"))) # doctest: +SKIP
In process 2:
>>> import pickle, zict # doctest: +SKIP
>>> shm = zict.SharedMemory() # doctest: +SKIP
>>> pik, metadata = receive_from_process_1() # doctest: +SKIP
>>> key = shm.import_(metadata) # returns "a" # doctest: +SKIP
>>> a = pickle.loads(pik, buffers=[shm[key]]) # doctest: +SKIP
Now process 1 and 2 hold a reference to the same memory; in-place changes on one
process are reflected onto the other. The shared memory is released after you delete
the key and dereference the buffer returned by ``__getitem__`` on *both* processes:
>>> del shm["a"] # doctest: +SKIP
>>> del a # doctest: +SKIP
or alternatively when both processes are terminated.
**Implementation notes**
This mapping uses OS-specific shared memory, which
1. can be shared among already existing processes, e.g. unlike ``mmap(fd=-1)``, and
2. is automatically cleaned up by the OS in case of ungraceful process termination,
e.g. unlike ``shm_open`` (which is used by :mod:`multiprocessing.shared_memory`
on all POSIX OS'es)
It is implemented on top of ``memfd_create`` on Linux and ``CreateFileMapping`` on
Windows. Notably, there is no POSIX equivalent for these API calls, as it only
implements ``shm_open`` which would inevitably cause memory leaks in case of
ungraceful process termination.
"""

# {key: (unique safe key, implementation-specific data)}
_data: dict[str, tuple[str, Any]]

def __init__(self): # type: ignore[no-untyped-def]
if sys.platform not in ("linux", "win32"):
raise NotImplementedError(
"SharedMemory is only available on Linux and Windows"
)

self._data = {}

def __str__(self) -> str:
return f"<SharedMemory: {len(self)} elements>"

__repr__ = __str__

def __setitem__(
self,
key: str,
value: bytes
| bytearray
| memoryview
| list[bytes | bytearray | memoryview]
| tuple[bytes | bytearray | memoryview, ...],
) -> None:
try:
del self[key]
except KeyError:
pass

if not isinstance(value, (tuple, list)):
value = [value]
safe_key = quote(key, safe="") + "#" + secrets.token_bytes(8).hex()
impl_data = _setitem(safe_key, value)
self._data[key] = safe_key, impl_data

def __getitem__(self, key: str) -> memoryview:
_, impl_data = self._data[key]
return _getitem(impl_data)

def __delitem__(self, key: str) -> None:
_, impl_data = self._data.pop(key)
_delitem(impl_data)

def __del__(self) -> None:
try:
data_values = self._data.values()
except Exception:
# Interpreter shutdown
return # pragma: nocover

for _, impl_data in data_values:
try:
_delitem(impl_data)
except Exception:
pass # pragma: nocover

def close(self) -> None:
# Implements ZictBase.close(). Also triggered by __exit__.
self.clear()

def __contains__(self, key: object) -> bool:
return key in self._data

def keys(self) -> KeysView[str]:
return self._data.keys()

def __iter__(self) -> Iterator[str]:
return iter(self._data)

def __len__(self) -> int:
return len(self._data)

def export(self, key: str) -> tuple:
"""Export metadata for a key, which can be fed into :meth:`import_` on
another process.
Returns
-------
Opaque metadata object (implementation-specific) to be passed to
:meth:`import_`. It is serializable with JSON, YAML, and msgpack.
See Also
--------
import_
"""
return _export(*self._data[key])

def import_(self, metadata: tuple | list) -> str:
"""Import a key from another process, starting to share the memory area.
You should treat parameters as implementation details and just unpack the tuple
that was generated by :meth:`export`.
Returns
-------
Key that was just added to the mapping
Raises
------
FileNotFoundError
Either the key or the whole SharedMemory object were deleted on the process
where you ran :meth:`export`, or the process was terminated.
Notes
-----
On Windows, this method will raise FileNotFoundError if the key has been deleted
from the other SharedMemory mapping *and* it is no longer referenced anywhere.
On Linux, this method will raise as soon as the key is deleted from the other
SharedMemory mapping, even if it's still referenced.
e.g. this code is not portable, as it will work on Windows but not on Linux:
>>> buf = shm["x"] = buf # doctest: +SKIP
>>> meta = shm.export("x") # doctest: +SKIP
>>> del shm["x"] # doctest: +SKIP
See Also
--------
export
"""
safe_key = metadata[0]
key = unquote(safe_key.split("#")[0])

try:
del self[key]
except KeyError:
pass

try:
impl_data = _import(*metadata)
except OSError:
raise FileNotFoundError(f"Peer process no longer holds the key: {key!r}")
self._data[key] = safe_key, impl_data
return key

0 comments on commit f883663

Please sign in to comment.