Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Fix for massively parallel performance regression #720

Merged
merged 3 commits into from
May 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 85 additions & 89 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _check_hashes(x, y, datatype):


def set_default_compiler(compiler):
"""Set the PyOP2 default compiler, globally.
"""Set the PyOP2 default compiler, globally over COMM_WORLD.

:arg compiler: String with name or path to compiler executable
OR a subclass of the Compiler class
Expand All @@ -85,66 +85,73 @@ def set_default_compiler(compiler):
)


def sniff_compiler(exe):
def sniff_compiler(exe, comm=mpi.COMM_WORLD):
"""Obtain the correct compiler class by calling the compiler executable.

:arg exe: String with name or path to compiler executable
JDBetteridge marked this conversation as resolved.
Show resolved Hide resolved
:arg comm: Comm over which we want to determine the compiler type
:returns: A compiler class
"""
try:
output = subprocess.run(
[exe, "--version"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True,
encoding="utf-8"
).stdout
except (subprocess.CalledProcessError, UnicodeDecodeError):
output = ""

# Find the name of the compiler family
if output.startswith("gcc") or output.startswith("g++"):
name = "GNU"
elif output.startswith("clang"):
name = "clang"
elif output.startswith("Apple LLVM") or output.startswith("Apple clang"):
name = "clang"
elif output.startswith("icc"):
name = "Intel"
elif "Cray" in output.split("\n")[0]:
# Cray is more awkward eg:
# Cray clang version 11.0.4 (<some_hash>)
# gcc (GCC) 9.3.0 20200312 (Cray Inc.)
name = "Cray"
else:
name = "unknown"

# Set the compiler instance based on the platform (and architecture)
if sys.platform.find("linux") == 0:
if name == "Intel":
compiler = LinuxIntelCompiler
elif name == "GNU":
compiler = LinuxGnuCompiler
elif name == "clang":
compiler = LinuxClangCompiler
elif name == "Cray":
compiler = LinuxCrayCompiler
compiler = None
if comm.rank == 0:
# Note:
# Sniffing compiler for very large numbers of MPI ranks is
# expensive so we do this on one rank and broadcast
try:
output = subprocess.run(
[exe, "--version"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True,
encoding="utf-8"
).stdout
except (subprocess.CalledProcessError, UnicodeDecodeError):
output = ""

# Find the name of the compiler family
if output.startswith("gcc") or output.startswith("g++"):
name = "GNU"
elif output.startswith("clang"):
name = "clang"
elif output.startswith("Apple LLVM") or output.startswith("Apple clang"):
name = "clang"
elif output.startswith("icc"):
name = "Intel"
elif "Cray" in output.split("\n")[0]:
# Cray is more awkward eg:
# Cray clang version 11.0.4 (<some_hash>)
# gcc (GCC) 9.3.0 20200312 (Cray Inc.)
name = "Cray"
else:
compiler = AnonymousCompiler
elif sys.platform.find("darwin") == 0:
if name == "clang":
machine = platform.uname().machine
if machine == "arm64":
compiler = MacClangARMCompiler
elif machine == "x86_64":
compiler = MacClangCompiler
elif name == "GNU":
compiler = MacGNUCompiler
name = "unknown"

# Set the compiler instance based on the platform (and architecture)
if sys.platform.find("linux") == 0:
if name == "Intel":
compiler = LinuxIntelCompiler
elif name == "GNU":
compiler = LinuxGnuCompiler
elif name == "clang":
compiler = LinuxClangCompiler
elif name == "Cray":
compiler = LinuxCrayCompiler
else:
compiler = AnonymousCompiler
elif sys.platform.find("darwin") == 0:
if name == "clang":
machine = platform.uname().machine
if machine == "arm64":
compiler = MacClangARMCompiler
elif machine == "x86_64":
compiler = MacClangCompiler
elif name == "GNU":
compiler = MacGNUCompiler
else:
compiler = AnonymousCompiler
else:
compiler = AnonymousCompiler
else:
compiler = AnonymousCompiler
return compiler

return comm.bcast(compiler, 0)


class Compiler(ABC):
Expand Down Expand Up @@ -178,8 +185,8 @@ class Compiler(ABC):
_debugflags = ()

def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, comm=None):
# Get compiler version ASAP since it is used in __repr__
self.sniff_compiler_version()
# Set compiler version ASAP since it is used in __repr__
self.version = None
JDBetteridge marked this conversation as resolved.
Show resolved Hide resolved

self._extra_compiler_flags = tuple(extra_compiler_flags)
self._extra_linker_flags = tuple(extra_linker_flags)
Expand All @@ -190,6 +197,7 @@ def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, co
# Compilation communicators are reference counted on the PyOP2 comm
self.pcomm = mpi.internal_comm(comm, self)
self.comm = mpi.compilation_comm(self.pcomm, self)
self.sniff_compiler_version()

def __repr__(self):
return f"<{self._name} compiler, version {self.version or 'unknown'}>"
Expand Down Expand Up @@ -238,23 +246,28 @@ def sniff_compiler_version(self, cpp=False):
:arg cpp: If set to True will use the C++ compiler rather than
the C compiler to determine the version number.
"""
# Note:
# Sniffing the compiler version for very large numbers of
# MPI ranks is expensive
exe = self.cxx if cpp else self.cc
self.version = None
# `-dumpversion` is not sufficient to get the whole version string (for some compilers),
# but other compilers do not implement `-dumpfullversion`!
for dumpstring in ["-dumpfullversion", "-dumpversion"]:
try:
output = subprocess.run(
[exe, dumpstring],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True,
encoding="utf-8"
).stdout
self.version = Version(output)
break
except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion):
continue
version = None
if self.comm.rank == 0:
# `-dumpversion` is not sufficient to get the whole version string (for some compilers),
# but other compilers do not implement `-dumpfullversion`!
for dumpstring in ["-dumpfullversion", "-dumpversion"]:
try:
output = subprocess.run(
[exe, dumpstring],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True,
encoding="utf-8"
).stdout
version = Version(output)
break
except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion):
continue
self.version = self.comm.bcast(version, 0)

@property
def bugfix_cflags(self):
Expand Down Expand Up @@ -448,23 +461,6 @@ class LinuxGnuCompiler(Compiler):
_optflags = ("-march=native", "-O3", "-ffast-math")
_debugflags = ("-O0", "-g")

def sniff_compiler_version(self, cpp=False):
super(LinuxGnuCompiler, self).sniff_compiler_version()
if self.version >= Version("7.0"):
try:
# gcc-7 series only spits out patch level on dumpfullversion.
exe = self.cxx if cpp else self.cc
output = subprocess.run(
[exe, "-dumpfullversion"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
check=True,
encoding="utf-8"
).stdout
self.version = Version(output)
except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion):
pass

@property
def bugfix_cflags(self):
"""Flags to work around bugs in compilers."""
Expand Down Expand Up @@ -596,7 +592,7 @@ def __init__(self, code, argtypes):
exe = configuration["cxx"] or "mpicxx"
else:
exe = configuration["cc"] or "mpicc"
compiler = sniff_compiler(exe)
compiler = sniff_compiler(exe, comm)
dll = compiler(cppargs, ldargs, cpp=cpp, comm=comm).get_so(code, extension)

if isinstance(jitmodule, GlobalKernel):
Expand Down
Loading