Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
Remove get_so from Compiler base class
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Aug 19, 2024
1 parent be1e58b commit 3b459ce
Showing 1 changed file with 162 additions and 158 deletions.
320 changes: 162 additions & 158 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,163 +275,6 @@ def sniff_compiler_version(self, cpp=False):
def bugfix_cflags(self):
return ()

@staticmethod
def expandWl(ldflags):
"""Generator to expand the `-Wl` compiler flags for use as linker flags
:arg ldflags: linker flags for a compiler command
"""
for flag in ldflags:
if flag.startswith('-Wl'):
for f in flag.lstrip('-Wl')[1:].split(','):
yield f
else:
yield flag

@mpi.collective
def get_so(self, jitmodule, extension):
"""Build a shared library and load it
:arg jitmodule: The JIT Module which can generate the code to compile.
:arg extension: extension of the source file (c, cpp).
Returns a :class:`ctypes.CDLL` object of the resulting shared
library."""

# C or C++
if self._cpp:
compiler = self.cxx
compiler_flags = self.cxxflags
else:
compiler = self.cc
compiler_flags = self.cflags

# Determine cache key
hsh = md5(str(jitmodule.cache_key).encode())
hsh.update(compiler.encode())
if self.ld:
hsh.update(self.ld.encode())
hsh.update("".join(compiler_flags).encode())
hsh.update("".join(self.ldflags).encode())

basename = hsh.hexdigest()

cachedir = configuration['cache_dir']

dirpart, basename = basename[:2], basename[2:]
cachedir = os.path.join(cachedir, dirpart)
pid = os.getpid()
cname = os.path.join(cachedir, f"{basename}_p{pid}.{extension}")
oname = os.path.join(cachedir, f"{basename}_p{pid}.o")
soname = os.path.join(cachedir, f"{basename}.so")
# Link into temporary file, then rename to shared library
# atomically (avoiding races).
tmpname = os.path.join(cachedir, f"{basename}_p{pid}.so.tmp")

if configuration['check_src_hashes'] or configuration['debug']:
matching = self.comm.allreduce(basename, op=_check_op)
if matching != basename:
# Dump all src code to disk for debugging
output = os.path.join(configuration["cache_dir"], "mismatching-kernels")
srcfile = os.path.join(output, f"src-rank{self.comm.rank}.{extension}")
if self.comm.rank == 0:
os.makedirs(output, exist_ok=True)
self.comm.barrier()
with open(srcfile, "w") as f:
f.write(jitmodule.code_to_compile)
self.comm.barrier()
raise CompilationError(f"Generated code differs across ranks (see output in {output})")

# Check whether this shared object already written to disk
try:
dll = ctypes.CDLL(soname)
except OSError:
dll = None
got_dll = bool(dll)
all_dll = self.comm.allgather(got_dll)

