Skip to content

Commit

Permalink
Untangle compilation process (#3940)
Browse files Browse the repository at this point in the history
  • Loading branch information
connorjward authored Jan 10, 2025
1 parent 8e1a748 commit a08fc1e
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 84 deletions.
5 changes: 3 additions & 2 deletions firedrake/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,8 +770,8 @@ def make_c_evaluate(function, c_name="evaluate", ldargs=None, tolerance=None):
libspatialindex_so = Path(rtree.core.rt._name).absolute()
lsi_runpath = f"-Wl,-rpath,{libspatialindex_so.parent}"
ldargs += [str(libspatialindex_so), lsi_runpath]
return compilation.load(
src, "c", c_name,
dll = compilation.load(
src, "c",
cppargs=[
f"-I{path.dirname(__file__)}",
f"-I{sys.prefix}/include",
Expand All @@ -780,3 +780,4 @@ def make_c_evaluate(function, c_name="evaluate", ldargs=None, tolerance=None):
ldargs=ldargs,
comm=function.comm
)
return getattr(dll, c_name)
6 changes: 3 additions & 3 deletions firedrake/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -2688,8 +2688,8 @@ def _c_locator(self, tolerance=None):

libspatialindex_so = Path(rtree.core.rt._name).absolute()
lsi_runpath = f"-Wl,-rpath,{libspatialindex_so.parent}"
locator = compilation.load(
src, "c", "locator",
dll = compilation.load(
src, "c",
cppargs=[
f"-I{os.path.dirname(__file__)}",
f"-I{sys.prefix}/include",
Expand All @@ -2703,7 +2703,7 @@ def _c_locator(self, tolerance=None):
],
comm=self.comm
)

locator = getattr(dll, "locator")
locator.argtypes = [ctypes.POINTER(function._CFunction),
ctypes.POINTER(ctypes.c_double),
ctypes.POINTER(ctypes.c_double),
Expand Down
2 changes: 1 addition & 1 deletion firedrake/mg/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def name_multiindex(multiindex, name):
kernel = lp.make_kernel(
domains, instructions, kernel_data, name=kernel_name,
target=tsfc.parameters.target, lang_version=(2018, 2))
kernel = lp.merge([kernel, *subkernels])
kernel = lp.merge([kernel, *subkernels]).with_entrypoints({kernel_name})
return op2.Kernel(
kernel, name=kernel_name, include_dirs=Ainv.include_dirs,
headers=Ainv.headers, events=Ainv.events)
Expand Down
8 changes: 6 additions & 2 deletions firedrake/preconditioners/fdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1835,13 +1835,17 @@ def setSubMatCSR(comm, triu=False):
return cache.setdefault(key, SparseAssembler.load_setSubMatCSR(comm, triu))

@staticmethod
def load_c_code(code, name, **kwargs):
def load_c_code(code, name, comm, argtypes, restype):
petsc_dir = get_petsc_dir()
cppargs = [f"-I{d}/include" for d in petsc_dir]
ldargs = ([f"-L{d}/lib" for d in petsc_dir]
+ [f"-Wl,-rpath,{d}/lib" for d in petsc_dir]
+ ["-lpetsc", "-lm"])
return load(code, "c", name, cppargs=cppargs, ldargs=ldargs, **kwargs)
dll = load(code, "c", cppargs=cppargs, ldargs=ldargs, comm=comm)
fn = getattr(dll, name)
fn.argtypes = argtypes
fn.restype = restype
return fn

@staticmethod
def load_setSubMatCSR(comm, triu=False):
Expand Down
13 changes: 7 additions & 6 deletions firedrake/preconditioners/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,12 +505,13 @@ def load_c_function(code, name, comm):
ldargs = (["-L%s/lib" % d for d in get_petsc_dir()]
+ ["-Wl,-rpath,%s/lib" % d for d in get_petsc_dir()]
+ ["-lpetsc", "-lm"])
return load(code, "c", name,
argtypes=[ctypes.c_voidp, ctypes.c_int, ctypes.c_voidp,
ctypes.c_voidp, ctypes.c_voidp, ctypes.c_int,
ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp],
restype=ctypes.c_int, cppargs=cppargs, ldargs=ldargs,
comm=comm)
dll = load(code, "c", cppargs=cppargs, ldargs=ldargs, comm=comm)
fn = getattr(dll, name)
fn.argtypes = [ctypes.c_voidp, ctypes.c_int, ctypes.c_voidp,
ctypes.c_voidp, ctypes.c_voidp, ctypes.c_int,
ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp]
fn.restype = ctypes.c_int
return fn


def make_c_arguments(form, kernel, state, get_map, require_state=False,
Expand Down
9 changes: 5 additions & 4 deletions firedrake/supermeshing.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,14 +434,15 @@ def likely(cell_A):
includes = ["-I%s/include" % d for d in dirs]
libs = ["-L%s/lib" % d for d in dirs]
libs = libs + ["-Wl,-rpath,%s/lib" % d for d in dirs] + ["-lpetsc", "-lsupermesh"]
lib = load(
supermesh_kernel_str, "c", "supermesh_kernel",
dll = load(
supermesh_kernel_str, "c",
cppargs=includes,
ldargs=libs,
argtypes=[ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp],
restype=ctypes.c_int,
comm=mesh_A._comm
)
lib = getattr(dll, "supermesh_kernel")
lib.argtypes = [ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp]
lib.restype = ctypes.c_int

ammm(V_A, V_B, likely, node_locations_A, node_locations_B, M_SS, ctypes.addressof(lib), mat)
if orig_value_size == 1:
Expand Down
65 changes: 14 additions & 51 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@
from pyop2.logger import warning, debug, progress, INFO
from pyop2.exceptions import CompilationError
from pyop2.utils import get_petsc_variables
import pyop2.global_kernel
from petsc4py import PETSc


Expand Down Expand Up @@ -411,51 +410,23 @@ class AnonymousCompiler(Compiler):


def load_hashkey(*args, **kwargs):
from pyop2.global_kernel import GlobalKernel
if isinstance(args[0], str):
code_hash = md5(args[0].encode()).hexdigest()
elif isinstance(args[0], GlobalKernel):
code_hash = md5(str(args[0].cache_key).encode()).hexdigest()
else:
pass # This will raise an error in load
code_hash = md5(args[0].encode()).hexdigest()
return default_parallel_hashkey(code_hash, *args[1:], **kwargs)


@mpi.collective
@memory_cache(hashkey=load_hashkey)
@PETSc.Log.EventDecorator()
def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(),
argtypes=None, restype=None, comm=None):
def load(code, extension, cppargs=(), ldargs=(), comm=None):
"""Build a shared library and return a function pointer from it.
:arg jitmodule: The JIT Module which can generate the code to compile, or
the string representing the source code.
:arg code: The code to compile.
:arg extension: extension of the source file (c, cpp)
:arg fn_name: The name of the function to return from the resulting library
:arg cppargs: A tuple of arguments to the C compiler (optional)
:arg ldargs: A tuple of arguments to the linker (optional)
:arg argtypes: A list of ctypes argument types matching the arguments of
the returned function (optional, pass ``None`` for ``void``). This is
only used when string is passed in instead of JITModule.
:arg restype: The return type of the function (optional, pass
``None`` for ``void``).
:kwarg comm: Optional communicator to compile the code on (only
rank 0 compiles code) (defaults to pyop2.mpi.COMM_WORLD).
"""
if isinstance(jitmodule, str):
class StrCode(object):
def __init__(self, code, argtypes):
self.code_to_compile = code
self.cache_key = (None, code) # We peel off the first
# entry, since for a jitmodule, it's a process-local
# cache key
self.argtypes = argtypes
code = StrCode(jitmodule, argtypes)
elif isinstance(jitmodule, pyop2.global_kernel.GlobalKernel):
code = jitmodule
else:
raise ValueError("Don't know how to compile code of type %r" % type(jitmodule))

global _compiler
if _compiler:
# Use the global compiler if it has been set
Expand All @@ -475,15 +446,7 @@ def __init__(self, code, argtypes):
# This call is cached on disk
so_name = make_so(compiler_instance, code, extension, comm)
# This call might be cached in memory by the OS (system dependent)
dll = ctypes.CDLL(so_name)

if isinstance(jitmodule, pyop2.global_kernel.GlobalKernel):
_add_profiling_events(dll, code.local_kernel.events)

fn = getattr(dll, fn_name)
fn.argtypes = code.argtypes
fn.restype = restype
return fn
return ctypes.CDLL(so_name)


def expandWl(ldflags):
Expand Down Expand Up @@ -519,27 +482,27 @@ def setdefault(self, key, default=None):
return self[key]


def _make_so_hashkey(compiler, jitmodule, extension, comm):
def _make_so_hashkey(compiler, code, extension, comm):
if extension == "cpp":
exe = compiler.cxx
compiler_flags = compiler.cxxflags
else:
exe = compiler.cc
compiler_flags = compiler.cflags
return (compiler, exe, compiler_flags, compiler.ld, compiler.ldflags, jitmodule.cache_key)
return (compiler, code, exe, compiler_flags, compiler.ld, compiler.ldflags)


def check_source_hashes(compiler, jitmodule, extension, comm):
def check_source_hashes(compiler, code, extension, comm):
"""A check to see whether code generated on all ranks is identical.
:arg compiler: The compiler to use to create the shared library.
:arg jitmodule: The JIT Module which can generate the code to compile.
:arg code: The code to compile.
:arg filename: The filename of the library to create.
:arg extension: extension of the source file (c, cpp).
:arg comm: Communicator over which to perform compilation.
"""
# Reconstruct hash from filename
hashval = _as_hexdigest(_make_so_hashkey(compiler, jitmodule, extension, comm))
hashval = _as_hexdigest(_make_so_hashkey(compiler, code, extension, comm))
with mpi.temp_internal_comm(comm) as icomm:
matching = icomm.allreduce(hashval, op=_check_op)
if matching != hashval:
Expand All @@ -550,7 +513,7 @@ def check_source_hashes(compiler, jitmodule, extension, comm):
output.mkdir(parents=True, exist_ok=True)
icomm.barrier()
with open(srcfile, "w") as fh:
fh.write(jitmodule.code_to_compile)
fh.write(code)
icomm.barrier()
raise CompilationError(f"Generated code differs across ranks (see output in {output})")

Expand All @@ -561,11 +524,11 @@ def check_source_hashes(compiler, jitmodule, extension, comm):
cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir'], extension=".so")
)
@PETSc.Log.EventDecorator()
def make_so(compiler, jitmodule, extension, comm, filename=None):
def make_so(compiler, code, extension, comm, filename=None):
"""Build a shared library and load it
:arg compiler: The compiler to use to create the shared library.
:arg jitmodule: The JIT Module which can generate the code to compile.
:arg code: The code to compile.
:arg filename: The filename of the library to create.
:arg extension: extension of the source file (c, cpp).
:arg comm: Communicator over which to perform compilation.
Expand Down Expand Up @@ -605,7 +568,7 @@ def make_so(compiler, jitmodule, extension, comm, filename=None):
with progress(INFO, 'Compiling wrapper'):
# Write source code to disk
with open(cname, "w") as fh:
fh.write(jitmodule.code_to_compile)
fh.write(code)
os.close(descriptor)

if not compiler.ld:
Expand Down Expand Up @@ -650,7 +613,7 @@ def _run(cc, logfile, errfile, step="Compilation", filemode="w"):
"""))


def _add_profiling_events(dll, events):
def add_profiling_events(dll, events):
"""
If PyOP2 is in profiling mode, events are attached to dll to profile the local linear algebra calls.
The event is generated here in python and then set in the shared library,
Expand Down
24 changes: 12 additions & 12 deletions pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from petsc4py import PETSc

from pyop2 import mpi
from pyop2.compilation import load
from pyop2.compilation import add_profiling_events, load
from pyop2.configuration import configuration
from pyop2.datatypes import IntType, as_ctypes
from pyop2.types import IterationRegion, Constant, READ
Expand Down Expand Up @@ -366,8 +366,11 @@ def code_to_compile(self):
"""Return the C/C++ source code as a string."""
from pyop2.codegen.rep2loopy import generate

wrapper = generate(self.builder)
code = lp.generate_code_v2(wrapper)
with PETSc.Log.Event("GlobalKernel: generate loopy"):
wrapper = generate(self.builder)

with PETSc.Log.Event("GlobalKernel: generate device code"):
code = lp.generate_code_v2(wrapper)

if self.local_kernel.cpp:
from loopy.codegen.result import process_preambles
Expand Down Expand Up @@ -397,15 +400,12 @@ def compile(self, comm):
+ tuple(self.local_kernel.ldargs)
)

return load(
self,
extension,
self.name,
cppargs=cppargs,
ldargs=ldargs,
restype=ctypes.c_int,
comm=comm
)
dll = load(self.code_to_compile, extension, cppargs=cppargs, ldargs=ldargs, comm=comm)
add_profiling_events(dll, self.local_kernel.events)
fn = getattr(dll, self.name)
fn.argtypes = self.argtypes
fn.restype = ctypes.c_int
return fn

@cached_property
def argtypes(self):
Expand Down
7 changes: 4 additions & 3 deletions tests/pyop2/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
# OF THE POSSIBILITY OF SUCH DAMAGE.

import ctypes
import os
import pytest
import tempfile
Expand Down Expand Up @@ -785,7 +784,8 @@ def test_writing_large_so():
if COMM_WORLD.rank == 1:
os.remove("big.c")

fn = load(program, "c", "big", argtypes=(ctypes.c_voidp,), comm=COMM_WORLD)
dll = load(program, "c", comm=COMM_WORLD)
fn = getattr(dll, "big")
assert fn is not None


Expand All @@ -800,7 +800,8 @@ def test_two_comms_compile_the_same_code():
}
""")

fn = load(code, "c", "noop", argtypes=(), comm=COMM_WORLD)
dll = load(code, "c", comm=COMM_WORLD)
fn = getattr(dll, "noop")
assert fn is not None


Expand Down

0 comments on commit a08fc1e

Please sign in to comment.