Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
A better solution for wrapping make_so in a disk cache
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Aug 20, 2024
1 parent 4b33593 commit 8c9ca07
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 83 deletions.
27 changes: 13 additions & 14 deletions pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,20 +244,22 @@ class _CacheMiss:
def _as_hexdigest(*args):
hash_ = hashlib.md5()
for a in args:
# TODO: Remove or edit this check!
# JBTODO: Remove or edit this check!
if isinstance(a, MPI.Comm) or isinstance(a, cachetools.keys._HashedTuple):
breakpoint()
hash_.update(str(a).encode())
return hash_.hexdigest()


class DictLikeDiskAccess(MutableMapping):
def __init__(self, cachedir):
def __init__(self, cachedir, extension=".pickle"):
"""
:arg cachedir: The cache directory.
:arg extension: Optional extension to use for written files.
"""
self.cachedir = cachedir
self.extension = extension

def __getitem__(self, key):
"""Retrieve a value from the disk cache.
Expand All @@ -267,7 +269,7 @@ def __getitem__(self, key):
"""
filepath = Path(self.cachedir, key[0][:2], key[0][2:] + key[1])
try:
with self.open(filepath, mode="rb") as fh:
with self.open(filepath.with_suffix(self.extension), mode="rb") as fh:
value = self.read(fh)
except FileNotFoundError:
raise KeyError("File not on disk, cache miss")
Expand All @@ -287,7 +289,7 @@ def __setitem__(self, key, value):
filepath = basedir.joinpath(k2)
with self.open(tempfile, mode="wb") as fh:
self.write(fh, value)
tempfile.rename(filepath)
tempfile.rename(filepath.with_suffix(self.extension))

def __delitem__(self, key):
raise NotImplementedError(f"Cannot remove items from {self.__class__.__name__}")
Expand All @@ -299,11 +301,11 @@ def __len__(self):
raise NotImplementedError(f"Cannot query length of {self.__class__.__name__}")

def __repr__(self):
return f"{self.__class__.__name__}(cachedir={self.cachedir})"
return f"{self.__class__.__name__}(cachedir={self.cachedir}, extension={self.extension})"

def __eq__(self, other):
# Instances are the same if they have the same cachedir
return self.cachedir == other.cachedir
return (self.cachedir == other.cachedir and self.extension == other.extension)

def open(self, *args, **kwargs):
return open(*args, **kwargs)
Expand Down Expand Up @@ -359,8 +361,6 @@ def get(self, key, default=None):
self.hit += 1
return value

# JBTODO: Only instrument get, since we have to use get and get item in wrapper
# OR... find away around the hack in compilation.py
def __getitem__(self, key):
try:
value = super().__getitem__(key)
Expand Down Expand Up @@ -441,7 +441,7 @@ def wrapper(*args, **kwargs):
)
else:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit')
# TODO: Add communication tags to avoid cross-broadcasting
# JBTODO: Add communication tags to avoid cross-broadcasting
comm.bcast(value, root=0)
else:
value = comm.bcast(CACHE_MISS, root=0)
Expand All @@ -467,8 +467,7 @@ def wrapper(*args, **kwargs):

if value is CACHE_MISS:
value = func(*args, **kwargs)
local_cache[key] = value
return local_cache[key]
return local_cache.setdefault(key, value)

return wrapper
return decorator
Expand Down Expand Up @@ -497,13 +496,13 @@ def decorator(func):
return memory_cache(*args, **kwargs)(disk_only_cache(*args, cachedir=cachedir, **kwargs)(func))
return decorator

# TODO: (Wishlist)
# JBTODO: (Wishlist)
# * Try more exotic caches ie: memory_cache = partial(parallel_cache, cache_factory=lambda: cachetools.LRUCache(maxsize=1000)) ✓
# * Add some sort of cache reporting ✓
# * Add some sort of cache statistics ✓
# * Refactor compilation.py to use @mem_and_disk_cached, where get_so is just uses DictLikeDiskAccess with an overloaded self.write() method
# * Refactor compilation.py to use @mem_and_disk_cached, where get_so is just uses DictLikeDiskAccess with an overloaded self.write() method
# * Systematic investigation into cache sizes/types for Firedrake
# - Is a mem cache needed for DLLs?
# - Is a mem cache needed for DLLs? No
# - Is LRUCache better than a simple dict? (memory profile test suite)
# - What is the optimal maxsize?
# * Add some docstrings and maybe some exposition!
126 changes: 58 additions & 68 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@
from functools import partial
from pathlib import Path
from contextlib import contextmanager
from tempfile import gettempdir
from itertools import cycle


