From b35ff863195659e03c33c0a41c47cddd1a6b34e0 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Mon, 19 Aug 2024 19:01:38 +0100 Subject: [PATCH] Remove comm from Compiler class --- pyop2/compilation.py | 105 ++++++++++++++++++++++--------------------- 1 file changed, 55 insertions(+), 50 deletions(-) diff --git a/pyop2/compilation.py b/pyop2/compilation.py index e0e58a15e..aec360913 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -43,6 +43,7 @@ from hashlib import md5 from packaging.version import Version, InvalidVersion from textwrap import dedent +from functools import partial from pyop2 import mpi @@ -87,6 +88,36 @@ def set_default_compiler(compiler): ) +def sniff_compiler_version(compiler, cpp=False): + """Attempt to determine the compiler version number. + + :arg compiler: Instance of compiler to sniff the version of + :arg cpp: If set to True will use the C++ compiler rather than + the C compiler to determine the version number. + """ + # Note: + # Sniffing the compiler version for very large numbers of + # MPI ranks is expensive, ensure this is only run on rank 0 + exe = compiler.cxx if cpp else compiler.cc + version = None + # `-dumpversion` is not sufficient to get the whole version string (for some compilers), + # but other compilers do not implement `-dumpfullversion`! + for dumpstring in ["-dumpfullversion", "-dumpversion"]: + try: + output = subprocess.run( + [exe, dumpstring], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + encoding="utf-8" + ).stdout + version = Version(output) + break + except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion): + continue + return version + + def sniff_compiler(exe, comm=mpi.COMM_WORLD): """Obtain the correct compiler class by calling the compiler executable. @@ -153,6 +184,11 @@ def sniff_compiler(exe, comm=mpi.COMM_WORLD): else: compiler = AnonymousCompiler + # Now try and get a version number + temp = Compiler() + version = sniff_compiler_version(temp) + compiler = partial(compiler, version=version) + return comm.bcast(compiler, 0) @@ -186,9 +222,8 @@ class Compiler(ABC): _optflags = () _debugflags = () - def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, comm=None): - # Set compiler version ASAP since it is used in __repr__ - self.version = None + def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, version=None): + self.version = version self._extra_compiler_flags = tuple(extra_compiler_flags) self._extra_linker_flags = tuple(extra_linker_flags) @@ -196,11 +231,6 @@ def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, co self._cpp = cpp self._debug = configuration["debug"] - # Compilation communicators are reference counted on the PyOP2 comm - self.pcomm = mpi.internal_comm(comm, self) - self.comm = mpi.compilation_comm(self.pcomm, self) - self.sniff_compiler_version() - def __repr__(self): return f"<{self._name} compiler, version {self.version or 'unknown'}>" @@ -242,35 +272,6 @@ def ldflags(self): ldflags += tuple(shlex.split(configuration["ldflags"])) return ldflags - def sniff_compiler_version(self, cpp=False): - """Attempt to determine the compiler version number. - - :arg cpp: If set to True will use the C++ compiler rather than - the C compiler to determine the version number. - """ - # Note: - # Sniffing the compiler version for very large numbers of - # MPI ranks is expensive - exe = self.cxx if cpp else self.cc - version = None - if self.comm.rank == 0: - # `-dumpversion` is not sufficient to get the whole version string (for some compilers), - # but other compilers do not implement `-dumpfullversion`! - for dumpstring in ["-dumpfullversion", "-dumpversion"]: - try: - output = subprocess.run( - [exe, dumpstring], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=True, - encoding="utf-8" - ).stdout - version = Version(output) - break - except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion): - continue - self.version = self.comm.bcast(version, 0) - @property def bugfix_cflags(self): return () @@ -460,8 +461,8 @@ def __init__(self, code, argtypes): exe = configuration["cc"] or "mpicc" compiler = sniff_compiler(exe, comm) - compiler_instance = compiler(cppargs, ldargs, cpp=cpp, comm=comm) - dll = make_so(compiler_instance, code, extension) + compiler_instance = compiler(cppargs, ldargs, cpp=cpp) + dll = make_so(compiler_instance, code, extension, comm) if isinstance(jitmodule, GlobalKernel): _add_profiling_events(dll, code.local_kernel.events) @@ -485,14 +486,18 @@ def expandWl(ldflags): @mpi.collective -def make_so(compiler, jitmodule, extension): +def make_so(compiler, jitmodule, extension, comm): """Build a shared library and load it :arg compiler: The compiler to use to create the shared library. :arg jitmodule: The JIT Module which can generate the code to compile. :arg extension: extension of the source file (c, cpp). + :arg comm: Communicator over which to perform compilation. Returns a :class:`ctypes.CDLL` object of the resulting shared library.""" + # Compilation communicators are reference counted on the PyOP2 comm + pcomm = mpi.internal_comm(comm, compiler) + comm = mpi.compilation_comm(pcomm, compiler) # C or C++ if compiler._cpp: @@ -510,9 +515,9 @@ def make_so(compiler, jitmodule, extension): hsh.update("".join(compiler_flags).encode()) hsh.update("".join(compiler.ldflags).encode()) - basename = hsh.hexdigest() + basename = hsh.hexdigest() # This is hash key - cachedir = configuration['cache_dir'] + cachedir = configuration['cache_dir'] # This is cachedir dirpart, basename = basename[:2], basename[2:] cachedir = os.path.join(cachedir, dirpart) @@ -525,17 +530,17 @@ def make_so(compiler, jitmodule, extension): tmpname = os.path.join(cachedir, f"{basename}_p{pid}.so.tmp") if configuration['check_src_hashes'] or configuration['debug']: - matching = compiler.comm.allreduce(basename, op=_check_op) + matching = comm.allreduce(basename, op=_check_op) if matching != basename: # Dump all src code to disk for debugging output = os.path.join(configuration["cache_dir"], "mismatching-kernels") - srcfile = os.path.join(output, f"src-rank{compiler.comm.rank}.{extension}") - if compiler.comm.rank == 0: + srcfile = os.path.join(output, f"src-rank{comm.rank}.{extension}") + if comm.rank == 0: os.makedirs(output, exist_ok=True) - compiler.comm.barrier() + comm.barrier() with open(srcfile, "w") as f: f.write(jitmodule.code_to_compile) - compiler.comm.barrier() + comm.barrier() raise CompilationError(f"Generated code differs across ranks (see output in {output})") # Check whether this shared object already written to disk @@ -544,11 +549,11 @@ def make_so(compiler, jitmodule, extension): except OSError: dll = None got_dll = bool(dll) - all_dll = compiler.comm.allgather(got_dll) + all_dll = comm.allgather(got_dll) # If the library is not loaded _on all ranks_ build it if not min(all_dll): - if compiler.comm.rank == 0: + if comm.rank == 0: # No need to do this on all ranks os.makedirs(cachedir, exist_ok=True) logfile = os.path.join(cachedir, f"{basename}_p{pid}.log") @@ -625,7 +630,7 @@ def make_so(compiler, jitmodule, extension): # Atomically ensure soname exists os.rename(tmpname, soname) # Wait for compilation to complete - compiler.comm.barrier() + comm.barrier() # Load resulting library dll = ctypes.CDLL(soname) return dll