diff --git a/firedrake/function.py b/firedrake/function.py index 97a08be816..e1378b21e0 100644 --- a/firedrake/function.py +++ b/firedrake/function.py @@ -770,8 +770,8 @@ def make_c_evaluate(function, c_name="evaluate", ldargs=None, tolerance=None): libspatialindex_so = Path(rtree.core.rt._name).absolute() lsi_runpath = f"-Wl,-rpath,{libspatialindex_so.parent}" ldargs += [str(libspatialindex_so), lsi_runpath] - return compilation.load( - src, "c", c_name, + dll = compilation.load( + src, "c", cppargs=[ f"-I{path.dirname(__file__)}", f"-I{sys.prefix}/include", @@ -780,3 +780,4 @@ def make_c_evaluate(function, c_name="evaluate", ldargs=None, tolerance=None): ldargs=ldargs, comm=function.comm ) + return getattr(dll, c_name) diff --git a/firedrake/mesh.py b/firedrake/mesh.py index 9936efcefc..7e9a61773b 100644 --- a/firedrake/mesh.py +++ b/firedrake/mesh.py @@ -2688,8 +2688,8 @@ def _c_locator(self, tolerance=None): libspatialindex_so = Path(rtree.core.rt._name).absolute() lsi_runpath = f"-Wl,-rpath,{libspatialindex_so.parent}" - locator = compilation.load( - src, "c", "locator", + dll = compilation.load( + src, "c", cppargs=[ f"-I{os.path.dirname(__file__)}", f"-I{sys.prefix}/include", @@ -2703,7 +2703,7 @@ def _c_locator(self, tolerance=None): ], comm=self.comm ) - + locator = getattr(dll, "locator") locator.argtypes = [ctypes.POINTER(function._CFunction), ctypes.POINTER(ctypes.c_double), ctypes.POINTER(ctypes.c_double), diff --git a/firedrake/mg/kernels.py b/firedrake/mg/kernels.py index f892f6260c..6ce9c4eb55 100644 --- a/firedrake/mg/kernels.py +++ b/firedrake/mg/kernels.py @@ -720,7 +720,7 @@ def name_multiindex(multiindex, name): kernel = lp.make_kernel( domains, instructions, kernel_data, name=kernel_name, target=tsfc.parameters.target, lang_version=(2018, 2)) - kernel = lp.merge([kernel, *subkernels]) + kernel = lp.merge([kernel, *subkernels]).with_entrypoints({kernel_name}) return op2.Kernel( kernel, name=kernel_name, include_dirs=Ainv.include_dirs, headers=Ainv.headers, events=Ainv.events) diff --git a/firedrake/preconditioners/fdm.py b/firedrake/preconditioners/fdm.py index e75172dc8e..51a611ad47 100644 --- a/firedrake/preconditioners/fdm.py +++ b/firedrake/preconditioners/fdm.py @@ -1835,13 +1835,17 @@ def setSubMatCSR(comm, triu=False): return cache.setdefault(key, SparseAssembler.load_setSubMatCSR(comm, triu)) @staticmethod - def load_c_code(code, name, **kwargs): + def load_c_code(code, name, comm, argtypes, restype): petsc_dir = get_petsc_dir() cppargs = [f"-I{d}/include" for d in petsc_dir] ldargs = ([f"-L{d}/lib" for d in petsc_dir] + [f"-Wl,-rpath,{d}/lib" for d in petsc_dir] + ["-lpetsc", "-lm"]) - return load(code, "c", name, cppargs=cppargs, ldargs=ldargs, **kwargs) + dll = load(code, "c", cppargs=cppargs, ldargs=ldargs, comm=comm) + fn = getattr(dll, name) + fn.argtypes = argtypes + fn.restype = restype + return fn @staticmethod def load_setSubMatCSR(comm, triu=False): diff --git a/firedrake/preconditioners/patch.py b/firedrake/preconditioners/patch.py index 0a7bad5575..5e9d0d4fa0 100644 --- a/firedrake/preconditioners/patch.py +++ b/firedrake/preconditioners/patch.py @@ -505,12 +505,13 @@ def load_c_function(code, name, comm): ldargs = (["-L%s/lib" % d for d in get_petsc_dir()] + ["-Wl,-rpath,%s/lib" % d for d in get_petsc_dir()] + ["-lpetsc", "-lm"]) - return load(code, "c", name, - argtypes=[ctypes.c_voidp, ctypes.c_int, ctypes.c_voidp, - ctypes.c_voidp, ctypes.c_voidp, ctypes.c_int, - ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp], - restype=ctypes.c_int, cppargs=cppargs, ldargs=ldargs, - comm=comm) + dll = load(code, "c", cppargs=cppargs, ldargs=ldargs, comm=comm) + fn = getattr(dll, name) + fn.argtypes = [ctypes.c_voidp, ctypes.c_int, ctypes.c_voidp, + ctypes.c_voidp, ctypes.c_voidp, ctypes.c_int, + ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp] + fn.restype = ctypes.c_int + return fn def make_c_arguments(form, kernel, state, get_map, require_state=False, diff --git a/firedrake/supermeshing.py b/firedrake/supermeshing.py index ee576ea6e5..b35cbb0fe4 100644 --- a/firedrake/supermeshing.py +++ b/firedrake/supermeshing.py @@ -434,14 +434,15 @@ def likely(cell_A): includes = ["-I%s/include" % d for d in dirs] libs = ["-L%s/lib" % d for d in dirs] libs = libs + ["-Wl,-rpath,%s/lib" % d for d in dirs] + ["-lpetsc", "-lsupermesh"] - lib = load( - supermesh_kernel_str, "c", "supermesh_kernel", + dll = load( + supermesh_kernel_str, "c", cppargs=includes, ldargs=libs, - argtypes=[ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp], - restype=ctypes.c_int, comm=mesh_A._comm ) + lib = getattr(dll, "supermesh_kernel") + lib.argtypes = [ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp, ctypes.c_voidp] + lib.restype = ctypes.c_int ammm(V_A, V_B, likely, node_locations_A, node_locations_B, M_SS, ctypes.addressof(lib), mat) if orig_value_size == 1: diff --git a/pyop2/compilation.py b/pyop2/compilation.py index d28c945188..a9b724f368 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -56,7 +56,6 @@ from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError from pyop2.utils import get_petsc_variables -import pyop2.global_kernel from petsc4py import PETSc @@ -411,51 +410,23 @@ class AnonymousCompiler(Compiler): def load_hashkey(*args, **kwargs): - from pyop2.global_kernel import GlobalKernel - if isinstance(args[0], str): - code_hash = md5(args[0].encode()).hexdigest() - elif isinstance(args[0], GlobalKernel): - code_hash = md5(str(args[0].cache_key).encode()).hexdigest() - else: - pass # This will raise an error in load + code_hash = md5(args[0].encode()).hexdigest() return default_parallel_hashkey(code_hash, *args[1:], **kwargs) @mpi.collective @memory_cache(hashkey=load_hashkey) @PETSc.Log.EventDecorator() -def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), - argtypes=None, restype=None, comm=None): +def load(code, extension, cppargs=(), ldargs=(), comm=None): """Build a shared library and return a function pointer from it. - :arg jitmodule: The JIT Module which can generate the code to compile, or - the string representing the source code. + :arg code: The code to compile. :arg extension: extension of the source file (c, cpp) - :arg fn_name: The name of the function to return from the resulting library :arg cppargs: A tuple of arguments to the C compiler (optional) :arg ldargs: A tuple of arguments to the linker (optional) - :arg argtypes: A list of ctypes argument types matching the arguments of - the returned function (optional, pass ``None`` for ``void``). This is - only used when string is passed in instead of JITModule. - :arg restype: The return type of the function (optional, pass - ``None`` for ``void``). :kwarg comm: Optional communicator to compile the code on (only rank 0 compiles code) (defaults to pyop2.mpi.COMM_WORLD). """ - if isinstance(jitmodule, str): - class StrCode(object): - def __init__(self, code, argtypes): - self.code_to_compile = code - self.cache_key = (None, code) # We peel off the first - # entry, since for a jitmodule, it's a process-local - # cache key - self.argtypes = argtypes - code = StrCode(jitmodule, argtypes) - elif isinstance(jitmodule, pyop2.global_kernel.GlobalKernel): - code = jitmodule - else: - raise ValueError("Don't know how to compile code of type %r" % type(jitmodule)) - global _compiler if _compiler: # Use the global compiler if it has been set @@ -475,15 +446,7 @@ def __init__(self, code, argtypes): # This call is cached on disk so_name = make_so(compiler_instance, code, extension, comm) # This call might be cached in memory by the OS (system dependent) - dll = ctypes.CDLL(so_name) - - if isinstance(jitmodule, pyop2.global_kernel.GlobalKernel): - _add_profiling_events(dll, code.local_kernel.events) - - fn = getattr(dll, fn_name) - fn.argtypes = code.argtypes - fn.restype = restype - return fn + return ctypes.CDLL(so_name) def expandWl(ldflags): @@ -519,27 +482,27 @@ def setdefault(self, key, default=None): return self[key] -def _make_so_hashkey(compiler, jitmodule, extension, comm): +def _make_so_hashkey(compiler, code, extension, comm): if extension == "cpp": exe = compiler.cxx compiler_flags = compiler.cxxflags else: exe = compiler.cc compiler_flags = compiler.cflags - return (compiler, exe, compiler_flags, compiler.ld, compiler.ldflags, jitmodule.cache_key) + return (compiler, code, exe, compiler_flags, compiler.ld, compiler.ldflags) -def check_source_hashes(compiler, jitmodule, extension, comm): +def check_source_hashes(compiler, code, extension, comm): """A check to see whether code generated on all ranks is identical. :arg compiler: The compiler to use to create the shared library. - :arg jitmodule: The JIT Module which can generate the code to compile. + :arg code: The code to compile. :arg filename: The filename of the library to create. :arg extension: extension of the source file (c, cpp). :arg comm: Communicator over which to perform compilation. """ # Reconstruct hash from filename - hashval = _as_hexdigest(_make_so_hashkey(compiler, jitmodule, extension, comm)) + hashval = _as_hexdigest(_make_so_hashkey(compiler, code, extension, comm)) with mpi.temp_internal_comm(comm) as icomm: matching = icomm.allreduce(hashval, op=_check_op) if matching != hashval: @@ -550,7 +513,7 @@ def check_source_hashes(compiler, jitmodule, extension, comm): output.mkdir(parents=True, exist_ok=True) icomm.barrier() with open(srcfile, "w") as fh: - fh.write(jitmodule.code_to_compile) + fh.write(code) icomm.barrier() raise CompilationError(f"Generated code differs across ranks (see output in {output})") @@ -561,11 +524,11 @@ def check_source_hashes(compiler, jitmodule, extension, comm): cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir'], extension=".so") ) @PETSc.Log.EventDecorator() -def make_so(compiler, jitmodule, extension, comm, filename=None): +def make_so(compiler, code, extension, comm, filename=None): """Build a shared library and load it :arg compiler: The compiler to use to create the shared library. - :arg jitmodule: The JIT Module which can generate the code to compile. + :arg code: The code to compile. :arg filename: The filename of the library to create. :arg extension: extension of the source file (c, cpp). :arg comm: Communicator over which to perform compilation. @@ -605,7 +568,7 @@ def make_so(compiler, jitmodule, extension, comm, filename=None): with progress(INFO, 'Compiling wrapper'): # Write source code to disk with open(cname, "w") as fh: - fh.write(jitmodule.code_to_compile) + fh.write(code) os.close(descriptor) if not compiler.ld: @@ -650,7 +613,7 @@ def _run(cc, logfile, errfile, step="Compilation", filemode="w"): """)) -def _add_profiling_events(dll, events): +def add_profiling_events(dll, events): """ If PyOP2 is in profiling mode, events are attached to dll to profile the local linear algebra calls. The event is generated here in python and then set in the shared library, diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index ae13dc1c59..7edfed0771 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -11,7 +11,7 @@ from petsc4py import PETSc from pyop2 import mpi -from pyop2.compilation import load +from pyop2.compilation import add_profiling_events, load from pyop2.configuration import configuration from pyop2.datatypes import IntType, as_ctypes from pyop2.types import IterationRegion, Constant, READ @@ -366,8 +366,11 @@ def code_to_compile(self): """Return the C/C++ source code as a string.""" from pyop2.codegen.rep2loopy import generate - wrapper = generate(self.builder) - code = lp.generate_code_v2(wrapper) + with PETSc.Log.Event("GlobalKernel: generate loopy"): + wrapper = generate(self.builder) + + with PETSc.Log.Event("GlobalKernel: generate device code"): + code = lp.generate_code_v2(wrapper) if self.local_kernel.cpp: from loopy.codegen.result import process_preambles @@ -397,15 +400,12 @@ def compile(self, comm): + tuple(self.local_kernel.ldargs) ) - return load( - self, - extension, - self.name, - cppargs=cppargs, - ldargs=ldargs, - restype=ctypes.c_int, - comm=comm - ) + dll = load(self.code_to_compile, extension, cppargs=cppargs, ldargs=ldargs, comm=comm) + add_profiling_events(dll, self.local_kernel.events) + fn = getattr(dll, self.name) + fn.argtypes = self.argtypes + fn.restype = ctypes.c_int + return fn @cached_property def argtypes(self): diff --git a/tests/pyop2/test_caching.py b/tests/pyop2/test_caching.py index 1298991b3e..cfd9e6ce7f 100644 --- a/tests/pyop2/test_caching.py +++ b/tests/pyop2/test_caching.py @@ -31,7 +31,6 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED # OF THE POSSIBILITY OF SUCH DAMAGE. -import ctypes import os import pytest import tempfile @@ -785,7 +784,8 @@ def test_writing_large_so(): if COMM_WORLD.rank == 1: os.remove("big.c") - fn = load(program, "c", "big", argtypes=(ctypes.c_voidp,), comm=COMM_WORLD) + dll = load(program, "c", comm=COMM_WORLD) + fn = getattr(dll, "big") assert fn is not None @@ -800,7 +800,8 @@ def test_two_comms_compile_the_same_code(): } """) - fn = load(code, "c", "noop", argtypes=(), comm=COMM_WORLD) + dll = load(code, "c", comm=COMM_WORLD) + fn = getattr(dll, "noop") assert fn is not None