# If the library is not loaded _on all ranks_ build it
if not min(all_dll):
if self.comm.rank == 0:
# No need to do this on all ranks
os.makedirs(cachedir, exist_ok=True)
logfile = os.path.join(cachedir, f"{basename}_p{pid}.log")
errfile = os.path.join(cachedir, f"{basename}_p{pid}.err")
with progress(INFO, 'Compiling wrapper'):
with open(cname, "w") as f:
f.write(jitmodule.code_to_compile)
# Compiler also links
if not self.ld:
cc = (compiler,) \
+ compiler_flags \
+ ('-o', tmpname, cname) \
+ self.ldflags
debug(f"Compilation command: {' '.join(cc)}")
with open(logfile, "w") as log, open(errfile, "w") as err:
log.write("Compilation command:\n")
log.write(" ".join(cc))
log.write("\n\n")
try:
if configuration['no_fork_available']:
cc += ["2>", errfile, ">", logfile]
cmd = " ".join(cc)
status = os.system(cmd)
if status != 0:
raise subprocess.CalledProcessError(status, cmd)
else:
subprocess.check_call(cc, stderr=err, stdout=log)
except subprocess.CalledProcessError as e:
raise CompilationError(dedent(f"""
Command "{e.cmd}" return error status {e.returncode}.
Unable to compile code
Compile log in {logfile}
Compile errors in {errfile}
"""))
else:
cc = (compiler,) \
+ compiler_flags \
+ ('-c', '-o', oname, cname)
# Extract linker specific "cflags" from ldflags
ld = tuple(shlex.split(self.ld)) \
+ ('-o', tmpname, oname) \
+ tuple(self.expandWl(self.ldflags))
debug(f"Compilation command: {' '.join(cc)}", )
debug(f"Link command: {' '.join(ld)}")
with open(logfile, "a") as log, open(errfile, "a") as err:
log.write("Compilation command:\n")
log.write(" ".join(cc))
log.write("\n\n")
log.write("Link command:\n")
log.write(" ".join(ld))
log.write("\n\n")
try:
if configuration['no_fork_available']:
cc += ["2>", errfile, ">", logfile]
ld += ["2>>", errfile, ">>", logfile]
cccmd = " ".join(cc)
ldcmd = " ".join(ld)
status = os.system(cccmd)
if status != 0:
raise subprocess.CalledProcessError(status, cccmd)
status = os.system(ldcmd)
if status != 0:
raise subprocess.CalledProcessError(status, ldcmd)
else:
subprocess.check_call(cc, stderr=err, stdout=log)
subprocess.check_call(ld, stderr=err, stdout=log)
except subprocess.CalledProcessError as e:
raise CompilationError(dedent(f"""
Command "{e.cmd}" return error status {e.returncode}.
Unable to compile code
Compile log in {logfile}
Compile errors in {errfile}
"""))
# Atomically ensure soname exists
os.rename(tmpname, soname)
# Wait for compilation to complete
self.comm.barrier()
# Load resulting library
dll = ctypes.CDLL(soname)
return dll


class MacClangCompiler(Compiler):
"""A compiler for building a shared library on Mac systems."""
Expand Down Expand Up @@ -616,7 +459,9 @@ def __init__(self, code, argtypes):
else:
exe = configuration["cc"] or "mpicc"
compiler = sniff_compiler(exe, comm)
dll = compiler(cppargs, ldargs, cpp=cpp, comm=comm).get_so(code, extension)

compiler_instance = compiler(cppargs, ldargs, cpp=cpp, comm=comm)
dll = make_so(compiler_instance, code, extension)

if isinstance(jitmodule, GlobalKernel):
_add_profiling_events(dll, code.local_kernel.events)
Expand All @@ -627,6 +472,165 @@ def __init__(self, code, argtypes):
return fn


def expandWl(ldflags):
"""Generator to expand the `-Wl` compiler flags for use as linker flags
:arg ldflags: linker flags for a compiler command
"""
for flag in ldflags:
if flag.startswith('-Wl'):
for f in flag.lstrip('-Wl')[1:].split(','):
yield f
else:
yield flag


@mpi.collective
def make_so(compiler, jitmodule, extension):
"""Build a shared library and load it
:arg compiler: The compiler to use to create the shared library.
:arg jitmodule: The JIT Module which can generate the code to compile.
:arg extension: extension of the source file (c, cpp).
Returns a :class:`ctypes.CDLL` object of the resulting shared
library."""

# C or C++
if compiler._cpp:
exe = compiler.cxx
compiler_flags = compiler.cxxflags
else:
exe = compiler.cc
compiler_flags = compiler.cflags

# Determine cache key
hsh = md5(str(jitmodule.cache_key).encode())
hsh.update(exe.encode())
if compiler.ld:
hsh.update(compiler.ld.encode())
hsh.update("".join(compiler_flags).encode())
hsh.update("".join(compiler.ldflags).encode())

basename = hsh.hexdigest()

cachedir = configuration['cache_dir']

dirpart, basename = basename[:2], basename[2:]
cachedir = os.path.join(cachedir, dirpart)
pid = os.getpid()
cname = os.path.join(cachedir, f"{basename}_p{pid}.{extension}")
oname = os.path.join(cachedir, f"{basename}_p{pid}.o")
soname = os.path.join(cachedir, f"{basename}.so")
# Link into temporary file, then rename to shared library
# atomically (avoiding races).
tmpname = os.path.join(cachedir, f"{basename}_p{pid}.so.tmp")

