Skip to content

Commit

Permalink
Rethink memory only cache for non-broadcastable values
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Aug 8, 2024
1 parent 9e20a3d commit ebb9d36
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 42 deletions.
56 changes: 52 additions & 4 deletions pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,13 +248,17 @@ def cache_key(self):
"""


def default_parallel_hashkey(comm, *args, **kwargs):
def default_parallel_hashkey(*args, **kwargs):
comm = kwargs.get('comm')
return comm, cachetools.keys.hashkey(*args, **kwargs)


def parallel_memory_only_cache(key=default_parallel_hashkey):
"""Decorator for wrapping a function to be called over a communiucator in a
cache that stores values in memory.
"""Memory only cache decorator.
Decorator for wrapping a function to be called over a communiucator in a
cache that stores broadcastable values in memory. If the value is found in
the cache of rank 0 it is broadcast to all other ranks.
:arg key: Callable returning the cache key for the function inputs. This
function must return a 2-tuple where the first entry is the
Expand All @@ -278,7 +282,6 @@ def wrapper(*args, **kwargs):

# Grab value from rank 0 memory cache and broadcast result
if comm.rank == 0:

v = local_cache.get(k)
if v is None:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss')
Expand All @@ -296,6 +299,51 @@ def wrapper(*args, **kwargs):
return decorator


def parallel_memory_only_cache_no_broadcast(key=default_parallel_hashkey):
"""Memory only cache decorator.
Decorator for wrapping a function to be called over a communiucator in a
cache that stores non-broadcastable values in memory, for instance function
pointers. If the value is not present on all ranks, all ranks repeat the
work.
:arg key: Callable returning the cache key for the function inputs. This
function must return a 2-tuple where the first entry is the
communicator to be collective over and the second is the key. This is
required to ensure that deadlocks do not occur when using different
subcommunicators.
"""
def decorator(func):
def wrapper(*args, **kwargs):
""" Extract the key and then try the memory cache before falling back
on calling the function and populating the cache.
"""
comm, mem_key = key(*args, **kwargs)
k = _as_hexdigest(mem_key)

# Fetch the per-comm cache or set it up if not present
local_cache = comm.Get_attr(comm_cache_keyval)
if local_cache is None:
local_cache = {}
comm.Set_attr(comm_cache_keyval, local_cache)

# Grab value from all ranks memory cache and vote
v = local_cache.get(k)
if v is None:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss')
else:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache hit')
all_present = comm.allgather(bool(v))

# If not present in the cache of all ranks, recompute on all ranks
if not min(all_present):
v = func(*args, **kwargs)
return local_cache.setdefault(k, v)

return wrapper
return decorator


def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=False):
"""Decorator for wrapping a function in a cache that stores values in memory and to disk.
Expand Down
62 changes: 37 additions & 25 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,11 @@
import shlex
from hashlib import md5
from packaging.version import Version, InvalidVersion
from textwrap import dedent


from pyop2 import mpi
from pyop2.caching import parallel_memory_only_cache_no_broadcast
from pyop2.configuration import configuration
from pyop2.logger import warning, debug, progress, INFO
from pyop2.exceptions import CompilationError
Expand Down Expand Up @@ -317,36 +319,42 @@ def get_so(self, jitmodule, extension):
dirpart, basename = basename[:2], basename[2:]
cachedir = os.path.join(cachedir, dirpart)
pid = os.getpid()
cname = os.path.join(cachedir, "%s_p%d.%s" % (basename, pid, extension))
oname = os.path.join(cachedir, "%s_p%d.o" % (basename, pid))
soname = os.path.join(cachedir, "%s.so" % basename)
cname = os.path.join(cachedir, f"{basename}_p{pid}.{extension}")
oname = os.path.join(cachedir, f"{basename}_p{pid}.o")
soname = os.path.join(cachedir, f"{basename}.so")
# Link into temporary file, then rename to shared library
# atomically (avoiding races).
tmpname = os.path.join(cachedir, "%s_p%d.so.tmp" % (basename, pid))
tmpname = os.path.join(cachedir, f"{basename}_p{pid}.so.tmp")

if configuration['check_src_hashes'] or configuration['debug']:
matching = self.comm.allreduce(basename, op=_check_op)
if matching != basename:
# Dump all src code to disk for debugging
output = os.path.join(configuration["cache_dir"], "mismatching-kernels")
srcfile = os.path.join(output, "src-rank%d.c" % self.comm.rank)
srcfile = os.path.join(output, f"src-rank{self.comm.rank}.{extension}")
if self.comm.rank == 0:
os.makedirs(output, exist_ok=True)
self.comm.barrier()
with open(srcfile, "w") as f:
f.write(jitmodule.code_to_compile)
self.comm.barrier()
raise CompilationError("Generated code differs across ranks (see output in %s)" % output)
raise CompilationError(f"Generated code differs across ranks (see output in {output})")

