Skip to content

Commit

Permalink
Merge branch 'master' into gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushikcfd committed Jun 24, 2022
2 parents e937806 + a33fedf commit b75090a
Show file tree
Hide file tree
Showing 27 changed files with 2,124 additions and 1,797 deletions.
101 changes: 37 additions & 64 deletions pyop2/backends/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from pyop2.types.set import Set, ExtrudedSet, Subset, MixedSet
from pyop2.types.dataset import DataSet, GlobalDataSet, MixedDataSet
from pyop2.types.map import Map, MixedMap
from pyop2.parloop import AbstractParLoop, AbstractJITModule
from pyop2.parloop import AbstractParLoop
from pyop2.global_kernel import AbstractGlobalKernel
from pyop2.types.mat import Mat
from pyop2.glob import Global
from pyop2.backends import AbstractComputeBackend
from pyop2.datatypes import as_ctypes, IntType
from petsc4py import PETSc
from . import (
compilation,
configuration as conf,
datatypes as dtypes,
mpi,
utils
)
Expand All @@ -35,23 +35,17 @@ def _vec(self):
return vec


class JITModule(AbstractJITModule):
class GlobalKernel(AbstractGlobalKernel):

@utils.cached_property
def code_to_compile(self):
from pyop2.codegen.builder import WrapperBuilder
"""Return the C/C++ source code as a string."""
from pyop2.codegen.rep2loopy import generate

builder = WrapperBuilder(kernel=self._kernel,
iterset=self._iterset,
iteration_region=self._iteration_region,
pass_layer_to_kernel=self._pass_layer_arg)
for arg in self._args:
builder.add_argument(arg)

wrapper = generate(builder)
wrapper = generate(self.builder)
code = lp.generate_code_v2(wrapper)

if self._kernel._cpp:
if self.local_kernel.cpp:
from loopy.codegen.result import process_preambles
preamble = "".join(process_preambles(getattr(code, "device_preambles", [])))
device_code = "\n\n".join(str(dp.ast) for dp in code.device_programs)
Expand All @@ -60,53 +54,38 @@ def code_to_compile(self):

@PETSc.Log.EventDecorator()
@mpi.collective
def compile(self):
# If we weren't in the cache we /must/ have arguments
if not hasattr(self, '_args'):
raise RuntimeError("JITModule has no args associated with it, should never happen")

compiler = conf.configuration["compiler"]
extension = "cpp" if self._kernel._cpp else "c"
cppargs = self._cppargs
cppargs += ["-I%s/include" % d for d in utils.get_petsc_dir()] + \
["-I%s" % d for d in self._kernel._include_dirs] + \
["-I%s" % os.path.abspath(os.path.dirname(__file__))]
ldargs = ["-L%s/lib" % d for d in utils.get_petsc_dir()] + \
["-Wl,-rpath,%s/lib" % d for d in utils.get_petsc_dir()] + \
["-lpetsc", "-lm"] + self._libraries
ldargs += self._kernel._ldargs

self._fun = compilation.load(self,
extension,
self._wrapper_name,
cppargs=cppargs,
ldargs=ldargs,
restype=ctypes.c_int,
compiler=compiler,
comm=self.comm)
# Blow away everything we don't need any more
del self._args
del self._kernel
del self._iterset
def compile(self, comm):
"""Compile the kernel.
:arg comm: The communicator the compilation is collective over.
:returns: A ctypes function pointer for the compiled function.
"""
extension = "cpp" if self.local_kernel.cpp else "c"
cppargs = (
tuple("-I%s/include" % d for d in utils.get_petsc_dir())
+ tuple("-I%s" % d for d in self.local_kernel.include_dirs)
+ ("-I%s" % os.path.abspath(os.path.dirname(__file__)),)
)
ldargs = (
tuple("-L%s/lib" % d for d in utils.get_petsc_dir())
+ tuple("-Wl,-rpath,%s/lib" % d for d in utils.get_petsc_dir())
+ ("-lpetsc", "-lm")
+ tuple(self.local_kernel.ldargs)
)

return compilation.load(self, extension, self.name,
cppargs=cppargs,
ldargs=ldargs,
restype=ctypes.c_int,
comm=comm)