if configuration['check_src_hashes'] or configuration['debug']:
matching = compiler.comm.allreduce(basename, op=_check_op)
if matching != basename:
# Dump all src code to disk for debugging
output = os.path.join(configuration["cache_dir"], "mismatching-kernels")
srcfile = os.path.join(output, f"src-rank{compiler.comm.rank}.{extension}")
if compiler.comm.rank == 0:
os.makedirs(output, exist_ok=True)
compiler.comm.barrier()
with open(srcfile, "w") as f:
f.write(jitmodule.code_to_compile)
compiler.comm.barrier()
raise CompilationError(f"Generated code differs across ranks (see output in {output})")

# Check whether this shared object already written to disk
try:
dll = ctypes.CDLL(soname)
except OSError:
dll = None
got_dll = bool(dll)
all_dll = compiler.comm.allgather(got_dll)

# If the library is not loaded _on all ranks_ build it
if not min(all_dll):
if compiler.comm.rank == 0:
# No need to do this on all ranks
os.makedirs(cachedir, exist_ok=True)
logfile = os.path.join(cachedir, f"{basename}_p{pid}.log")
errfile = os.path.join(cachedir, f"{basename}_p{pid}.err")
with progress(INFO, 'Compiling wrapper'):
with open(cname, "w") as f:
f.write(jitmodule.code_to_compile)
# Compiler also links
if not compiler.ld:
cc = (exe,) \
+ compiler_flags \
+ ('-o', tmpname, cname) \
+ compiler.ldflags
debug(f"Compilation command: {' '.join(cc)}")
with open(logfile, "w") as log, open(errfile, "w") as err:
log.write("Compilation command:\n")
log.write(" ".join(cc))
log.write("\n\n")
try:
if configuration['no_fork_available']:
cc += ["2>", errfile, ">", logfile]
cmd = " ".join(cc)
status = os.system(cmd)
if status != 0:
raise subprocess.CalledProcessError(status, cmd)
else:
subprocess.check_call(cc, stderr=err, stdout=log)
except subprocess.CalledProcessError as e:
raise CompilationError(dedent(f"""
Command "{e.cmd}" return error status {e.returncode}.
Unable to compile code
Compile log in {logfile}
Compile errors in {errfile}
"""))
else:
cc = (exe,) \
+ compiler_flags \
+ ('-c', '-o', oname, cname)
# Extract linker specific "cflags" from ldflags
ld = tuple(shlex.split(compiler.ld)) \
+ ('-o', tmpname, oname) \
+ tuple(expandWl(compiler.ldflags))
debug(f"Compilation command: {' '.join(cc)}", )
debug(f"Link command: {' '.join(ld)}")
with open(logfile, "a") as log, open(errfile, "a") as err:
log.write("Compilation command:\n")
log.write(" ".join(cc))
log.write("\n\n")
log.write("Link command:\n")
log.write(" ".join(ld))
log.write("\n\n")
try:
if configuration['no_fork_available']:
cc += ["2>", errfile, ">", logfile]
ld += ["2>>", errfile, ">>", logfile]
cccmd = " ".join(cc)
ldcmd = " ".join(ld)
status = os.system(cccmd)
if status != 0:
raise subprocess.CalledProcessError(status, cccmd)
status = os.system(ldcmd)
if status != 0:
raise subprocess.CalledProcessError(status, ldcmd)
else:
subprocess.check_call(cc, stderr=err, stdout=log)
subprocess.check_call(ld, stderr=err, stdout=log)
except subprocess.CalledProcessError as e:
raise CompilationError(dedent(f"""
Command "{e.cmd}" return error status {e.returncode}.
Unable to compile code
Compile log in {logfile}
Compile errors in {errfile}
"""))
# Atomically ensure soname exists
os.rename(tmpname, soname)
# Wait for compilation to complete
compiler.comm.barrier()
# Load resulting library
dll = ctypes.CDLL(soname)
return dll


def _add_profiling_events(dll, events):
"""
If PyOP2 is in profiling mode, events are attached to dll to profile the local linear algebra calls.
Expand Down

0 comments on commit 3b459ce

Please sign in to comment.