from pyop2 import mpi
from pyop2.caching import parallel_cache, memory_cache, default_parallel_hashkey
from pyop2.caching import parallel_cache, memory_cache, default_parallel_hashkey, _as_hexdigest, DictLikeDiskAccess
from pyop2.configuration import configuration
from pyop2.logger import warning, debug, progress, INFO
from pyop2.exceptions import CompilationError
Expand Down Expand Up @@ -417,6 +419,7 @@ def load_hashkey(*args, **kwargs):
return default_parallel_hashkey(code_hash, *args[1:], **kwargs)


# JBTODO: This should not be memory cached
@mpi.collective
@memory_cache(hashkey=load_hashkey, broadcast=False)
def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(),
Expand Down Expand Up @@ -467,7 +470,12 @@ def __init__(self, code, argtypes):

debug = configuration["debug"]
compiler_instance = compiler(cppargs, ldargs, debug=debug)
dll = _make_so_wrapper(compiler_instance, code, extension, comm)
if configuration['check_src_hashes'] or configuration['debug']:
check_source_hashes(compiler_instance, code, extension, comm)
# This call is cached on disk
so_name = make_so(compiler_instance, code, extension, comm)
# This call is cached in memory by the OS
dll = ctypes.CDLL(so_name)

if isinstance(jitmodule, GlobalKernel):
_add_profiling_events(dll, code.local_kernel.events)
Expand All @@ -490,37 +498,18 @@ def expandWl(ldflags):
yield flag


from pyop2.caching import DictLikeDiskAccess


class CompilerDiskAccess(DictLikeDiskAccess):
@contextmanager
def open(self, *args, **kwargs):
# In the parent class the `open` method is called by `read` as:
# open(filename, mode="rb")
# and the `write` method as:
# open(tempname, mode="wb")
# Here we bypass this and just return the filename (pathlib.Path object)
# letting the read and write methods handle file opening.
if args[0].suffix:
# Writing: drop PID and extension
args[0].touch()
filename = args[0].with_name(args[0].name.split('_p')[0])
else:
# Reading: Add extension
filename = args[0].with_suffix(".so")
def open(self, filename, *args, **kwargs):
yield filename

def write(self, *args, **kwargs):
filename = args[0]
compiler, jitmodule, extension, comm = args[1]
_legacy_make_so(compiler, jitmodule, filename, extension, comm)
def write(self, filename, value):
shutil.copy(value, filename)

def read(self, filename):
try:
return _legacy_load_so(filename)
except OSError as e:
raise FileNotFoundError(e)
if not filename.exists():
raise FileNotFoundError("File not on disk, cache miss")
return filename


def _make_so_hashkey(compiler, jitmodule, extension, comm):
Expand All @@ -533,30 +522,51 @@ def _make_so_hashkey(compiler, jitmodule, extension, comm):
return (compiler, exe, compiler_flags, compiler.ld, compiler.ldflags, jitmodule.cache_key)


def check_source_hashes(compiler, jitmodule, extension, comm):
# Reconstruct hash from filename
hashval = _as_hexdigest(_make_so_hashkey(compiler, jitmodule, extension, comm))
with mpi.temp_internal_comm(comm) as icomm:
matching = icomm.allreduce(hashval, op=_check_op)
if matching != hashval:
# Dump all src code to disk for debugging
output = Path(configuration["cache_dir"]).joinpath("mismatching-kernels")
srcfile = output.with_name(f"src-rank{icomm.rank}.{extension}")
if icomm.rank == 0:
output.mkdir(exist_ok=True)
icomm.barrier()
with open(srcfile, "w") as fh:
fh.write(jitmodule.code_to_compile)
icomm.barrier()
raise CompilationError(f"Generated code differs across ranks (see output in {output})")


FILE_CYCLER = cycle(f"{ii:02x}" for ii in range(256))


@mpi.collective
@parallel_cache(
hashkey=_make_so_hashkey,
cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir']),
broadcast=False
cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir'], extension=".so"),
)
def _make_so_wrapper(compiler, jitmodule, extension, comm):
# The creation of the shared library is handled by the `write` method of
# `CompilerDiskAccess` above.
# JBTODO: This is a bit of a hack...
return (compiler, jitmodule, extension, comm)


