Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Untangle compilation process #3940

Merged
merged 6 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions firedrake/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,8 +775,8 @@ def make_c_evaluate(function, c_name="evaluate", ldargs=None, tolerance=None):
libspatialindex_so = Path(rtree.core.rt._name).absolute()
lsi_runpath = f"-Wl,-rpath,{libspatialindex_so.parent}"
ldargs += [str(libspatialindex_so), lsi_runpath]
return compilation.load(
src, "c", c_name,
dll = compilation.load(
src, "c",
cppargs=[
f"-I{path.dirname(__file__)}",
f"-I{sys.prefix}/include",
Expand All @@ -785,3 +785,4 @@ def make_c_evaluate(function, c_name="evaluate", ldargs=None, tolerance=None):
ldargs=ldargs,
comm=function.comm
)
return getattr(dll, c_name)
6 changes: 3 additions & 3 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2681,8 +2681,8 @@ def _c_locator(self, tolerance=None):

libspatialindex_so = Path(rtree.core.rt._name).absolute()
lsi_runpath = f"-Wl,-rpath,{libspatialindex_so.parent}"
locator = compilation.load(
src, "c", "locator",
dll = compilation.load(
src, "c",
cppargs=[
f"-I{os.path.dirname(__file__)}",
f"-I{sys.prefix}/include",
Expand All @@ -2696,7 +2696,7 @@ def _c_locator(self, tolerance=None):
],
comm=self.comm
)

locator = getattr(dll, "locator")
locator.argtypes = [ctypes.POINTER(function._CFunction),
ctypes.POINTER(ctypes.c_double),
ctypes.POINTER(ctypes.c_double),
Expand Down
57 changes: 13 additions & 44 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
from pyop2.logger import warning, debug, progress, INFO
from pyop2.exceptions import CompilationError
from pyop2.utils import get_petsc_variables
import pyop2.global_kernel
from petsc4py import PETSc


Expand Down Expand Up @@ -424,38 +423,16 @@ def load_hashkey(*args, **kwargs):
@mpi.collective
@memory_cache(hashkey=load_hashkey)
@PETSc.Log.EventDecorator()
def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(),
argtypes=None, restype=None, comm=None):
def load(code, extension, cppargs=(), ldargs=(), comm=None):
"""Build a shared library and return a function pointer from it.

:arg jitmodule: The JIT Module which can generate the code to compile, or
the string representing the source code.
:arg code: The code to compile.
:arg extension: extension of the source file (c, cpp)
:arg fn_name: The name of the function to return from the resulting library
:arg cppargs: A tuple of arguments to the C compiler (optional)
:arg ldargs: A tuple of arguments to the linker (optional)
:arg argtypes: A list of ctypes argument types matching the arguments of
the returned function (optional, pass ``None`` for ``void``). This is
only used when string is passed in instead of JITModule.
:arg restype: The return type of the function (optional, pass
``None`` for ``void``).
:kwarg comm: Optional communicator to compile the code on (only
rank 0 compiles code) (defaults to pyop2.mpi.COMM_WORLD).
"""
if isinstance(jitmodule, str):
class StrCode(object):
def __init__(self, code, argtypes):
self.code_to_compile = code
self.cache_key = (None, code) # We peel off the first
# entry, since for a jitmodule, it's a process-local
# cache key
self.argtypes = argtypes
code = StrCode(jitmodule, argtypes)
elif isinstance(jitmodule, pyop2.global_kernel.GlobalKernel):
code = jitmodule
else:
raise ValueError("Don't know how to compile code of type %r" % type(jitmodule))

global _compiler
if _compiler:
# Use the global compiler if it has been set
Expand All @@ -475,15 +452,7 @@ def __init__(self, code, argtypes):
# This call is cached on disk
so_name = make_so(compiler_instance, code, extension, comm)
# This call might be cached in memory by the OS (system dependent)
dll = ctypes.CDLL(so_name)

if isinstance(jitmodule, pyop2.global_kernel.GlobalKernel):
_add_profiling_events(dll, code.local_kernel.events)

fn = getattr(dll, fn_name)
fn.argtypes = code.argtypes
fn.restype = restype
return fn
return ctypes.CDLL(so_name)


def expandWl(ldflags):
Expand Down Expand Up @@ -519,27 +488,27 @@ def setdefault(self, key, default=None):
return self[key]


def _make_so_hashkey(compiler, jitmodule, extension, comm):
def _make_so_hashkey(compiler, code, extension, comm):
if extension == "cpp":
exe = compiler.cxx
compiler_flags = compiler.cxxflags
else:
exe = compiler.cc
compiler_flags = compiler.cflags
return (compiler, exe, compiler_flags, compiler.ld, compiler.ldflags, jitmodule.cache_key)
return (compiler, code, exe, compiler_flags, compiler.ld, compiler.ldflags)


def check_source_hashes(compiler, jitmodule, extension, comm):
def check_source_hashes(compiler, code, extension, comm):
"""A check to see whether code generated on all ranks is identical.

:arg compiler: The compiler to use to create the shared library.
:arg jitmodule: The JIT Module which can generate the code to compile.
:arg code: The code to compile.
:arg filename: The filename of the library to create.
:arg extension: extension of the source file (c, cpp).
:arg comm: Communicator over which to perform compilation.
"""
# Reconstruct hash from filename
hashval = _as_hexdigest(_make_so_hashkey(compiler, jitmodule, extension, comm))
hashval = _as_hexdigest(_make_so_hashkey(compiler, code, extension, comm))
with mpi.temp_internal_comm(comm) as icomm:
matching = icomm.allreduce(hashval, op=_check_op)
if matching != hashval:
Expand All @@ -550,7 +519,7 @@ def check_source_hashes(compiler, jitmodule, extension, comm):
output.mkdir(parents=True, exist_ok=True)
icomm.barrier()
with open(srcfile, "w") as fh:
fh.write(jitmodule.code_to_compile)
fh.write(code)
icomm.barrier()
raise CompilationError(f"Generated code differs across ranks (see output in {output})")

Expand All @@ -561,11 +530,11 @@ def check_source_hashes(compiler, jitmodule, extension, comm):
cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir'], extension=".so")
)
@PETSc.Log.EventDecorator()
def make_so(compiler, jitmodule, extension, comm, filename=None):
def make_so(compiler, code, extension, comm, filename=None):
"""Build a shared library and load it

:arg compiler: The compiler to use to create the shared library.
:arg jitmodule: The JIT Module which can generate the code to compile.
:arg code: The code to compile.
:arg filename: The filename of the library to create.
:arg extension: extension of the source file (c, cpp).
:arg comm: Communicator over which to perform compilation.
Expand Down Expand Up @@ -605,7 +574,7 @@ def make_so(compiler, jitmodule, extension, comm, filename=None):
with progress(INFO, 'Compiling wrapper'):
# Write source code to disk
with open(cname, "w") as fh:
fh.write(jitmodule.code_to_compile)
fh.write(code)
os.close(descriptor)

if not compiler.ld:
Expand Down Expand Up @@ -650,7 +619,7 @@ def _run(cc, logfile, errfile, step="Compilation", filemode="w"):
"""))


