Skip to content

Commit

Permalink
Refactor memory/disk caches in order to remove use of id() in GlobalK…
Browse files Browse the repository at this point in the history
…ernel
  • Loading branch information
JDBetteridge committed Aug 8, 2024
1 parent b0780e6 commit 9e20a3d
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 23 deletions.
76 changes: 63 additions & 13 deletions pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@
from warnings import warn

from pyop2.configuration import configuration
from pyop2.mpi import comm_cache_keyval
from pyop2.logger import debug
from pyop2.mpi import comm_cache_keyval, COMM_WORLD
from pyop2.utils import cached_property


Expand Down Expand Up @@ -247,6 +248,54 @@ def cache_key(self):
"""


def default_parallel_hashkey(comm, *args, **kwargs):
return comm, cachetools.keys.hashkey(*args, **kwargs)


def parallel_memory_only_cache(key=default_parallel_hashkey):
"""Decorator for wrapping a function to be called over a communiucator in a
cache that stores values in memory.
:arg key: Callable returning the cache key for the function inputs. This
function must return a 2-tuple where the first entry is the
communicator to be collective over and the second is the key. This is
required to ensure that deadlocks do not occur when using different
subcommunicators.
"""
def decorator(func):
def wrapper(*args, **kwargs):
""" Extract the key and then try the memory cache before falling back
on calling the function and populating the cache.
"""
comm, mem_key = key(*args, **kwargs)
k = _as_hexdigest(mem_key)

# Fetch the per-comm cache or set it up if not present
local_cache = comm.Get_attr(comm_cache_keyval)
if local_cache is None:
local_cache = {}
comm.Set_attr(comm_cache_keyval, local_cache)

# Grab value from rank 0 memory cache and broadcast result
if comm.rank == 0:

v = local_cache.get(k)
if v is None:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss')
else:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache hit')
comm.bcast(v, root=0)
else:
v = comm.bcast(None, root=0)

if v is None:
v = func(*args, **kwargs)
return local_cache.setdefault(k, v)

return wrapper
return decorator


def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=False):
"""Decorator for wrapping a function in a cache that stores values in memory and to disk.
Expand Down Expand Up @@ -277,33 +326,34 @@ def wrapper(*args, **kwargs):
k = _as_hexdigest(key(*args, **kwargs))
try:
v = cache[k]
debug(f'Serial: {k} memory cache hit')
except KeyError:
debug(f'Serial: {k} memory cache miss')
v = _disk_cache_get(cachedir, k)
if v is not None:
debug(f'Serial: {k} disk cache hit')

if v is None:
debug(f'Serial: {k} disk cache miss')
v = func(*args, **kwargs)
_disk_cache_set(cachedir, k, v)
return cache.setdefault(k, v)

else: # Collective
@parallel_memory_only_cache(key=key)
def wrapper(*args, **kwargs):
""" Same as above, but in parallel over `comm`
"""
comm, disk_key = key(*args, **kwargs)
k = _as_hexdigest(disk_key)

# Fetch the per-comm cache and set it up if not present
local_cache = comm.Get_attr(comm_cache_keyval)
if local_cache is None:
local_cache = {}
comm.Set_attr(comm_cache_keyval, local_cache)

# Grab value from rank 0 memory/disk cache and broadcast result
# Grab value from rank 0 disk cache and broadcast result
if comm.rank == 0:
try:
v = local_cache[k]
except KeyError:
v = _disk_cache_get(cachedir, k)
v = _disk_cache_get(cachedir, k)
if v is not None:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} disk cache hit')
else:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} disk cache miss')
comm.bcast(v, root=0)
else:
v = comm.bcast(None, root=0)
Expand All @@ -313,7 +363,7 @@ def wrapper(*args, **kwargs):
# Only write to the disk cache on rank 0
if comm.rank == 0:
_disk_cache_set(cachedir, k, v)
return local_cache.setdefault(k, v)
return v

return wrapper
return decorator
Expand Down
21 changes: 11 additions & 10 deletions pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from petsc4py import PETSc

from pyop2 import compilation, mpi
from pyop2.caching import Cached
from pyop2.caching import Cached, parallel_memory_only_cache
from pyop2.configuration import configuration
from pyop2.datatypes import IntType, as_ctypes
from pyop2.types import IterationRegion, Constant, READ
Expand Down Expand Up @@ -334,24 +334,25 @@ def __init__(self, local_kernel, arguments, *,

self._initialized = True

@staticmethod
def _call_key(self, comm, *args):
return comm, (0,)

@mpi.collective
@parallel_memory_only_cache(key=_call_key)
def __call__(self, comm, *args):
"""Execute the compiled kernel.
:arg comm: Communicator the execution is collective over.
:*args: Arguments to pass to the compiled kernel.
"""
# If the communicator changes then we cannot safely use the in-memory
# function cache. Note here that we are not using dup_comm to get a
# stable communicator id because we will already be using the internal one.
key = id(comm)
try:
func = self._func_cache[key]
except KeyError:
func = self.compile(comm)
self._func_cache[key] = func
func = self.compile(comm)
func(*args)

# This method has to return _something_ for the `@parallel_memory_only_cache`
# to function correctly
return 0

@property
def _wrapper_name(self):
import warnings
Expand Down

0 comments on commit 9e20a3d

Please sign in to comment.