@mpi.collective
def _legacy_make_so(compiler, jitmodule, filename, extension, comm):
def make_so(compiler, jitmodule, extension, comm, filename=None):
"""Build a shared library and load it
:arg compiler: The compiler to use to create the shared library.
:arg jitmodule: The JIT Module which can generate the code to compile.
:arg filename: The filename of the library to create.
:arg extension: extension of the source file (c, cpp).
:arg comm: Communicator over which to perform compilation.
:arg filename: Optional
Returns a :class:`ctypes.CDLL` object of the resulting shared
library."""
if filename is None:
tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}")
tempdir.mkdir(exist_ok=True)
filename = tempdir.joinpath(f"foo{next(FILE_CYCLER)}.c")
else:
filename = Path(filename).absolute()
filename.parent.mkdir(exist_ok=True)

# Compilation communicators are reference counted on the PyOP2 comm
icomm = mpi.internal_comm(comm, compiler)
ccomm = mpi.compilation_comm(icomm, compiler)
Expand All @@ -578,26 +588,10 @@ def _legacy_make_so(compiler, jitmodule, filename, extension, comm):
tempname = filename.with_stem(f"{base}_p{pid}.so")
soname = filename.with_suffix(".so")

if configuration['check_src_hashes'] or configuration['debug']:
# Reconstruct hash from filename
hashval = "".join(filename.parts[-2:])
matching = ccomm.allreduce(hashval, op=_check_op)
if matching != hashval:
# Dump all src code to disk for debugging
output = Path(configuration["cache_dir"]).joinpath("mismatching-kernels")
srcfile = output.with_name(f"src-rank{comm.rank}.{extension}")
if ccomm.rank == 0:
output.mkdir(exist_ok=True)
ccomm.barrier()
with open(srcfile, "w") as fh:
fh.write(jitmodule.code_to_compile)
ccomm.barrier()
raise CompilationError(f"Generated code differs across ranks (see output in {output})")

# Compile on compilation communicator (ccomm) rank 0
if comm.rank == 0:
logfile = path.with_name(f"{base}_p{pid}.log")
errfile = path.with_name(f"{base}_p{pid}.err")
logfile = path.joinpath(f"{base}_p{pid}.log")
errfile = path.joinpath(f"{base}_p{pid}.err")
with progress(INFO, 'Compiling wrapper'):
with open(cname, "w") as fh:
fh.write(jitmodule.code_to_compile)
Expand All @@ -610,31 +604,27 @@ def _legacy_make_so(compiler, jitmodule, filename, extension, comm):
_run(cc, logfile, errfile)
# Extract linker specific "cflags" from ldflags
ld = tuple(shlex.split(compiler.ld)) + ('-o', str(tempname), str(oname)) + tuple(expandWl(compiler.ldflags))
_run(ld, logfile, errfile)
_run(ld, logfile, errfile, step="Linker", filemode="a")
# Atomically ensure soname exists
tempname.rename(soname)
# Wait for compilation to complete
ccomm.barrier()
return soname


def _legacy_load_so(filename):
# Load library
dll = ctypes.CDLL(filename)
return dll


def _run(cc, logfile, errfile):
debug(f"Compilation command: {' '.join(cc)}")
def _run(cc, logfile, errfile, step="Compilation", filemode="w"):
debug(f"{step} command: {' '.join(cc)}")
try:
if configuration['no_fork_available']:
cc += ("2>", str(errfile), ">", str(logfile))
redirect = ">" if filemode == "w" else ">>"
cc += (f"2{redirect}", str(errfile), redirect, str(logfile))
cmd = " ".join(cc)
status = os.system(cmd)
if status != 0:
raise subprocess.CalledProcessError(status, cmd)
else:
with open(logfile, "w") as log, open(errfile, "w") as err:
log.write("Compilation command:\n")
with open(logfile, filemode) as log, open(errfile, filemode) as err:
log.write(f"{step} command:\n")
log.write(" ".join(cc))
log.write("\n\n")
subprocess.check_call(cc, stderr=err, stdout=log)
Expand Down
2 changes: 1 addition & 1 deletion pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,13 +360,13 @@ def builder(self):
builder.add_argument(arg)
return builder

# TODO: Wrap with parallel_cached_property
@cached_property
def code_to_compile(self):
"""Return the C/C++ source code as a string."""
from pyop2.codegen.rep2loopy import generate

wrapper = generate(self.builder)
# JBTODO: Expensive? Can this be wrapped with a cache?
code = lp.generate_code_v2(wrapper)

if self.local_kernel.cpp:
Expand Down

0 comments on commit 8c9ca07

Please sign in to comment.