# Check whether this shared object already written to disk
try:
# Are we in the cache?
return ctypes.CDLL(soname)
dll = ctypes.CDLL(soname)
except OSError:
# No, let's go ahead and build
dll = None
got_dll = bool(dll)
all_dll = self.comm.allgather(got_dll)

# If the library is not loaded _on all ranks_ build it
if not min(all_dll):
if self.comm.rank == 0:
# No need to do this on all ranks
os.makedirs(cachedir, exist_ok=True)
logfile = os.path.join(cachedir, "%s_p%d.log" % (basename, pid))
errfile = os.path.join(cachedir, "%s_p%d.err" % (basename, pid))
logfile = os.path.join(cachedir, f"{basename}_p{pid}.log")
errfile = os.path.join(cachedir, f"{basename}_p{pid}.err")
with progress(INFO, 'Compiling wrapper'):
with open(cname, "w") as f:
f.write(jitmodule.code_to_compile)
Expand All @@ -356,7 +364,7 @@ def get_so(self, jitmodule, extension):
+ compiler_flags \
+ ('-o', tmpname, cname) \
+ self.ldflags
debug('Compilation command: %s', ' '.join(cc))
debug(f"Compilation command: {' '.join(cc)}")
with open(logfile, "w") as log, open(errfile, "w") as err:
log.write("Compilation command:\n")
log.write(" ".join(cc))
Expand All @@ -371,11 +379,12 @@ def get_so(self, jitmodule, extension):
else:
subprocess.check_call(cc, stderr=err, stdout=log)
except subprocess.CalledProcessError as e:
raise CompilationError(
"""Command "%s" return error status %d.
Unable to compile code
Compile log in %s
Compile errors in %s""" % (e.cmd, e.returncode, logfile, errfile))
raise CompilationError(dedent(f"""
Command "{e.cmd}" return error status {e.returncode}.
Unable to compile code
Compile log in {logfile}
Compile errors in {errfile}
"""))
else:
cc = (compiler,) \
+ compiler_flags \
Expand All @@ -384,8 +393,8 @@ def get_so(self, jitmodule, extension):
ld = tuple(shlex.split(self.ld)) \
+ ('-o', tmpname, oname) \
+ tuple(self.expandWl(self.ldflags))
debug('Compilation command: %s', ' '.join(cc))
debug('Link command: %s', ' '.join(ld))
debug(f"Compilation command: {' '.join(cc)}", )
debug(f"Link command: {' '.join(ld)}")
with open(logfile, "a") as log, open(errfile, "a") as err:
log.write("Compilation command:\n")
log.write(" ".join(cc))
Expand All @@ -409,17 +418,19 @@ def get_so(self, jitmodule, extension):
subprocess.check_call(cc, stderr=err, stdout=log)
subprocess.check_call(ld, stderr=err, stdout=log)
except subprocess.CalledProcessError as e:
raise CompilationError(
"""Command "%s" return error status %d.
Unable to compile code
Compile log in %s
Compile errors in %s""" % (e.cmd, e.returncode, logfile, errfile))
raise CompilationError(dedent(f"""
Command "{e.cmd}" return error status {e.returncode}.
Unable to compile code
Compile log in {logfile}
Compile errors in {errfile}
"""))
# Atomically ensure soname exists
os.rename(tmpname, soname)
# Wait for compilation to complete
self.comm.barrier()
# Load resulting library
return ctypes.CDLL(soname)
dll = ctypes.CDLL(soname)
return dll


class MacClangCompiler(Compiler):
Expand Down Expand Up @@ -547,6 +558,7 @@ class AnonymousCompiler(Compiler):


@mpi.collective
@parallel_memory_only_cache_no_broadcast()
def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(),
argtypes=None, restype=None, comm=None):
"""Build a shared library and return a function pointer from it.
Expand Down
15 changes: 2 additions & 13 deletions pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from petsc4py import PETSc

from pyop2 import compilation, mpi
from pyop2.caching import Cached, parallel_memory_only_cache
from pyop2.caching import Cached
from pyop2.configuration import configuration
from pyop2.datatypes import IntType, as_ctypes
from pyop2.types import IterationRegion, Constant, READ
Expand Down Expand Up @@ -329,30 +329,19 @@ def __init__(self, local_kernel, arguments, *,
self._iteration_region = iteration_region
self._pass_layer_arg = pass_layer_arg

# Cache for stashing the compiled code
self._func_cache = {}

self._initialized = True

@staticmethod
def _call_key(self, comm, *args):
return comm, (0,)

@mpi.collective
@parallel_memory_only_cache(key=_call_key)
def __call__(self, comm, *args):
"""Execute the compiled kernel.
:arg comm: Communicator the execution is collective over.
:*args: Arguments to pass to the compiled kernel.
"""
# It is unnecessary to cache this call as it is cached in pyop2/compilation.py
func = self.compile(comm)
func(*args)

# This method has to return _something_ for the `@parallel_memory_only_cache`
# to function correctly
return 0

@property
def _wrapper_name(self):
import warnings
Expand Down

0 comments on commit ebb9d36

Please sign in to comment.