From 9e20a3d3043ff6b1d8dec0ee89248e37b9078bf5 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 8 Aug 2024 18:31:12 +0100 Subject: [PATCH] Refactor memory/disk caches in order to remove use of id() in GlobalKernel --- pyop2/caching.py | 76 ++++++++++++++++++++++++++++++++++-------- pyop2/global_kernel.py | 21 ++++++------ 2 files changed, 74 insertions(+), 23 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index e0d575054..ac2fbe6a2 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -41,7 +41,8 @@ from warnings import warn from pyop2.configuration import configuration -from pyop2.mpi import comm_cache_keyval +from pyop2.logger import debug +from pyop2.mpi import comm_cache_keyval, COMM_WORLD from pyop2.utils import cached_property @@ -247,6 +248,54 @@ def cache_key(self): """ +def default_parallel_hashkey(comm, *args, **kwargs): + return comm, cachetools.keys.hashkey(*args, **kwargs) + + +def parallel_memory_only_cache(key=default_parallel_hashkey): + """Decorator for wrapping a function to be called over a communiucator in a + cache that stores values in memory. + + :arg key: Callable returning the cache key for the function inputs. This + function must return a 2-tuple where the first entry is the + communicator to be collective over and the second is the key. This is + required to ensure that deadlocks do not occur when using different + subcommunicators. + """ + def decorator(func): + def wrapper(*args, **kwargs): + """ Extract the key and then try the memory cache before falling back + on calling the function and populating the cache. + """ + comm, mem_key = key(*args, **kwargs) + k = _as_hexdigest(mem_key) + + # Fetch the per-comm cache or set it up if not present + local_cache = comm.Get_attr(comm_cache_keyval) + if local_cache is None: + local_cache = {} + comm.Set_attr(comm_cache_keyval, local_cache) + + # Grab value from rank 0 memory cache and broadcast result + if comm.rank == 0: + + v = local_cache.get(k) + if v is None: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss') + else: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache hit') + comm.bcast(v, root=0) + else: + v = comm.bcast(None, root=0) + + if v is None: + v = func(*args, **kwargs) + return local_cache.setdefault(k, v) + + return wrapper + return decorator + + def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=False): """Decorator for wrapping a function in a cache that stores values in memory and to disk. @@ -277,33 +326,34 @@ def wrapper(*args, **kwargs): k = _as_hexdigest(key(*args, **kwargs)) try: v = cache[k] + debug(f'Serial: {k} memory cache hit') except KeyError: + debug(f'Serial: {k} memory cache miss') v = _disk_cache_get(cachedir, k) + if v is not None: + debug(f'Serial: {k} disk cache hit') if v is None: + debug(f'Serial: {k} disk cache miss') v = func(*args, **kwargs) _disk_cache_set(cachedir, k, v) return cache.setdefault(k, v) else: # Collective + @parallel_memory_only_cache(key=key) def wrapper(*args, **kwargs): """ Same as above, but in parallel over `comm` """ comm, disk_key = key(*args, **kwargs) k = _as_hexdigest(disk_key) - # Fetch the per-comm cache and set it up if not present - local_cache = comm.Get_attr(comm_cache_keyval) - if local_cache is None: - local_cache = {} - comm.Set_attr(comm_cache_keyval, local_cache) - - # Grab value from rank 0 memory/disk cache and broadcast result + # Grab value from rank 0 disk cache and broadcast result if comm.rank == 0: - try: - v = local_cache[k] - except KeyError: - v = _disk_cache_get(cachedir, k) + v = _disk_cache_get(cachedir, k) + if v is not None: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} disk cache hit') + else: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} disk cache miss') comm.bcast(v, root=0) else: v = comm.bcast(None, root=0) @@ -313,7 +363,7 @@ def wrapper(*args, **kwargs): # Only write to the disk cache on rank 0 if comm.rank == 0: _disk_cache_set(cachedir, k, v) - return local_cache.setdefault(k, v) + return v return wrapper return decorator diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index 536d717e9..79fbcaeee 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -11,7 +11,7 @@ from petsc4py import PETSc from pyop2 import compilation, mpi -from pyop2.caching import Cached +from pyop2.caching import Cached, parallel_memory_only_cache from pyop2.configuration import configuration from pyop2.datatypes import IntType, as_ctypes from pyop2.types import IterationRegion, Constant, READ @@ -334,24 +334,25 @@ def __init__(self, local_kernel, arguments, *, self._initialized = True + @staticmethod + def _call_key(self, comm, *args): + return comm, (0,) + @mpi.collective + @parallel_memory_only_cache(key=_call_key) def __call__(self, comm, *args): """Execute the compiled kernel. :arg comm: Communicator the execution is collective over. :*args: Arguments to pass to the compiled kernel. """ - # If the communicator changes then we cannot safely use the in-memory - # function cache. Note here that we are not using dup_comm to get a - # stable communicator id because we will already be using the internal one. - key = id(comm) - try: - func = self._func_cache[key] - except KeyError: - func = self.compile(comm) - self._func_cache[key] = func + func = self.compile(comm) func(*args) + # This method has to return _something_ for the `@parallel_memory_only_cache` + # to function correctly + return 0 + @property def _wrapper_name(self): import warnings