@utils.cached_property
def argtypes(self):
index_type = dtypes.as_ctypes(dtypes.IntType)
argtypes = (index_type, index_type)
argtypes += self._iterset._argtypes_
for arg in self._args:
argtypes += arg._argtypes_
seen = set()
for arg in self._args:
maps = arg.map_tuple
for map_ in maps:
for k, t in zip(map_._kernel_args_, map_._argtypes_):
if k in seen:
continue
argtypes += (t,)
seen.add(k)
return argtypes
...
# The first two arguments to the global kernel are the 'start' and 'stop'
# indices. All other arguments are declared to be void pointers.
dtypes = [as_ctypes(IntType)] * 2
dtypes.extend([ctypes.c_voidp for _ in self.builder.wrapper_args[2:]])
return tuple(dtypes)


class ParLoop(AbstractParLoop):
Expand All @@ -128,12 +107,6 @@ def prepare_arglist(self, iterset, *args):
seen.add(k)
return arglist

@utils.cached_property
def _jitmodule(self):
return JITModule(self.kernel, self.iterset, *self.args,
iterate=self.iteration_region,
pass_layer_arg=self._pass_layer_arg)

@mpi.collective
def _compute(self, part, fun, *arglist):
with self._compute_event:
Expand Down
113 changes: 113 additions & 0 deletions pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@

"""Provides common base classes for cached objects."""

import hashlib
import os
from pathlib import Path
import pickle

import cachetools

from pyop2.configuration import configuration
from pyop2.mpi import hash_comm
from pyop2.utils import cached_property


Expand Down Expand Up @@ -230,3 +238,108 @@ def _cache_key(cls, *args, **kwargs):
def cache_key(self):
"""Cache key."""
return self._key


cached = cachetools.cached
"""Cache decorator for functions. See the cachetools documentation for more
information.
.. note::
If you intend to use this decorator to cache things that are collective
across a communicator then you must include the communicator as part of
the cache key. Since communicators are themselves not hashable you should
use :func:`pyop2.mpi.hash_comm`.
You should also make sure to use unbounded caches as otherwise some ranks
may evict results leading to deadlocks.
"""


def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=False):
"""Decorator for wrapping a function in a cache that stores values in memory and to disk.
:arg cache: The in-memory cache, usually a :class:`dict`.
:arg cachedir: The location of the cache directory. Defaults to ``PYOP2_CACHE_DIR``.
:arg key: Callable returning the cache key for the function inputs. If ``collective``
is ``True`` then this function must return a 2-tuple where the first entry is the
communicator to be collective over and the second is the key. This is required to ensure
that deadlocks do not occur when using different subcommunicators.
:arg collective: If ``True`` then cache lookup is done collectively over a communicator.
"""
if cachedir is None:
cachedir = configuration["cache_dir"]

def decorator(func):
def wrapper(*args, **kwargs):
if collective:
comm, disk_key = key(*args, **kwargs)
disk_key = _as_hexdigest(disk_key)
k = hash_comm(comm), disk_key
else:
k = _as_hexdigest(key(*args, **kwargs))

# first try the in-memory cache
try:
return cache[k]
except KeyError:
pass

# then try to retrieve from disk
if collective:
if comm.rank == 0:
v = _disk_cache_get(cachedir, disk_key)
comm.bcast(v, root=0)
else:
v = comm.bcast(None, root=0)
else:
v = _disk_cache_get(cachedir, k)
if v is not None:
return cache.setdefault(k, v)

# if all else fails call func and populate the caches
v = func(*args, **kwargs)
if collective:
if comm.rank == 0:
_disk_cache_set(cachedir, disk_key, v)
else:
_disk_cache_set(cachedir, k, v)
return cache.setdefault(k, v)
return wrapper
return decorator


def _as_hexdigest(key):
return hashlib.md5(str(key).encode()).hexdigest()


def _disk_cache_get(cachedir, key):
"""Retrieve a value from the disk cache.
:arg cachedir: The cache directory.
:arg key: The cache key (must be a string).
:returns: The cached object if found, else ``None``.
"""
filepath = Path(cachedir, key[:2], key[2:])
try:
with open(filepath, "rb") as f:
return pickle.load(f)
except FileNotFoundError:
return None


def _disk_cache_set(cachedir, key, value):
"""Store a new value in the disk cache.
:arg cachedir: The cache directory.
:arg key: The cache key (must be a string).
:arg value: The new item to store in the cache.
"""
k1, k2 = key[:2], key[2:]
basedir = Path(cachedir, k1)
basedir.mkdir(parents=True, exist_ok=True)

tempfile = basedir.joinpath(f"{k2}_p{os.getpid()}.tmp")
filepath = basedir.joinpath(k2)
with open(tempfile, "wb") as f:
pickle.dump(value, f)
tempfile.rename(filepath)
Loading

0 comments on commit b75090a

Please sign in to comment.