Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
WIP: may have reintroduced deadlocks
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Aug 22, 2024
1 parent 395b760 commit a78aafe
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 16 deletions.
39 changes: 30 additions & 9 deletions pyop2/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from pyop2.configuration import configuration
from pyop2.logger import warning, debug, progress, INFO
from pyop2.exceptions import CompilationError
import pyop2.global_kernel
from petsc4py import PETSc


Expand Down Expand Up @@ -420,6 +421,7 @@ def load_hashkey(*args, **kwargs):


# JBTODO: This should not be memory cached
# ...benchmarking disagrees with my assessment
@mpi.collective
@memory_cache(hashkey=load_hashkey, broadcast=False)
def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(),
Expand All @@ -440,8 +442,6 @@ def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(),
:kwarg comm: Optional communicator to compile the code on (only
rank 0 compiles code) (defaults to pyop2.mpi.COMM_WORLD).
"""
from pyop2.global_kernel import GlobalKernel

if isinstance(jitmodule, str):
class StrCode(object):
def __init__(self, code, argtypes):
Expand All @@ -451,7 +451,7 @@ def __init__(self, code, argtypes):
# cache key
self.argtypes = argtypes
code = StrCode(jitmodule, argtypes)
elif isinstance(jitmodule, GlobalKernel):
elif isinstance(jitmodule, pyop2.global_kernel.GlobalKernel):
code = jitmodule
else:
raise ValueError("Don't know how to compile code of type %r" % type(jitmodule))
Expand All @@ -477,7 +477,7 @@ def __init__(self, code, argtypes):
# This call is cached in memory by the OS
dll = ctypes.CDLL(so_name)

if isinstance(jitmodule, GlobalKernel):
if isinstance(jitmodule, pyop2.global_kernel.GlobalKernel):
_add_profiling_events(dll, code.local_kernel.events)

fn = getattr(dll, fn_name)
Expand Down Expand Up @@ -511,6 +511,13 @@ def read(self, filename):
raise FileNotFoundError("File not on disk, cache miss")
return filename

def setdefault(self, key, default=None):
try:
return self[key]
except KeyError:
self[key] = default
return self[key]


def _make_so_hashkey(compiler, jitmodule, extension, comm):
if extension == "cpp":
Expand Down Expand Up @@ -546,7 +553,7 @@ def check_source_hashes(compiler, jitmodule, extension, comm):
@mpi.collective
@parallel_cache(
hashkey=_make_so_hashkey,
cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir'], extension=".so"),
cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir'], extension=".so")
)
def make_so(compiler, jitmodule, extension, comm, filename=None):
"""Build a shared library and load it
Expand All @@ -560,12 +567,15 @@ def make_so(compiler, jitmodule, extension, comm, filename=None):
Returns a :class:`ctypes.CDLL` object of the resulting shared
library."""
if filename is None:
tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}")
tempdir.mkdir(exist_ok=True)
filename = tempdir.joinpath(f"foo{next(FILE_CYCLER)}.c")
# JBTODO: Remove this directory at some point?
pyop2_tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}")
tempdir = pyop2_tempdir.joinpath(f"{os.getpid()}")
# ~ tempdir = Path(mkdtemp(dir=pyop2_tempdir.joinpath(f"{os.getpid()}")))
# This path + filename should be unique
filename = tempdir.joinpath("foo.c")
else:
pyop2_tempdir = None
filename = Path(filename).absolute()
filename.parent.mkdir(exist_ok=True)

# Compilation communicators are reference counted on the PyOP2 comm
icomm = mpi.internal_comm(comm, compiler)
Expand All @@ -590,6 +600,11 @@ def make_so(compiler, jitmodule, extension, comm, filename=None):

# Compile on compilation communicator (ccomm) rank 0
if comm.rank == 0:
if pyop2_tempdir is None:
filename.parent.mkdir(exist_ok=True)
else:
pyop2_tempdir.mkdir(exist_ok=True)
tempdir.mkdir(exist_ok=True)
logfile = path.joinpath(f"{base}_p{pid}.log")
errfile = path.joinpath(f"{base}_p{pid}.err")
with progress(INFO, 'Compiling wrapper'):
Expand All @@ -612,6 +627,12 @@ def make_so(compiler, jitmodule, extension, comm, filename=None):
return soname


# JBTODO: Probably don't want to do this if we fail to compile...
# ~ @atexit
# ~ def _cleanup_tempdir():
# ~ pyop2_tempdir = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}")


def _run(cc, logfile, errfile, step="Compilation", filemode="w"):
debug(f"{step} command: {' '.join(cc)}")
try:
Expand Down
17 changes: 11 additions & 6 deletions pyop2/global_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import pytools
from petsc4py import PETSc

from pyop2 import compilation, mpi
from pyop2 import mpi
from pyop2.compilation import load
from pyop2.configuration import configuration
from pyop2.datatypes import IntType, as_ctypes
from pyop2.types import IterationRegion, Constant, READ
Expand Down Expand Up @@ -397,11 +398,15 @@ def compile(self, comm):
+ tuple(self.local_kernel.ldargs)
)

return compilation.load(self, extension, self.name,
cppargs=cppargs,
ldargs=ldargs,
restype=ctypes.c_int,
comm=comm)
return load(
self,
extension,
self.name,
cppargs=cppargs,
ldargs=ldargs,
restype=ctypes.c_int,
comm=comm
)

@cached_property
def argtypes(self):
Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def cache(self):
_cache_collection = int_comm.Get_attr(mpi.comm_cache_keyval)
if _cache_collection is None:
_cache_collection = {default_cache_name: DEFAULT_CACHE()}
mpi.COMM_WORLD.Set_attr(mpi.comm_cache_keyval, _cache_collection)
int_comm.Set_attr(mpi.comm_cache_keyval, _cache_collection)
return _cache_collection[default_cache_name]

@pytest.fixture
Expand Down

0 comments on commit a78aafe

Please sign in to comment.