Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
Remove comm from Compiler class
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Aug 19, 2024
1 parent 3b459ce commit b35ff86
Showing 1 changed file with 55 additions and 50 deletions.
105 changes: 55 additions & 50 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from hashlib import md5
from packaging.version import Version, InvalidVersion
from textwrap import dedent
from functools import partial


from pyop2 import mpi
Expand Down Expand Up @@ -87,6 +88,36 @@ def set_default_compiler(compiler):
)


def sniff_compiler_version(compiler, cpp=False):
"""Attempt to determine the compiler version number.
:arg compiler: Instance of compiler to sniff the version of
:arg cpp: If set to True will use the C++ compiler rather than
the C compiler to determine the version number.
"""
# Note:
# Sniffing the compiler version for very large numbers of
# MPI ranks is expensive, ensure this is only run on rank 0
exe = compiler.cxx if cpp else compiler.cc
version = None
# `-dumpversion` is not sufficient to get the whole version string (for some compilers),
# but other compilers do not implement `-dumpfullversion`!
for dumpstring in ["-dumpfullversion", "-dumpversion"]:
try:
output = subprocess.run(
[exe, dumpstring],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True,
encoding="utf-8"
).stdout
version = Version(output)
break
except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion):
continue
return version


def sniff_compiler(exe, comm=mpi.COMM_WORLD):
"""Obtain the correct compiler class by calling the compiler executable.
Expand Down Expand Up @@ -153,6 +184,11 @@ def sniff_compiler(exe, comm=mpi.COMM_WORLD):
else:
compiler = AnonymousCompiler

# Now try and get a version number
temp = Compiler()
version = sniff_compiler_version(temp)
compiler = partial(compiler, version=version)

return comm.bcast(compiler, 0)


Expand Down Expand Up @@ -186,21 +222,15 @@ class Compiler(ABC):
_optflags = ()
_debugflags = ()

def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, comm=None):
# Set compiler version ASAP since it is used in __repr__
self.version = None
def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, version=None):
self.version = version

self._extra_compiler_flags = tuple(extra_compiler_flags)
self._extra_linker_flags = tuple(extra_linker_flags)

self._cpp = cpp
self._debug = configuration["debug"]

# Compilation communicators are reference counted on the PyOP2 comm
self.pcomm = mpi.internal_comm(comm, self)
self.comm = mpi.compilation_comm(self.pcomm, self)
self.sniff_compiler_version()

def __repr__(self):
return f"<{self._name} compiler, version {self.version or 'unknown'}>"

Expand Down Expand Up @@ -242,35 +272,6 @@ def ldflags(self):
ldflags += tuple(shlex.split(configuration["ldflags"]))
return ldflags

def sniff_compiler_version(self, cpp=False):
"""Attempt to determine the compiler version number.
:arg cpp: If set to True will use the C++ compiler rather than
the C compiler to determine the version number.
"""
# Note:
# Sniffing the compiler version for very large numbers of
# MPI ranks is expensive
exe = self.cxx if cpp else self.cc
version = None
if self.comm.rank == 0:
# `-dumpversion` is not sufficient to get the whole version string (for some compilers),
# but other compilers do not implement `-dumpfullversion`!
for dumpstring in ["-dumpfullversion", "-dumpversion"]:
try:
output = subprocess.run(
[exe, dumpstring],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True,
encoding="utf-8"
).stdout
version = Version(output)
break
except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion):
continue
self.version = self.comm.bcast(version, 0)

@property
def bugfix_cflags(self):
return ()
Expand Down Expand Up @@ -460,8 +461,8 @@ def __init__(self, code, argtypes):
exe = configuration["cc"] or "mpicc"
compiler = sniff_compiler(exe, comm)

compiler_instance = compiler(cppargs, ldargs, cpp=cpp, comm=comm)
dll = make_so(compiler_instance, code, extension)
compiler_instance = compiler(cppargs, ldargs, cpp=cpp)
dll = make_so(compiler_instance, code, extension, comm)

if isinstance(jitmodule, GlobalKernel):
_add_profiling_events(dll, code.local_kernel.events)
Expand All @@ -485,14 +486,18 @@ def expandWl(ldflags):


@mpi.collective
def make_so(compiler, jitmodule, extension):
def make_so(compiler, jitmodule, extension, comm):
"""Build a shared library and load it
:arg compiler: The compiler to use to create the shared library.
:arg jitmodule: The JIT Module which can generate the code to compile.
:arg extension: extension of the source file (c, cpp).
:arg comm: Communicator over which to perform compilation.
Returns a :class:`ctypes.CDLL` object of the resulting shared
library."""
# Compilation communicators are reference counted on the PyOP2 comm
pcomm = mpi.internal_comm(comm, compiler)
comm = mpi.compilation_comm(pcomm, compiler)

# C or C++
if compiler._cpp:
Expand All @@ -510,9 +515,9 @@ def make_so(compiler, jitmodule, extension):
hsh.update("".join(compiler_flags).encode())
hsh.update("".join(compiler.ldflags).encode())

basename = hsh.hexdigest()
basename = hsh.hexdigest() # This is hash key

cachedir = configuration['cache_dir']
cachedir = configuration['cache_dir'] # This is cachedir

dirpart, basename = basename[:2], basename[2:]
cachedir = os.path.join(cachedir, dirpart)
Expand All @@ -525,17 +530,17 @@ def make_so(compiler, jitmodule, extension):
tmpname = os.path.join(cachedir, f"{basename}_p{pid}.so.tmp")

if configuration['check_src_hashes'] or configuration['debug']:
matching = compiler.comm.allreduce(basename, op=_check_op)
matching = comm.allreduce(basename, op=_check_op)
if matching != basename:
# Dump all src code to disk for debugging
output = os.path.join(configuration["cache_dir"], "mismatching-kernels")
srcfile = os.path.join(output, f"src-rank{compiler.comm.rank}.{extension}")
if compiler.comm.rank == 0:
srcfile = os.path.join(output, f"src-rank{comm.rank}.{extension}")
if comm.rank == 0:
os.makedirs(output, exist_ok=True)
compiler.comm.barrier()
comm.barrier()
with open(srcfile, "w") as f:
f.write(jitmodule.code_to_compile)
compiler.comm.barrier()
comm.barrier()
raise CompilationError(f"Generated code differs across ranks (see output in {output})")

# Check whether this shared object already written to disk
Expand All @@ -544,11 +549,11 @@ def make_so(compiler, jitmodule, extension):
except OSError:
dll = None
got_dll = bool(dll)
all_dll = compiler.comm.allgather(got_dll)
all_dll = comm.allgather(got_dll)

# If the library is not loaded _on all ranks_ build it
if not min(all_dll):
if compiler.comm.rank == 0:
if comm.rank == 0:
# No need to do this on all ranks
os.makedirs(cachedir, exist_ok=True)
logfile = os.path.join(cachedir, f"{basename}_p{pid}.log")
Expand Down Expand Up @@ -625,7 +630,7 @@ def make_so(compiler, jitmodule, extension):
# Atomically ensure soname exists
os.rename(tmpname, soname)
# Wait for compilation to complete
compiler.comm.barrier()
comm.barrier()
# Load resulting library
dll = ctypes.CDLL(soname)
return dll
Expand Down

0 comments on commit b35ff86

Please sign in to comment.