From 4b33593e8f24c77ee29ea6196ba89bd37bf6d237 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Tue, 20 Aug 2024 00:20:46 +0100 Subject: [PATCH] Use caching.disk_only_cache for make_so --- pyop2/caching.py | 9 +- pyop2/compilation.py | 293 ++++++++++++++++-------------- test/unit/test_updated_caching.py | 3 + 3 files changed, 161 insertions(+), 144 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 65954beda..a80cc767b 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -267,7 +267,7 @@ def __getitem__(self, key): """ filepath = Path(self.cachedir, key[0][:2], key[0][2:] + key[1]) try: - with self.open(filepath, "rb") as fh: + with self.open(filepath, mode="rb") as fh: value = self.read(fh) except FileNotFoundError: raise KeyError("File not on disk, cache miss") @@ -285,7 +285,7 @@ def __setitem__(self, key, value): tempfile = basedir.joinpath(f"{k2}_p{os.getpid()}.tmp") filepath = basedir.joinpath(k2) - with self.open(tempfile, "wb") as fh: + with self.open(tempfile, mode="wb") as fh: self.write(fh, value) tempfile.rename(filepath) @@ -359,6 +359,8 @@ def get(self, key, default=None): self.hit += 1 return value + # JBTODO: Only instrument get, since we have to use get and get item in wrapper + # OR... find away around the hack in compilation.py def __getitem__(self, key): try: value = super().__getitem__(key) @@ -465,7 +467,8 @@ def wrapper(*args, **kwargs): if value is CACHE_MISS: value = func(*args, **kwargs) - return local_cache.setdefault(key, value) + local_cache[key] = value + return local_cache[key] return wrapper return decorator diff --git a/pyop2/compilation.py b/pyop2/compilation.py index aec360913..2e1954239 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -44,10 +44,12 @@ from packaging.version import Version, InvalidVersion from textwrap import dedent from functools import partial +from pathlib import Path +from contextlib import contextmanager from pyop2 import mpi -from pyop2.caching import memory_cache, default_parallel_hashkey +from pyop2.caching import parallel_cache, memory_cache, default_parallel_hashkey from pyop2.configuration import configuration from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError @@ -204,10 +206,8 @@ class Compiler(ABC): (optional, prepended to any flags specified as the ldflags configuration option). The environment variable ``PYOP2_LDFLAGS`` can also be used to extend these options. - :arg cpp: Should we try and use the C++ compiler instead of the C - compiler?. - :kwarg comm: Optional communicator to compile the code on - (defaults to pyop2.mpi.COMM_WORLD). + :arg version: (Optional) usually sniffed by loader. + :arg debug: Whether to use debugging compiler flags. """ _name = "unknown" @@ -222,17 +222,22 @@ class Compiler(ABC): _optflags = () _debugflags = () - def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, version=None): - self.version = version - + def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), version=None, debug=False): self._extra_compiler_flags = tuple(extra_compiler_flags) self._extra_linker_flags = tuple(extra_linker_flags) - - self._cpp = cpp - self._debug = configuration["debug"] + self._version = version + self._debug = debug def __repr__(self): - return f"<{self._name} compiler, version {self.version or 'unknown'}>" + string = f"{self.__class__.__name__}(" + string += f"extra_compiler_flags={self._extra_compiler_flags}, " + string += f"extra_linker_flags={self._extra_linker_flags}, " + string += f"version={self._version!r}, " + string += f"debug={self._debug})" + return string + + def __str__(self): + return f"<{self._name} compiler, version {self._version or 'unknown'}>" @property def cc(self): @@ -319,7 +324,7 @@ class LinuxGnuCompiler(Compiler): @property def bugfix_cflags(self): """Flags to work around bugs in compilers.""" - ver = self.version + ver = self._version cflags = () if Version("4.8.0") <= ver < Version("4.9.0"): # GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=61068 @@ -448,21 +453,21 @@ def __init__(self, code, argtypes): else: raise ValueError("Don't know how to compile code of type %r" % type(jitmodule)) - cpp = (extension == "cpp") global _compiler if _compiler: # Use the global compiler if it has been set compiler = _compiler else: # Sniff compiler from executable - if cpp: + if extension == "cpp": exe = configuration["cxx"] or "mpicxx" else: exe = configuration["cc"] or "mpicc" compiler = sniff_compiler(exe, comm) - compiler_instance = compiler(cppargs, ldargs, cpp=cpp) - dll = make_so(compiler_instance, code, extension, comm) + debug = configuration["debug"] + compiler_instance = compiler(cppargs, ldargs, debug=debug) + dll = _make_so_wrapper(compiler_instance, code, extension, comm) if isinstance(jitmodule, GlobalKernel): _add_profiling_events(dll, code.local_kernel.events) @@ -485,157 +490,163 @@ def expandWl(ldflags): yield flag +from pyop2.caching import DictLikeDiskAccess + + +class CompilerDiskAccess(DictLikeDiskAccess): + @contextmanager + def open(self, *args, **kwargs): + # In the parent class the `open` method is called by `read` as: + # open(filename, mode="rb") + # and the `write` method as: + # open(tempname, mode="wb") + # Here we bypass this and just return the filename (pathlib.Path object) + # letting the read and write methods handle file opening. + if args[0].suffix: + # Writing: drop PID and extension + args[0].touch() + filename = args[0].with_name(args[0].name.split('_p')[0]) + else: + # Reading: Add extension + filename = args[0].with_suffix(".so") + yield filename + + def write(self, *args, **kwargs): + filename = args[0] + compiler, jitmodule, extension, comm = args[1] + _legacy_make_so(compiler, jitmodule, filename, extension, comm) + + def read(self, filename): + try: + return _legacy_load_so(filename) + except OSError as e: + raise FileNotFoundError(e) + + +def _make_so_hashkey(compiler, jitmodule, extension, comm): + if extension == "cpp": + exe = compiler.cxx + compiler_flags = compiler.cxxflags + else: + exe = compiler.cc + compiler_flags = compiler.cflags + return (compiler, exe, compiler_flags, compiler.ld, compiler.ldflags, jitmodule.cache_key) + + +@mpi.collective +@parallel_cache( + hashkey=_make_so_hashkey, + cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir']), + broadcast=False +) +def _make_so_wrapper(compiler, jitmodule, extension, comm): + # The creation of the shared library is handled by the `write` method of + # `CompilerDiskAccess` above. + # JBTODO: This is a bit of a hack... + return (compiler, jitmodule, extension, comm) + + @mpi.collective -def make_so(compiler, jitmodule, extension, comm): +def _legacy_make_so(compiler, jitmodule, filename, extension, comm): """Build a shared library and load it :arg compiler: The compiler to use to create the shared library. :arg jitmodule: The JIT Module which can generate the code to compile. + :arg filename: The filename of the library to create. :arg extension: extension of the source file (c, cpp). :arg comm: Communicator over which to perform compilation. Returns a :class:`ctypes.CDLL` object of the resulting shared library.""" # Compilation communicators are reference counted on the PyOP2 comm - pcomm = mpi.internal_comm(comm, compiler) - comm = mpi.compilation_comm(pcomm, compiler) + icomm = mpi.internal_comm(comm, compiler) + ccomm = mpi.compilation_comm(icomm, compiler) # C or C++ - if compiler._cpp: + if extension == "cpp": exe = compiler.cxx compiler_flags = compiler.cxxflags else: exe = compiler.cc compiler_flags = compiler.cflags - # Determine cache key - hsh = md5(str(jitmodule.cache_key).encode()) - hsh.update(exe.encode()) - if compiler.ld: - hsh.update(compiler.ld.encode()) - hsh.update("".join(compiler_flags).encode()) - hsh.update("".join(compiler.ldflags).encode()) - - basename = hsh.hexdigest() # This is hash key - - cachedir = configuration['cache_dir'] # This is cachedir - - dirpart, basename = basename[:2], basename[2:] - cachedir = os.path.join(cachedir, dirpart) + base = filename.name + path = filename.parent pid = os.getpid() - cname = os.path.join(cachedir, f"{basename}_p{pid}.{extension}") - oname = os.path.join(cachedir, f"{basename}_p{pid}.o") - soname = os.path.join(cachedir, f"{basename}.so") - # Link into temporary file, then rename to shared library - # atomically (avoiding races). - tmpname = os.path.join(cachedir, f"{basename}_p{pid}.so.tmp") + cname = filename.with_name(f"{base}_p{pid}.{extension}") + oname = filename.with_name(f"{base}_p{pid}.o") + # Link into temporary file, then rename to shared library atomically (avoiding races). + tempname = filename.with_stem(f"{base}_p{pid}.so") + soname = filename.with_suffix(".so") if configuration['check_src_hashes'] or configuration['debug']: - matching = comm.allreduce(basename, op=_check_op) - if matching != basename: + # Reconstruct hash from filename + hashval = "".join(filename.parts[-2:]) + matching = ccomm.allreduce(hashval, op=_check_op) + if matching != hashval: # Dump all src code to disk for debugging - output = os.path.join(configuration["cache_dir"], "mismatching-kernels") - srcfile = os.path.join(output, f"src-rank{comm.rank}.{extension}") - if comm.rank == 0: - os.makedirs(output, exist_ok=True) - comm.barrier() - with open(srcfile, "w") as f: - f.write(jitmodule.code_to_compile) - comm.barrier() + output = Path(configuration["cache_dir"]).joinpath("mismatching-kernels") + srcfile = output.with_name(f"src-rank{comm.rank}.{extension}") + if ccomm.rank == 0: + output.mkdir(exist_ok=True) + ccomm.barrier() + with open(srcfile, "w") as fh: + fh.write(jitmodule.code_to_compile) + ccomm.barrier() raise CompilationError(f"Generated code differs across ranks (see output in {output})") - # Check whether this shared object already written to disk - try: - dll = ctypes.CDLL(soname) - except OSError: - dll = None - got_dll = bool(dll) - all_dll = comm.allgather(got_dll) - - # If the library is not loaded _on all ranks_ build it - if not min(all_dll): - if comm.rank == 0: - # No need to do this on all ranks - os.makedirs(cachedir, exist_ok=True) - logfile = os.path.join(cachedir, f"{basename}_p{pid}.log") - errfile = os.path.join(cachedir, f"{basename}_p{pid}.err") - with progress(INFO, 'Compiling wrapper'): - with open(cname, "w") as f: - f.write(jitmodule.code_to_compile) - # Compiler also links - if not compiler.ld: - cc = (exe,) \ - + compiler_flags \ - + ('-o', tmpname, cname) \ - + compiler.ldflags - debug(f"Compilation command: {' '.join(cc)}") - with open(logfile, "w") as log, open(errfile, "w") as err: - log.write("Compilation command:\n") - log.write(" ".join(cc)) - log.write("\n\n") - try: - if configuration['no_fork_available']: - cc += ["2>", errfile, ">", logfile] - cmd = " ".join(cc) - status = os.system(cmd) - if status != 0: - raise subprocess.CalledProcessError(status, cmd) - else: - subprocess.check_call(cc, stderr=err, stdout=log) - except subprocess.CalledProcessError as e: - raise CompilationError(dedent(f""" - Command "{e.cmd}" return error status {e.returncode}. - Unable to compile code - Compile log in {logfile} - Compile errors in {errfile} - """)) - else: - cc = (exe,) \ - + compiler_flags \ - + ('-c', '-o', oname, cname) - # Extract linker specific "cflags" from ldflags - ld = tuple(shlex.split(compiler.ld)) \ - + ('-o', tmpname, oname) \ - + tuple(expandWl(compiler.ldflags)) - debug(f"Compilation command: {' '.join(cc)}", ) - debug(f"Link command: {' '.join(ld)}") - with open(logfile, "a") as log, open(errfile, "a") as err: - log.write("Compilation command:\n") - log.write(" ".join(cc)) - log.write("\n\n") - log.write("Link command:\n") - log.write(" ".join(ld)) - log.write("\n\n") - try: - if configuration['no_fork_available']: - cc += ["2>", errfile, ">", logfile] - ld += ["2>>", errfile, ">>", logfile] - cccmd = " ".join(cc) - ldcmd = " ".join(ld) - status = os.system(cccmd) - if status != 0: - raise subprocess.CalledProcessError(status, cccmd) - status = os.system(ldcmd) - if status != 0: - raise subprocess.CalledProcessError(status, ldcmd) - else: - subprocess.check_call(cc, stderr=err, stdout=log) - subprocess.check_call(ld, stderr=err, stdout=log) - except subprocess.CalledProcessError as e: - raise CompilationError(dedent(f""" - Command "{e.cmd}" return error status {e.returncode}. - Unable to compile code - Compile log in {logfile} - Compile errors in {errfile} - """)) - # Atomically ensure soname exists - os.rename(tmpname, soname) - # Wait for compilation to complete - comm.barrier() - # Load resulting library - dll = ctypes.CDLL(soname) + # Compile on compilation communicator (ccomm) rank 0 + if comm.rank == 0: + logfile = path.with_name(f"{base}_p{pid}.log") + errfile = path.with_name(f"{base}_p{pid}.err") + with progress(INFO, 'Compiling wrapper'): + with open(cname, "w") as fh: + fh.write(jitmodule.code_to_compile) + # Compiler also links + if not compiler.ld: + cc = (exe,) + compiler_flags + ('-o', str(tempname), str(cname)) + compiler.ldflags + _run(cc, logfile, errfile) + else: + cc = (exe,) + compiler_flags + ('-c', '-o', oname, cname) + _run(cc, logfile, errfile) + # Extract linker specific "cflags" from ldflags + ld = tuple(shlex.split(compiler.ld)) + ('-o', str(tempname), str(oname)) + tuple(expandWl(compiler.ldflags)) + _run(ld, logfile, errfile) + # Atomically ensure soname exists + tempname.rename(soname) + # Wait for compilation to complete + ccomm.barrier() + + +def _legacy_load_so(filename): + # Load library + dll = ctypes.CDLL(filename) return dll +def _run(cc, logfile, errfile): + debug(f"Compilation command: {' '.join(cc)}") + try: + if configuration['no_fork_available']: + cc += ("2>", str(errfile), ">", str(logfile)) + cmd = " ".join(cc) + status = os.system(cmd) + if status != 0: + raise subprocess.CalledProcessError(status, cmd) + else: + with open(logfile, "w") as log, open(errfile, "w") as err: + log.write("Compilation command:\n") + log.write(" ".join(cc)) + log.write("\n\n") + subprocess.check_call(cc, stderr=err, stdout=log) + except subprocess.CalledProcessError as e: + raise CompilationError(dedent(f""" + Command "{e.cmd}" return error status {e.returncode}. + Unable to compile code + Compile log in {logfile!s} + Compile errors in {errfile!s} + """)) + + def _add_profiling_events(dll, events): """ If PyOP2 is in profiling mode, events are attached to dll to profile the local linear algebra calls. diff --git a/test/unit/test_updated_caching.py b/test/unit/test_updated_caching.py index 2c8ee53bf..93af9f46c 100644 --- a/test/unit/test_updated_caching.py +++ b/test/unit/test_updated_caching.py @@ -115,12 +115,15 @@ def test_function_args_different(request, state, decorator, uncached_function, t ]) def test_function_over_different_comms(request, state, decorator, uncached_function, tmpdir): if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}: + # In parallel different ranks can get different tempdirs, we just want one + tmpdir = COMM_WORLD.bcast(tmpdir, root=0) kwargs = {"cachedir": tmpdir} else: kwargs = {} cached_function = function_factory(state, decorator, uncached_function, **kwargs) assert state.value == 0 + for ii in range(10): color = 0 if COMM_WORLD.rank < 2 else MPI.UNDEFINED comm12 = COMM_WORLD.Split(color=color)