def _add_profiling_events(dll, events):
def add_profiling_events(dll, events):
"""
If PyOP2 is in profiling mode, events are attached to dll to profile the local linear algebra calls.
The event is generated here in python and then set in the shared library,
Expand Down
24 changes: 12 additions & 12 deletions pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from petsc4py import PETSc

from pyop2 import mpi
from pyop2.compilation import load
from pyop2.compilation import add_profiling_events, load
from pyop2.configuration import configuration
from pyop2.datatypes import IntType, as_ctypes
from pyop2.types import IterationRegion, Constant, READ
Expand Down Expand Up @@ -366,8 +366,11 @@ def code_to_compile(self):
"""Return the C/C++ source code as a string."""
from pyop2.codegen.rep2loopy import generate

wrapper = generate(self.builder)
code = lp.generate_code_v2(wrapper)
with PETSc.Log.Event("GlobalKernel: generate loopy"):
wrapper = generate(self.builder)

with PETSc.Log.Event("GlobalKernel: generate device code"):
code = lp.generate_code_v2(wrapper)

if self.local_kernel.cpp:
from loopy.codegen.result import process_preambles
Expand Down Expand Up @@ -397,15 +400,12 @@ def compile(self, comm):
+ tuple(self.local_kernel.ldargs)
)

return load(
self,
extension,
self.name,
cppargs=cppargs,
ldargs=ldargs,
restype=ctypes.c_int,
comm=comm
)
dll = load(self.code_to_compile, extension, cppargs=cppargs, ldargs=ldargs, comm=comm)
add_profiling_events(dll, self.local_kernel.events)
fn = getattr(dll, self.name)
fn.argtypes = self.argtypes
fn.restype = ctypes.c_int
return fn

@cached_property
def argtypes(self):
Expand Down
Loading