From ebb9d36ccba8c2a3fdbf31b58f55f3e0053f917d Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Thu, 8 Aug 2024 21:58:35 +0100 Subject: [PATCH] Rethink memory only cache for non-broadcastable values --- pyop2/caching.py | 56 +++++++++++++++++++++++++++++++++++--- pyop2/compilation.py | 62 +++++++++++++++++++++++++----------------- pyop2/global_kernel.py | 15 ++-------- 3 files changed, 91 insertions(+), 42 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index ac2fbe6a2..1fd6b876a 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -248,13 +248,17 @@ def cache_key(self): """ -def default_parallel_hashkey(comm, *args, **kwargs): +def default_parallel_hashkey(*args, **kwargs): + comm = kwargs.get('comm') return comm, cachetools.keys.hashkey(*args, **kwargs) def parallel_memory_only_cache(key=default_parallel_hashkey): - """Decorator for wrapping a function to be called over a communiucator in a - cache that stores values in memory. + """Memory only cache decorator. + + Decorator for wrapping a function to be called over a communiucator in a + cache that stores broadcastable values in memory. If the value is found in + the cache of rank 0 it is broadcast to all other ranks. :arg key: Callable returning the cache key for the function inputs. This function must return a 2-tuple where the first entry is the @@ -278,7 +282,6 @@ def wrapper(*args, **kwargs): # Grab value from rank 0 memory cache and broadcast result if comm.rank == 0: - v = local_cache.get(k) if v is None: debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss') @@ -296,6 +299,51 @@ def wrapper(*args, **kwargs): return decorator +def parallel_memory_only_cache_no_broadcast(key=default_parallel_hashkey): + """Memory only cache decorator. + + Decorator for wrapping a function to be called over a communiucator in a + cache that stores non-broadcastable values in memory, for instance function + pointers. If the value is not present on all ranks, all ranks repeat the + work. + + :arg key: Callable returning the cache key for the function inputs. This + function must return a 2-tuple where the first entry is the + communicator to be collective over and the second is the key. This is + required to ensure that deadlocks do not occur when using different + subcommunicators. + """ + def decorator(func): + def wrapper(*args, **kwargs): + """ Extract the key and then try the memory cache before falling back + on calling the function and populating the cache. + """ + comm, mem_key = key(*args, **kwargs) + k = _as_hexdigest(mem_key) + + # Fetch the per-comm cache or set it up if not present + local_cache = comm.Get_attr(comm_cache_keyval) + if local_cache is None: + local_cache = {} + comm.Set_attr(comm_cache_keyval, local_cache) + + # Grab value from all ranks memory cache and vote + v = local_cache.get(k) + if v is None: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss') + else: + debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache hit') + all_present = comm.allgather(bool(v)) + + # If not present in the cache of all ranks, recompute on all ranks + if not min(all_present): + v = func(*args, **kwargs) + return local_cache.setdefault(k, v) + + return wrapper + return decorator + + def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=False): """Decorator for wrapping a function in a cache that stores values in memory and to disk. diff --git a/pyop2/compilation.py b/pyop2/compilation.py index f4a1af36a..80a5b4ccc 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -42,9 +42,11 @@ import shlex from hashlib import md5 from packaging.version import Version, InvalidVersion +from textwrap import dedent from pyop2 import mpi +from pyop2.caching import parallel_memory_only_cache_no_broadcast from pyop2.configuration import configuration from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError @@ -317,36 +319,42 @@ def get_so(self, jitmodule, extension): dirpart, basename = basename[:2], basename[2:] cachedir = os.path.join(cachedir, dirpart) pid = os.getpid() - cname = os.path.join(cachedir, "%s_p%d.%s" % (basename, pid, extension)) - oname = os.path.join(cachedir, "%s_p%d.o" % (basename, pid)) - soname = os.path.join(cachedir, "%s.so" % basename) + cname = os.path.join(cachedir, f"{basename}_p{pid}.{extension}") + oname = os.path.join(cachedir, f"{basename}_p{pid}.o") + soname = os.path.join(cachedir, f"{basename}.so") # Link into temporary file, then rename to shared library # atomically (avoiding races). - tmpname = os.path.join(cachedir, "%s_p%d.so.tmp" % (basename, pid)) + tmpname = os.path.join(cachedir, f"{basename}_p{pid}.so.tmp") if configuration['check_src_hashes'] or configuration['debug']: matching = self.comm.allreduce(basename, op=_check_op) if matching != basename: # Dump all src code to disk for debugging output = os.path.join(configuration["cache_dir"], "mismatching-kernels") - srcfile = os.path.join(output, "src-rank%d.c" % self.comm.rank) + srcfile = os.path.join(output, f"src-rank{self.comm.rank}.{extension}") if self.comm.rank == 0: os.makedirs(output, exist_ok=True) self.comm.barrier() with open(srcfile, "w") as f: f.write(jitmodule.code_to_compile) self.comm.barrier() - raise CompilationError("Generated code differs across ranks (see output in %s)" % output) + raise CompilationError(f"Generated code differs across ranks (see output in {output})") + + # Check whether this shared object already written to disk try: - # Are we in the cache? - return ctypes.CDLL(soname) + dll = ctypes.CDLL(soname) except OSError: - # No, let's go ahead and build + dll = None + got_dll = bool(dll) + all_dll = self.comm.allgather(got_dll) + + # If the library is not loaded _on all ranks_ build it + if not min(all_dll): if self.comm.rank == 0: # No need to do this on all ranks os.makedirs(cachedir, exist_ok=True) - logfile = os.path.join(cachedir, "%s_p%d.log" % (basename, pid)) - errfile = os.path.join(cachedir, "%s_p%d.err" % (basename, pid)) + logfile = os.path.join(cachedir, f"{basename}_p{pid}.log") + errfile = os.path.join(cachedir, f"{basename}_p{pid}.err") with progress(INFO, 'Compiling wrapper'): with open(cname, "w") as f: f.write(jitmodule.code_to_compile) @@ -356,7 +364,7 @@ def get_so(self, jitmodule, extension): + compiler_flags \ + ('-o', tmpname, cname) \ + self.ldflags - debug('Compilation command: %s', ' '.join(cc)) + debug(f"Compilation command: {' '.join(cc)}") with open(logfile, "w") as log, open(errfile, "w") as err: log.write("Compilation command:\n") log.write(" ".join(cc)) @@ -371,11 +379,12 @@ def get_so(self, jitmodule, extension): else: subprocess.check_call(cc, stderr=err, stdout=log) except subprocess.CalledProcessError as e: - raise CompilationError( - """Command "%s" return error status %d. -Unable to compile code -Compile log in %s -Compile errors in %s""" % (e.cmd, e.returncode, logfile, errfile)) + raise CompilationError(dedent(f""" + Command "{e.cmd}" return error status {e.returncode}. + Unable to compile code + Compile log in {logfile} + Compile errors in {errfile} + """)) else: cc = (compiler,) \ + compiler_flags \ @@ -384,8 +393,8 @@ def get_so(self, jitmodule, extension): ld = tuple(shlex.split(self.ld)) \ + ('-o', tmpname, oname) \ + tuple(self.expandWl(self.ldflags)) - debug('Compilation command: %s', ' '.join(cc)) - debug('Link command: %s', ' '.join(ld)) + debug(f"Compilation command: {' '.join(cc)}", ) + debug(f"Link command: {' '.join(ld)}") with open(logfile, "a") as log, open(errfile, "a") as err: log.write("Compilation command:\n") log.write(" ".join(cc)) @@ -409,17 +418,19 @@ def get_so(self, jitmodule, extension): subprocess.check_call(cc, stderr=err, stdout=log) subprocess.check_call(ld, stderr=err, stdout=log) except subprocess.CalledProcessError as e: - raise CompilationError( - """Command "%s" return error status %d. -Unable to compile code -Compile log in %s -Compile errors in %s""" % (e.cmd, e.returncode, logfile, errfile)) + raise CompilationError(dedent(f""" + Command "{e.cmd}" return error status {e.returncode}. + Unable to compile code + Compile log in {logfile} + Compile errors in {errfile} + """)) # Atomically ensure soname exists os.rename(tmpname, soname) # Wait for compilation to complete self.comm.barrier() # Load resulting library - return ctypes.CDLL(soname) + dll = ctypes.CDLL(soname) + return dll class MacClangCompiler(Compiler): @@ -547,6 +558,7 @@ class AnonymousCompiler(Compiler): @mpi.collective +@parallel_memory_only_cache_no_broadcast() def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), argtypes=None, restype=None, comm=None): """Build a shared library and return a function pointer from it. diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index 79fbcaeee..8f35038ff 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -11,7 +11,7 @@ from petsc4py import PETSc from pyop2 import compilation, mpi -from pyop2.caching import Cached, parallel_memory_only_cache +from pyop2.caching import Cached from pyop2.configuration import configuration from pyop2.datatypes import IntType, as_ctypes from pyop2.types import IterationRegion, Constant, READ @@ -329,30 +329,19 @@ def __init__(self, local_kernel, arguments, *, self._iteration_region = iteration_region self._pass_layer_arg = pass_layer_arg - # Cache for stashing the compiled code - self._func_cache = {} - self._initialized = True - @staticmethod - def _call_key(self, comm, *args): - return comm, (0,) - @mpi.collective - @parallel_memory_only_cache(key=_call_key) def __call__(self, comm, *args): """Execute the compiled kernel. :arg comm: Communicator the execution is collective over. :*args: Arguments to pass to the compiled kernel. """ + # It is unnecessary to cache this call as it is cached in pyop2/compilation.py func = self.compile(comm) func(*args) - # This method has to return _something_ for the `@parallel_memory_only_cache` - # to function correctly - return 0 - @property def _wrapper_name(self): import warnings