From 31471a606a852aed250b05574d1fc2a2874eec31 Mon Sep 17 00:00:00 2001 From: Jack Betteridge <43041811+JDBetteridge@users.noreply.github.com> Date: Wed, 9 Oct 2024 16:45:18 +0100 Subject: [PATCH] Remove comm hash and add per-comm caches (#724) --------- Co-authored-by: David A. Ham Co-authored-by: Connor Ward --- .github/workflows/ci.yml | 27 +- pyop2/caching.py | 592 +++++++++++++++++++++++++++----------- pyop2/compilation.py | 518 ++++++++++++++++++--------------- pyop2/configuration.py | 13 +- pyop2/exceptions.py | 10 + pyop2/global_kernel.py | 66 ++--- pyop2/mpi.py | 87 ++++-- pyop2/op2.py | 9 +- pyop2/utils.py | 21 +- requirements-git.txt | 1 + scripts/pyop2-clean | 4 +- test/unit/test_caching.py | 313 ++++++++++++++++---- 12 files changed, 1105 insertions(+), 556 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 788186ac9..9c089aea8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: # Don't immediately kill all if one Python version fails fail-fast: false matrix: - python-version: ['3.8', '3.9', '3.10', '3.11'] + python-version: ['3.9', '3.10', '3.11', '3.12'] env: CC: mpicc PETSC_DIR: ${{ github.workspace }}/petsc @@ -58,7 +58,7 @@ jobs: working-directory: ${{ env.PETSC_DIR }}/src/binding/petsc4py run: | python -m pip install --upgrade pip - python -m pip install --upgrade wheel 'cython<3' numpy + python -m pip install --upgrade wheel cython numpy python -m pip install --no-deps . - name: Checkout PyOP2 @@ -66,7 +66,7 @@ jobs: with: path: PyOP2 - - name: Install PyOP2 + - name: Install PyOP2 dependencies shell: bash working-directory: PyOP2 run: | @@ -76,7 +76,21 @@ jobs: python -m pip install pulp python -m pip install -U flake8 python -m pip install -U pytest-timeout - python -m pip install . + + - name: Install PyOP2 (Python <3.12) + if: ${{ matrix.python-version != '3.12' }} + shell: bash + working-directory: PyOP2 + run: python -m pip install . + + # Not sure if this is a bug in setuptools or something PyOP2 is doing wrong + - name: Install PyOP2 (Python == 3.12) + if: ${{ matrix.python-version == '3.12' }} + shell: bash + working-directory: PyOP2 + run: | + python -m pip install -U setuptools + python setup.py install - name: Run linting shell: bash @@ -86,7 +100,10 @@ jobs: - name: Run tests shell: bash working-directory: PyOP2 - run: pytest --tb=native --timeout=480 --timeout-method=thread -o faulthandler_timeout=540 -v test + run: | + # Running parallel test cases separately works around a bug in pytest-mpi + pytest -k "not parallel" --tb=native --timeout=480 --timeout-method=thread -o faulthandler_timeout=540 -v test + mpiexec -n 3 pytest -k "parallel[3]" --tb=native --timeout=480 --timeout-method=thread -o faulthandler_timeout=540 -v test timeout-minutes: 10 - name: Build documentation diff --git a/pyop2/caching.py b/pyop2/caching.py index 0f036212f..2948ddede 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -32,49 +32,39 @@ # OF THE POSSIBILITY OF SUCH DAMAGE. """Provides common base classes for cached objects.""" - +import atexit +import cachetools import hashlib import os -from pathlib import Path import pickle - -import cachetools +import weakref +from collections.abc import MutableMapping +from pathlib import Path +from warnings import warn # noqa F401 +from collections import defaultdict +from itertools import count +from functools import wraps +from tempfile import mkstemp from pyop2.configuration import configuration -from pyop2.mpi import hash_comm -from pyop2.utils import cached_property +from pyop2.exceptions import CachingError, HashError # noqa: F401 +from pyop2.logger import debug +from pyop2.mpi import ( + MPI, COMM_WORLD, comm_cache_keyval, temp_internal_comm +) +from petsc4py import PETSc -def report_cache(typ): - """Report the size of caches of type ``typ`` - - :arg typ: A class of cached object. For example - :class:`ObjectCached` or :class:`Cached`. - """ - from collections import defaultdict - from inspect import getmodule - from gc import get_objects - typs = defaultdict(lambda: 0) - n = 0 - for x in get_objects(): - if isinstance(x, typ): - typs[type(x)] += 1 - n += 1 - if n == 0: - print("\nNo %s objects in caches" % typ.__name__) - return - print("\n%d %s objects in caches" % (n, typ.__name__)) - print("Object breakdown") - print("================") - for k, v in typs.iteritems(): - mod = getmodule(k) - if mod is not None: - name = "%s.%s" % (mod.__name__, k.__name__) - else: - name = k.__name__ - print('%s: %d' % (name, v)) +# Caches created here are registered as a tuple of +# (creation_index, comm, comm.name, function, cache) +# in _KNOWN_CACHES +_CACHE_CIDX = count() +_KNOWN_CACHES = [] +# Flag for outputting information at the end of testing (do not abuse!) +_running_on_ci = bool(os.environ.get('PYOP2_CI_TESTS')) +# FIXME: (Later) Remove ObjectCached class ObjectCached(object): """Base class for objects that should be cached on another object. @@ -160,179 +150,431 @@ def make_obj(): return obj -class Cached(object): +def cache_filter(comm=None, comm_name=None, alive=True, function=None, cache_type=None): + """ Filter PyOP2 caches based on communicator, function or cache type. + """ + caches = _KNOWN_CACHES + if comm is not None: + with temp_internal_comm(comm) as icomm: + cache_collection = icomm.Get_attr(comm_cache_keyval) + if cache_collection is None: + print(f"Communicator {icomm.name} has no associated caches") + comm_name = icomm.name + if comm_name is not None: + caches = filter(lambda c: c.comm_name == comm_name, caches) + if alive: + caches = filter(lambda c: c.comm != MPI.COMM_NULL, caches) + if function is not None: + if isinstance(function, str): + caches = filter(lambda c: function in c.func_name, caches) + else: + caches = filter(lambda c: c.func is function, caches) + if cache_type is not None: + if isinstance(cache_type, str): + caches = filter(lambda c: cache_type in c.cache_name, caches) + else: + caches = filter(lambda c: c.cache_name == cache_type.__class__.__qualname__, caches) + return [*caches] - """Base class providing global caching of objects. Derived classes need to - implement classmethods :meth:`_process_args` and :meth:`_cache_key` - and define a class attribute :attr:`_cache` of type :class:`dict`. - .. warning:: - The derived class' :meth:`__init__` is still called if the object is - retrieved from cache. If that is not desired, derived classes can set - a flag indicating whether the constructor has already been called and - immediately return from :meth:`__init__` if the flag is set. Otherwise - the object will be re-initialized even if it was returned from cache! +class _CacheRecord: + """ Object for keeping a record of Pyop2 Cache statistics. + """ + def __init__(self, cidx, comm, func, cache): + self.cidx = cidx + self.comm = comm + self.comm_name = comm.name + self.func = func + self.func_module = func.__module__ + self.func_name = func.__qualname__ + self.cache = weakref.ref(cache) + fin = weakref.finalize(cache, self.finalize, cache) + fin.atexit = False + self.cache_name = cache.__class__.__qualname__ + try: + self.cache_loc = cache.cachedir + except AttributeError: + self.cache_loc = "Memory" + + def get_stats(self, cache=None): + if cache is None: + cache = self.cache() + hit = miss = size = maxsize = -1 + if cache is None: + hit, miss, size, maxsize = self.hit, self.miss, self.size, self.maxsize + if isinstance(cache, cachetools.Cache): + size = cache.currsize + maxsize = cache.maxsize + if hasattr(cache, "instrument__"): + hit = cache.hit + miss = cache.miss + if size == -1: + try: + size = len(cache) + except NotImplementedError: + pass + if maxsize is None: + try: + maxsize = cache.max_size + except AttributeError: + pass + return hit, miss, size, maxsize + + def finalize(self, cache): + self.hit, self.miss, self.size, self.maxsize = self.get_stats(cache) + + +def print_cache_stats(*args, **kwargs): + """ Print out the cache hit/miss/size/maxsize stats for PyOP2 caches. """ + data = defaultdict(lambda: defaultdict(list)) + for entry in cache_filter(*args, **kwargs): + active = (entry.comm != MPI.COMM_NULL) + data[(entry.comm_name, active)][(entry.cache_name, entry.cache_loc)].append( + (entry.cidx, entry.func_module, entry.func_name, entry.get_stats()) + ) + + tab = " " + hline = "-"*120 + col = (90, 27) + stats_col = (6, 6, 6, 6) + stats = ("hit", "miss", "size", "max") + no_stats = "|".join(" "*ii for ii in stats_col) + print(hline) + print(f"|{'Cache':^{col[0]}}|{'Stats':^{col[1]}}|") + subtitles = "|".join(f"{st:^{w}}" for st, w in zip(stats, stats_col)) + print("|" + " "*col[0] + f"|{subtitles:{col[1]}}|") + print(hline) + for ecomm, cachedict in data.items(): + active = "Active" if ecomm[1] else "Freed" + comm_title = f"{ecomm[0]} ({active})" + print(f"|{comm_title:{col[0]}}|{no_stats}|") + for ecache, function_list in cachedict.items(): + cache_title = f"{tab}{ecache[0]}" + print(f"|{cache_title:{col[0]}}|{no_stats}|") + cache_location = f"{tab} ↳ {ecache[1]!s}" + if len(cache_location) < col[0]: + print(f"|{cache_location:{col[0]}}|{no_stats}|") + else: + print(f"|{cache_location:78}|") + for entry in function_list: + function_title = f"{tab*2}id={entry[0]} {'.'.join(entry[1:3])}" + stats_row = "|".join(f"{s:{w}}" for s, w in zip(entry[3], stats_col)) + print(f"|{function_title:{col[0]}}|{stats_row:{col[1]}}|") + print(hline) - def __new__(cls, *args, **kwargs): - args, kwargs = cls._process_args(*args, **kwargs) - key = cls._cache_key(*args, **kwargs) - def make_obj(): - obj = super(Cached, cls).__new__(cls) - obj._key = key - obj._initialized = False - # obj.__init__ will be called twice when constructing - # something not in the cache. The first time here, with - # the canonicalised args, the second time directly in the - # subclass. But that one should hit the cache and return - # straight away. - obj.__init__(*args, **kwargs) - return obj +if _running_on_ci: + print_cache_stats = atexit.register(print_cache_stats) - # Don't bother looking in caches if we're not meant to cache - # this object. - if key is None: - return make_obj() - try: - return cls._cache_lookup(key) - except (KeyError, IOError): - obj = make_obj() - cls._cache_store(key, obj) - return obj - @classmethod - def _cache_lookup(cls, key): - return cls._cache[key] +class _CacheMiss: + pass - @classmethod - def _cache_store(cls, key, val): - cls._cache[key] = val - @classmethod - def _process_args(cls, *args, **kwargs): - """Pre-processes the arguments before they are being passed to - :meth:`_cache_key` and the constructor. +CACHE_MISS = _CacheMiss() - :rtype: *must* return a :class:`list` of *args* and a - :class:`dict` of *kwargs*""" - return args, kwargs - @classmethod - def _cache_key(cls, *args, **kwargs): - """Compute the cache key given the preprocessed constructor arguments. +def _as_hexdigest(*args): + hash_ = hashlib.md5() + for a in args: + if isinstance(a, MPI.Comm): + raise HashError("Communicators cannot be hashed, caching will be broken!") + hash_.update(str(a).encode()) + return hash_.hexdigest() + + +class DictLikeDiskAccess(MutableMapping): + """ A Dictionary like interface for storing and retrieving objects from a disk cache. + """ + def __init__(self, cachedir, extension=".pickle"): + """ - :rtype: Cache key to use or ``None`` if the object is not to be cached + :arg cachedir: The cache directory. + :arg extension: Optional extension to use for written files. + """ + self.cachedir = cachedir + self.extension = extension - .. note:: The cache key must be hashable.""" - return tuple(args) + tuple([(k, v) for k, v in kwargs.items()]) + def __getitem__(self, key): + """Retrieve a value from the disk cache. - @cached_property - def cache_key(self): - """Cache key.""" - return self._key + :arg key: The cache key, a 2-tuple of strings. + :returns: The cached object if found. + """ + filepath = Path(self.cachedir, key[0][:2], key[0][2:] + key[1]) + try: + with self.open(filepath.with_suffix(self.extension), mode="rb") as fh: + value = self.read(fh) + except FileNotFoundError: + raise KeyError("File not on disk, cache miss") + return value + def __setitem__(self, key, value): + """Store a new value in the disk cache. -cached = cachetools.cached -"""Cache decorator for functions. See the cachetools documentation for more -information. + :arg key: The cache key, a 2-tuple of strings. + :arg value: The new item to store in the cache. + """ + k1, k2 = key[0][:2], key[0][2:] + key[1] + basedir = Path(self.cachedir, k1) + basedir.mkdir(parents=True, exist_ok=True) -.. note:: - If you intend to use this decorator to cache things that are collective - across a communicator then you must include the communicator as part of - the cache key. Since communicators are themselves not hashable you should - use :func:`pyop2.mpi.hash_comm`. + # Care must be taken here to ensure that the file is created safely as + # the filesystem may be network based. `mkstemp` does so securely without + # race conditions: + # https://docs.python.org/3/library/tempfile.html#tempfile.mkstemp + # The file descriptor must also be closed after use with `os.close()`. + fd, tempfile = mkstemp(suffix=".tmp", prefix=k2, dir=basedir, text=False) + tempfile = Path(tempfile) + # Open using `tempfile` (the filename) rather than the file descriptor + # to allow redefining `self.open` + with self.open(tempfile, mode="wb") as fh: + self.write(fh, value) + os.close(fd) - You should also make sure to use unbounded caches as otherwise some ranks - may evict results leading to deadlocks. -""" + # Renaming (moving) the file is guaranteed by any POSIX compliant + # filesystem to be atomic. This may fail if somehow the destination is + # on another filesystem, but that shouldn't happen here. + filepath = basedir.joinpath(k2) + tempfile.rename(filepath.with_suffix(self.extension)) + def __delitem__(self, key): + raise NotImplementedError(f"Cannot remove items from {self.__class__.__name__}") -def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=False): - """Decorator for wrapping a function in a cache that stores values in memory and to disk. + def __iter__(self): + raise NotImplementedError(f"Cannot iterate over keys in {self.__class__.__name__}") - :arg cache: The in-memory cache, usually a :class:`dict`. - :arg cachedir: The location of the cache directory. Defaults to ``PYOP2_CACHE_DIR``. - :arg key: Callable returning the cache key for the function inputs. If ``collective`` - is ``True`` then this function must return a 2-tuple where the first entry is the - communicator to be collective over and the second is the key. This is required to ensure - that deadlocks do not occur when using different subcommunicators. - :arg collective: If ``True`` then cache lookup is done collectively over a communicator. + def __len__(self): + raise NotImplementedError(f"Cannot query length of {self.__class__.__name__}") + + def __repr__(self): + return f"{self.__class__.__name__}(cachedir={self.cachedir}, extension={self.extension})" + + def __eq__(self, other): + # Instances are the same if they have the same cachedir + return (self.cachedir == other.cachedir and self.extension == other.extension) + + def open(self, *args, **kwargs): + return open(*args, **kwargs) + + def read(self, filehandle): + return pickle.load(filehandle) + + def write(self, filehandle, value): + pickle.dump(value, filehandle) + + +def default_comm_fetcher(*args, **kwargs): + """ A sensible default comm fetcher for use with `parallel_cache`. """ - if cachedir is None: - cachedir = configuration["cache_dir"] + comms = filter( + lambda arg: isinstance(arg, MPI.Comm), + args + tuple(kwargs.values()) + ) + try: + comm = next(comms) + except StopIteration: + raise TypeError("No comms found in args or kwargs") + return comm - def decorator(func): - def wrapper(*args, **kwargs): - if collective: - comm, disk_key = key(*args, **kwargs) - disk_key = _as_hexdigest(disk_key) - k = hash_comm(comm), disk_key + +def default_parallel_hashkey(*args, **kwargs): + """ A sensible default hash key for use with `parallel_cache`. + """ + # We now want to actively remove any comms from args and kwargs to get + # the same disk cache key. + hash_args = tuple(filter( + lambda arg: not isinstance(arg, MPI.Comm), + args + )) + hash_kwargs = dict(filter( + lambda arg: not isinstance(arg[1], MPI.Comm), + kwargs.items() + )) + return cachetools.keys.hashkey(*hash_args, **hash_kwargs) + + +def instrument(cls): + """ Class decorator for dict-like objects for counting cache hits/misses. + """ + @wraps(cls, updated=()) + class _wrapper(cls): + instrument__ = True + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.hit = 0 + self.miss = 0 + + def get(self, key, default=None): + value = super().get(key, default) + if value is default: + self.miss += 1 else: - k = _as_hexdigest(key(*args, **kwargs)) + self.hit += 1 + return value - # first try the in-memory cache + def __getitem__(self, key): try: - return cache[k] - except KeyError: - pass - - # then try to retrieve from disk - if collective: - if comm.rank == 0: - v = _disk_cache_get(cachedir, disk_key) - comm.bcast(v, root=0) - else: - v = comm.bcast(None, root=0) - else: - v = _disk_cache_get(cachedir, k) - if v is not None: - return cache.setdefault(k, v) - - # if all else fails call func and populate the caches - v = func(*args, **kwargs) - if collective: - if comm.rank == 0: - _disk_cache_set(cachedir, disk_key, v) - else: - _disk_cache_set(cachedir, k, v) - return cache.setdefault(k, v) - return wrapper - return decorator + value = super().__getitem__(key) + self.hit += 1 + except KeyError as e: + self.miss += 1 + raise e + return value + return _wrapper + + +class DEFAULT_CACHE(dict): + pass -def _as_hexdigest(key): - return hashlib.md5(str(key).encode()).hexdigest() +# Example of how to instrument and use different default caches: +# from functools import partial +# EXOTIC_CACHE = partial(instrument(cachetools.LRUCache), maxsize=100) +# Turn on cache measurements if printing cache info is enabled +if configuration["print_cache_info"] or _running_on_ci: + DEFAULT_CACHE = instrument(DEFAULT_CACHE) + DictLikeDiskAccess = instrument(DictLikeDiskAccess) -def _disk_cache_get(cachedir, key): - """Retrieve a value from the disk cache. - :arg cachedir: The cache directory. - :arg key: The cache key (must be a string). - :returns: The cached object if found, else ``None``. +if configuration["spmd_strict"]: + def parallel_cache( + hashkey=default_parallel_hashkey, + comm_fetcher=default_comm_fetcher, + cache_factory=lambda: DEFAULT_CACHE(), + ): + """Parallel cache decorator (SPMD strict-enabled). + """ + def decorator(func): + @PETSc.Log.EventDecorator("PyOP2 Cache Wrapper") + @wraps(func) + def wrapper(*args, **kwargs): + """ Extract the key and then try the memory cache before falling back + on calling the function and populating the cache. SPMD strict ensures + that all ranks cache hit or miss to ensure that the function evaluation + always occurs in parallel. + """ + k = hashkey(*args, **kwargs) + key = _as_hexdigest(*k), func.__qualname__ + # Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits + with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm: + # Fetch the per-comm cache_collection or set it up if not present + # A collection is required since different types of cache can be set up on the same comm + cache_collection = comm.Get_attr(comm_cache_keyval) + if cache_collection is None: + cache_collection = {} + comm.Set_attr(comm_cache_keyval, cache_collection) + # If this kind of cache is already present on the + # cache_collection, get it, otherwise create it + local_cache = cache_collection.setdefault( + (cf := cache_factory()).__class__.__name__, + cf + ) + local_cache = cache_collection[cf.__class__.__name__] + + # If this is a new cache or function add it to the list of known caches + if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]: + # When a comm is freed we do not hold a reference to the cache. + # We attach a finalizer that extracts the stats before the cache + # is deleted. + _KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache)) + + # Grab value from all ranks cache and broadcast cache hit/miss + value = local_cache.get(key, CACHE_MISS) + debug_string = f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: " + debug_string += f"key={k} in cache: {local_cache.__class__.__name__} cache " + if value is CACHE_MISS: + debug(debug_string + "miss") + cache_hit = False + else: + debug(debug_string + "hit") + cache_hit = True + all_present = comm.allgather(cache_hit) + + # If not present in the cache of all ranks we force re-evaluation on all ranks + if not min(all_present): + value = CACHE_MISS + + if value is CACHE_MISS: + value = func(*args, **kwargs) + return local_cache.setdefault(key, value) + + return wrapper + return decorator +else: + def parallel_cache( + hashkey=default_parallel_hashkey, + comm_fetcher=default_comm_fetcher, + cache_factory=lambda: DEFAULT_CACHE(), + ): + """Parallel cache decorator. + """ + def decorator(func): + @PETSc.Log.EventDecorator("PyOP2 Cache Wrapper") + @wraps(func) + def wrapper(*args, **kwargs): + """ Extract the key and then try the memory cache before falling back + on calling the function and populating the cache. + """ + k = hashkey(*args, **kwargs) + key = _as_hexdigest(*k), func.__qualname__ + # Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits + with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm: + # Fetch the per-comm cache_collection or set it up if not present + # A collection is required since different types of cache can be set up on the same comm + cache_collection = comm.Get_attr(comm_cache_keyval) + if cache_collection is None: + cache_collection = {} + comm.Set_attr(comm_cache_keyval, cache_collection) + # If this kind of cache is already present on the + # cache_collection, get it, otherwise create it + local_cache = cache_collection.setdefault( + (cf := cache_factory()).__class__.__name__, + cf + ) + local_cache = cache_collection[cf.__class__.__name__] + + # If this is a new cache or function add it to the list of known caches + if (comm, comm.name, func, local_cache) not in [(c.comm, c.comm_name, c.func, c.cache()) for c in _KNOWN_CACHES]: + # When a comm is freed we do not hold a reference to the cache. + # We attach a finalizer that extracts the stats before the cache + # is deleted. + _KNOWN_CACHES.append(_CacheRecord(next(_CACHE_CIDX), comm, func, local_cache)) + + value = local_cache.get(key, CACHE_MISS) + + if value is CACHE_MISS: + value = func(*args, **kwargs) + return local_cache.setdefault(key, value) + + return wrapper + return decorator + + +def clear_memory_cache(comm): + """ Completely remove all PyOP2 caches on a given communicator. """ - filepath = Path(cachedir, key[:2], key[2:]) - try: - with open(filepath, "rb") as f: - return pickle.load(f) - except FileNotFoundError: - return None + with temp_internal_comm(comm) as icomm: + if icomm.Get_attr(comm_cache_keyval) is not None: + icomm.Set_attr(comm_cache_keyval, {}) -def _disk_cache_set(cachedir, key, value): - """Store a new value in the disk cache. +# A small collection of default simple caches +memory_cache = parallel_cache - :arg cachedir: The cache directory. - :arg key: The cache key (must be a string). - :arg value: The new item to store in the cache. - """ - k1, k2 = key[:2], key[2:] - basedir = Path(cachedir, k1) - basedir.mkdir(parents=True, exist_ok=True) - - tempfile = basedir.joinpath(f"{k2}_p{os.getpid()}.tmp") - filepath = basedir.joinpath(k2) - with open(tempfile, "wb") as f: - pickle.dump(value, f) - tempfile.rename(filepath) + +def serial_cache(hashkey, cache_factory=lambda: DEFAULT_CACHE()): + return cachetools.cached(key=hashkey, cache=cache_factory()) + + +def disk_only_cache(*args, cachedir=configuration["cache_dir"], **kwargs): + return parallel_cache(*args, **kwargs, cache_factory=lambda: DictLikeDiskAccess(cachedir)) + + +def memory_and_disk_cache(*args, cachedir=configuration["cache_dir"], **kwargs): + def decorator(func): + return memory_cache(*args, **kwargs)(disk_only_cache(*args, cachedir=cachedir, **kwargs)(func)) + return decorator diff --git a/pyop2/compilation.py b/pyop2/compilation.py index f4a1af36a..5c0ad7b4c 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -42,12 +42,20 @@ import shlex from hashlib import md5 from packaging.version import Version, InvalidVersion +from textwrap import dedent +from functools import partial +from pathlib import Path +from contextlib import contextmanager +from tempfile import gettempdir, mkstemp +from random import randint from pyop2 import mpi +from pyop2.caching import parallel_cache, memory_cache, default_parallel_hashkey, _as_hexdigest, DictLikeDiskAccess from pyop2.configuration import configuration from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError +import pyop2.global_kernel from petsc4py import PETSc @@ -60,6 +68,10 @@ def _check_hashes(x, y, datatype): _check_op = mpi.MPI.Op.Create(_check_hashes, commute=True) _compiler = None +# Directory must be unique per VENV for multiple installs +# _and_ per user for shared machines +_EXE_HASH = md5(sys.executable.encode()).hexdigest()[-6:] +MEM_TMP_DIR = Path(gettempdir()).joinpath(f"pyop2-tempcache-uid{os.getuid()}").joinpath(_EXE_HASH) def set_default_compiler(compiler): @@ -85,6 +97,36 @@ def set_default_compiler(compiler): ) +def sniff_compiler_version(compiler, cpp=False): + """Attempt to determine the compiler version number. + + :arg compiler: Instance of compiler to sniff the version of + :arg cpp: If set to True will use the C++ compiler rather than + the C compiler to determine the version number. + """ + # Note: + # Sniffing the compiler version for very large numbers of + # MPI ranks is expensive, ensure this is only run on rank 0 + exe = compiler.cxx if cpp else compiler.cc + version = None + # `-dumpversion` is not sufficient to get the whole version string (for some compilers), + # but other compilers do not implement `-dumpfullversion`! + for dumpstring in ["-dumpfullversion", "-dumpversion"]: + try: + output = subprocess.run( + [exe, dumpstring], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=True, + encoding="utf-8" + ).stdout + version = Version(output) + break + except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion): + continue + return version + + def sniff_compiler(exe, comm=mpi.COMM_WORLD): """Obtain the correct compiler class by calling the compiler executable. @@ -151,7 +193,12 @@ def sniff_compiler(exe, comm=mpi.COMM_WORLD): else: compiler = AnonymousCompiler - return comm.bcast(compiler, 0) + # Now try and get a version number + temp = Compiler() + version = sniff_compiler_version(temp) + compiler = partial(compiler, version=version) + + return comm.bcast(compiler, root=0) class Compiler(ABC): @@ -166,10 +213,8 @@ class Compiler(ABC): (optional, prepended to any flags specified as the ldflags configuration option). The environment variable ``PYOP2_LDFLAGS`` can also be used to extend these options. - :arg cpp: Should we try and use the C++ compiler instead of the C - compiler?. - :kwarg comm: Optional communicator to compile the code on - (defaults to pyop2.mpi.COMM_WORLD). + :arg version: (Optional) usually sniffed by loader. + :arg debug: Whether to use debugging compiler flags. """ _name = "unknown" @@ -184,23 +229,22 @@ class Compiler(ABC): _optflags = () _debugflags = () - def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), cpp=False, comm=None): - # Set compiler version ASAP since it is used in __repr__ - self.version = None - + def __init__(self, extra_compiler_flags=(), extra_linker_flags=(), version=None, debug=False): self._extra_compiler_flags = tuple(extra_compiler_flags) self._extra_linker_flags = tuple(extra_linker_flags) - - self._cpp = cpp - self._debug = configuration["debug"] - - # Compilation communicators are reference counted on the PyOP2 comm - self.pcomm = mpi.internal_comm(comm, self) - self.comm = mpi.compilation_comm(self.pcomm, self) - self.sniff_compiler_version() + self._version = version + self._debug = debug def __repr__(self): - return f"<{self._name} compiler, version {self.version or 'unknown'}>" + string = f"{self.__class__.__name__}(" + string += f"extra_compiler_flags={self._extra_compiler_flags}, " + string += f"extra_linker_flags={self._extra_linker_flags}, " + string += f"version={self._version!r}, " + string += f"debug={self._debug})" + return string + + def __str__(self): + return f"<{self._name} compiler, version {self._version or 'unknown'}>" @property def cc(self): @@ -240,187 +284,10 @@ def ldflags(self): ldflags += tuple(shlex.split(configuration["ldflags"])) return ldflags - def sniff_compiler_version(self, cpp=False): - """Attempt to determine the compiler version number. - - :arg cpp: If set to True will use the C++ compiler rather than - the C compiler to determine the version number. - """ - # Note: - # Sniffing the compiler version for very large numbers of - # MPI ranks is expensive - exe = self.cxx if cpp else self.cc - version = None - if self.comm.rank == 0: - # `-dumpversion` is not sufficient to get the whole version string (for some compilers), - # but other compilers do not implement `-dumpfullversion`! - for dumpstring in ["-dumpfullversion", "-dumpversion"]: - try: - output = subprocess.run( - [exe, dumpstring], - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - check=True, - encoding="utf-8" - ).stdout - version = Version(output) - break - except (subprocess.CalledProcessError, UnicodeDecodeError, InvalidVersion): - continue - self.version = self.comm.bcast(version, 0) - @property def bugfix_cflags(self): return () - @staticmethod - def expandWl(ldflags): - """Generator to expand the `-Wl` compiler flags for use as linker flags - :arg ldflags: linker flags for a compiler command - """ - for flag in ldflags: - if flag.startswith('-Wl'): - for f in flag.lstrip('-Wl')[1:].split(','): - yield f - else: - yield flag - - @mpi.collective - def get_so(self, jitmodule, extension): - """Build a shared library and load it - - :arg jitmodule: The JIT Module which can generate the code to compile. - :arg extension: extension of the source file (c, cpp). - Returns a :class:`ctypes.CDLL` object of the resulting shared - library.""" - - # C or C++ - if self._cpp: - compiler = self.cxx - compiler_flags = self.cxxflags - else: - compiler = self.cc - compiler_flags = self.cflags - - # Determine cache key - hsh = md5(str(jitmodule.cache_key).encode()) - hsh.update(compiler.encode()) - if self.ld: - hsh.update(self.ld.encode()) - hsh.update("".join(compiler_flags).encode()) - hsh.update("".join(self.ldflags).encode()) - - basename = hsh.hexdigest() - - cachedir = configuration['cache_dir'] - - dirpart, basename = basename[:2], basename[2:] - cachedir = os.path.join(cachedir, dirpart) - pid = os.getpid() - cname = os.path.join(cachedir, "%s_p%d.%s" % (basename, pid, extension)) - oname = os.path.join(cachedir, "%s_p%d.o" % (basename, pid)) - soname = os.path.join(cachedir, "%s.so" % basename) - # Link into temporary file, then rename to shared library - # atomically (avoiding races). - tmpname = os.path.join(cachedir, "%s_p%d.so.tmp" % (basename, pid)) - - if configuration['check_src_hashes'] or configuration['debug']: - matching = self.comm.allreduce(basename, op=_check_op) - if matching != basename: - # Dump all src code to disk for debugging - output = os.path.join(configuration["cache_dir"], "mismatching-kernels") - srcfile = os.path.join(output, "src-rank%d.c" % self.comm.rank) - if self.comm.rank == 0: - os.makedirs(output, exist_ok=True) - self.comm.barrier() - with open(srcfile, "w") as f: - f.write(jitmodule.code_to_compile) - self.comm.barrier() - raise CompilationError("Generated code differs across ranks (see output in %s)" % output) - try: - # Are we in the cache? - return ctypes.CDLL(soname) - except OSError: - # No, let's go ahead and build - if self.comm.rank == 0: - # No need to do this on all ranks - os.makedirs(cachedir, exist_ok=True) - logfile = os.path.join(cachedir, "%s_p%d.log" % (basename, pid)) - errfile = os.path.join(cachedir, "%s_p%d.err" % (basename, pid)) - with progress(INFO, 'Compiling wrapper'): - with open(cname, "w") as f: - f.write(jitmodule.code_to_compile) - # Compiler also links - if not self.ld: - cc = (compiler,) \ - + compiler_flags \ - + ('-o', tmpname, cname) \ - + self.ldflags - debug('Compilation command: %s', ' '.join(cc)) - with open(logfile, "w") as log, open(errfile, "w") as err: - log.write("Compilation command:\n") - log.write(" ".join(cc)) - log.write("\n\n") - try: - if configuration['no_fork_available']: - cc += ["2>", errfile, ">", logfile] - cmd = " ".join(cc) - status = os.system(cmd) - if status != 0: - raise subprocess.CalledProcessError(status, cmd) - else: - subprocess.check_call(cc, stderr=err, stdout=log) - except subprocess.CalledProcessError as e: - raise CompilationError( - """Command "%s" return error status %d. -Unable to compile code -Compile log in %s -Compile errors in %s""" % (e.cmd, e.returncode, logfile, errfile)) - else: - cc = (compiler,) \ - + compiler_flags \ - + ('-c', '-o', oname, cname) - # Extract linker specific "cflags" from ldflags - ld = tuple(shlex.split(self.ld)) \ - + ('-o', tmpname, oname) \ - + tuple(self.expandWl(self.ldflags)) - debug('Compilation command: %s', ' '.join(cc)) - debug('Link command: %s', ' '.join(ld)) - with open(logfile, "a") as log, open(errfile, "a") as err: - log.write("Compilation command:\n") - log.write(" ".join(cc)) - log.write("\n\n") - log.write("Link command:\n") - log.write(" ".join(ld)) - log.write("\n\n") - try: - if configuration['no_fork_available']: - cc += ["2>", errfile, ">", logfile] - ld += ["2>>", errfile, ">>", logfile] - cccmd = " ".join(cc) - ldcmd = " ".join(ld) - status = os.system(cccmd) - if status != 0: - raise subprocess.CalledProcessError(status, cccmd) - status = os.system(ldcmd) - if status != 0: - raise subprocess.CalledProcessError(status, ldcmd) - else: - subprocess.check_call(cc, stderr=err, stdout=log) - subprocess.check_call(ld, stderr=err, stdout=log) - except subprocess.CalledProcessError as e: - raise CompilationError( - """Command "%s" return error status %d. -Unable to compile code -Compile log in %s -Compile errors in %s""" % (e.cmd, e.returncode, logfile, errfile)) - # Atomically ensure soname exists - os.rename(tmpname, soname) - # Wait for compilation to complete - self.comm.barrier() - # Load resulting library - return ctypes.CDLL(soname) - class MacClangCompiler(Compiler): """A compiler for building a shared library on Mac systems.""" @@ -464,7 +331,7 @@ class LinuxGnuCompiler(Compiler): @property def bugfix_cflags(self): """Flags to work around bugs in compilers.""" - ver = self.version + ver = self._version cflags = () if Version("4.8.0") <= ver < Version("4.9.0"): # GCC bug https://gcc.gnu.org/bugzilla/show_bug.cgi?id=61068 @@ -546,7 +413,20 @@ class AnonymousCompiler(Compiler): _name = "Unknown" +def load_hashkey(*args, **kwargs): + from pyop2.global_kernel import GlobalKernel + if isinstance(args[0], str): + code_hash = md5(args[0].encode()).hexdigest() + elif isinstance(args[0], GlobalKernel): + code_hash = md5(str(args[0].cache_key).encode()).hexdigest() + else: + pass # This will raise an error in load + return default_parallel_hashkey(code_hash, *args[1:], **kwargs) + + @mpi.collective +@memory_cache(hashkey=load_hashkey) +@PETSc.Log.EventDecorator() def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), argtypes=None, restype=None, comm=None): """Build a shared library and return a function pointer from it. @@ -565,8 +445,6 @@ def load(jitmodule, extension, fn_name, cppargs=(), ldargs=(), :kwarg comm: Optional communicator to compile the code on (only rank 0 compiles code) (defaults to pyop2.mpi.COMM_WORLD). """ - from pyop2.global_kernel import GlobalKernel - if isinstance(jitmodule, str): class StrCode(object): def __init__(self, code, argtypes): @@ -576,26 +454,33 @@ def __init__(self, code, argtypes): # cache key self.argtypes = argtypes code = StrCode(jitmodule, argtypes) - elif isinstance(jitmodule, GlobalKernel): + elif isinstance(jitmodule, pyop2.global_kernel.GlobalKernel): code = jitmodule else: raise ValueError("Don't know how to compile code of type %r" % type(jitmodule)) - cpp = (extension == "cpp") global _compiler if _compiler: # Use the global compiler if it has been set compiler = _compiler else: # Sniff compiler from executable - if cpp: + if extension == "cpp": exe = configuration["cxx"] or "mpicxx" else: exe = configuration["cc"] or "mpicc" compiler = sniff_compiler(exe, comm) - dll = compiler(cppargs, ldargs, cpp=cpp, comm=comm).get_so(code, extension) - if isinstance(jitmodule, GlobalKernel): + debug = configuration["debug"] + compiler_instance = compiler(cppargs, ldargs, debug=debug) + if configuration['check_src_hashes'] or configuration['debug']: + check_source_hashes(compiler_instance, code, extension, comm) + # This call is cached on disk + so_name = make_so(compiler_instance, code, extension, comm) + # This call might be cached in memory by the OS (system dependent) + dll = ctypes.CDLL(so_name) + + if isinstance(jitmodule, pyop2.global_kernel.GlobalKernel): _add_profiling_events(dll, code.local_kernel.events) fn = getattr(dll, fn_name) @@ -604,12 +489,176 @@ def __init__(self, code, argtypes): return fn +def expandWl(ldflags): + """Generator to expand the `-Wl` compiler flags for use as linker flags + :arg ldflags: linker flags for a compiler command + """ + for flag in ldflags: + if flag.startswith('-Wl'): + for f in flag.lstrip('-Wl')[1:].split(','): + yield f + else: + yield flag + + +class CompilerDiskAccess(DictLikeDiskAccess): + @contextmanager + def open(self, filename, *args, **kwargs): + yield filename + + def write(self, filename, value): + shutil.copy(value, filename) + + def read(self, filename): + if not filename.exists(): + raise FileNotFoundError("File not on disk, cache miss") + return filename + + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return self[key] + + +def _make_so_hashkey(compiler, jitmodule, extension, comm): + if extension == "cpp": + exe = compiler.cxx + compiler_flags = compiler.cxxflags + else: + exe = compiler.cc + compiler_flags = compiler.cflags + return (compiler, exe, compiler_flags, compiler.ld, compiler.ldflags, jitmodule.cache_key) + + +def check_source_hashes(compiler, jitmodule, extension, comm): + """A check to see whether code generated on all ranks is identical. + + :arg compiler: The compiler to use to create the shared library. + :arg jitmodule: The JIT Module which can generate the code to compile. + :arg filename: The filename of the library to create. + :arg extension: extension of the source file (c, cpp). + :arg comm: Communicator over which to perform compilation. + """ + # Reconstruct hash from filename + hashval = _as_hexdigest(_make_so_hashkey(compiler, jitmodule, extension, comm)) + with mpi.temp_internal_comm(comm) as icomm: + matching = icomm.allreduce(hashval, op=_check_op) + if matching != hashval: + # Dump all src code to disk for debugging + output = Path(configuration["cache_dir"]).joinpath("mismatching-kernels") + srcfile = output.joinpath(f"src-rank{icomm.rank}.{extension}") + if icomm.rank == 0: + output.mkdir(exist_ok=True) + icomm.barrier() + with open(srcfile, "w") as fh: + fh.write(jitmodule.code_to_compile) + icomm.barrier() + raise CompilationError(f"Generated code differs across ranks (see output in {output})") + + +@mpi.collective +@parallel_cache( + hashkey=_make_so_hashkey, + cache_factory=lambda: CompilerDiskAccess(configuration['cache_dir'], extension=".so") +) +@PETSc.Log.EventDecorator() +def make_so(compiler, jitmodule, extension, comm, filename=None): + """Build a shared library and load it + + :arg compiler: The compiler to use to create the shared library. + :arg jitmodule: The JIT Module which can generate the code to compile. + :arg filename: The filename of the library to create. + :arg extension: extension of the source file (c, cpp). + :arg comm: Communicator over which to perform compilation. + :arg filename: Optional + Returns a :class:`ctypes.CDLL` object of the resulting shared + library.""" + # Compilation communicators are reference counted on the PyOP2 comm + icomm = mpi.internal_comm(comm, compiler) + ccomm = mpi.compilation_comm(icomm, compiler) + + # C or C++ + if extension == "cpp": + exe = compiler.cxx + compiler_flags = compiler.cxxflags + else: + exe = compiler.cc + compiler_flags = compiler.cflags + + # Compile on compilation communicator (ccomm) rank 0 + soname = None + if ccomm.rank == 0: + if filename is None: + # Adding random 2-digit hexnum avoids using excessive filesystem inodes + tempdir = MEM_TMP_DIR.joinpath(f"{randint(0, 255):02x}") + tempdir.mkdir(parents=True, exist_ok=True) + # This path + filename should be unique + descriptor, filename = mkstemp(suffix=f".{extension}", dir=tempdir, text=True) + filename = Path(filename) + else: + filename.parent.mkdir(exist_ok=True) + + cname = filename + oname = filename.with_suffix(".o") + soname = filename.with_suffix(".so") + logfile = filename.with_suffix(".log") + errfile = filename.with_suffix(".err") + with progress(INFO, 'Compiling wrapper'): + # Write source code to disk + with open(cname, "w") as fh: + fh.write(jitmodule.code_to_compile) + os.close(descriptor) + + if not compiler.ld: + # Compile and link + cc = (exe,) + compiler_flags + ('-o', str(soname), str(cname)) + compiler.ldflags + _run(cc, logfile, errfile) + else: + # Compile + cc = (exe,) + compiler_flags + ('-c', '-o', oname, cname) + _run(cc, logfile, errfile) + # Extract linker specific "cflags" from ldflags and link + ld = tuple(shlex.split(compiler.ld)) + ('-o', str(soname), str(oname)) + tuple(expandWl(compiler.ldflags)) + _run(ld, logfile, errfile, step="Linker", filemode="a") + + return ccomm.bcast(soname, root=0) + + +def _run(cc, logfile, errfile, step="Compilation", filemode="w"): + """ Run a compilation command and handle logging + errors. + """ + debug(f"{step} command: {' '.join(cc)}") + try: + if configuration['no_fork_available']: + redirect = ">" if filemode == "w" else ">>" + cc += (f"2{redirect}", str(errfile), redirect, str(logfile)) + cmd = " ".join(cc) + status = os.system(cmd) + if status != 0: + raise subprocess.CalledProcessError(status, cmd) + else: + with open(logfile, filemode) as log, open(errfile, filemode) as err: + log.write(f"{step} command:\n") + log.write(" ".join(cc)) + log.write("\n\n") + subprocess.check_call(cc, stderr=err, stdout=log) + except subprocess.CalledProcessError as e: + raise CompilationError(dedent(f""" + Command "{e.cmd}" return error status {e.returncode}. + Unable to compile code + Compile log in {logfile!s} + Compile errors in {errfile!s} + """)) + + def _add_profiling_events(dll, events): """ - If PyOP2 is in profiling mode, events are attached to dll to profile the local linear algebra calls. - The event is generated here in python and then set in the shared library, - so that memory is not allocated over and over again in the C kernel. The naming - convention is that the event ids are named by the event name prefixed by "ID_". + If PyOP2 is in profiling mode, events are attached to dll to profile the local linear algebra calls. + The event is generated here in python and then set in the shared library, + so that memory is not allocated over and over again in the C kernel. The naming + convention is that the event ids are named by the event name prefixed by "ID_". """ if PETSc.Log.isActive(): # also link the events from the linear algebra callables @@ -622,33 +671,34 @@ def _add_profiling_events(dll, events): ctypes.c_int.in_dll(dll, 'ID_'+e).value = PETSc.Log.Event(e).id -def clear_cache(prompt=False): - """Clear the PyOP2 compiler cache. +def clear_compiler_disk_cache(prompt=False): + """Clear the PyOP2 compiler disk cache. :arg prompt: if ``True`` prompt before removing any files """ - cachedir = configuration['cache_dir'] - - if not os.path.exists(cachedir): - print("Cache directory could not be found") - return - if len(os.listdir(cachedir)) == 0: - print("No cached libraries to remove") - return - - remove = True - if prompt: - user = input(f"Remove cached libraries from {cachedir}? [Y/n]: ") - - while user.lower() not in ['', 'y', 'n']: - print("Please answer y or n.") - user = input(f"Remove cached libraries from {cachedir}? [Y/n]: ") - - if user.lower() == 'n': - remove = False - - if remove: - print(f"Removing cached libraries from {cachedir}") - shutil.rmtree(cachedir, ignore_errors=True) - else: - print("Not removing cached libraries") + cachedirs = [configuration['cache_dir'], MEM_TMP_DIR] + + for directory in cachedirs: + if not os.path.exists(directory): + print("Cache directory could not be found") + continue + if len(os.listdir(directory)) == 0: + print("No cached libraries to remove") + continue + + remove = True + if prompt: + user = input(f"Remove cached libraries from {directory}? [Y/n]: ") + + while user.lower() not in ['', 'y', 'n']: + print("Please answer y or n.") + user = input(f"Remove cached libraries from {directory}? [Y/n]: ") + + if user.lower() == 'n': + remove = False + + if remove: + print(f"Removing cached libraries from {directory}") + shutil.rmtree(directory, ignore_errors=True) + else: + print("Not removing cached libraries") diff --git a/pyop2/configuration.py b/pyop2/configuration.py index 29717718c..0005ceeca 100644 --- a/pyop2/configuration.py +++ b/pyop2/configuration.py @@ -67,13 +67,16 @@ class Configuration(dict): to a node-local filesystem too. :param log_level: How chatty should PyOP2 be? Valid values are "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL". - :param print_cache_size: Should PyOP2 print the size of caches at + :param print_cache_size: Should PyOP2 print the cache information at program exit? :param matnest: Should matrices on mixed maps be built as nests? (Default yes) :param block_sparsity: Should sparsity patterns on datasets with cdim > 1 be built as block sparsities, or dof sparsities. The former saves memory but changes which preconditioners are available for the resulting matrices. (Default yes) + :param spmd_strict: Enable barriers for calls marked with @collective and + for cache access. This adds considerable overhead, but is useful for + tracking down deadlocks. (Default no) """ # name, env variable, type, default, write once cache_dir = os.path.join(gettempdir(), "pyop2-cache-uid%s" % os.getuid()) @@ -108,12 +111,14 @@ class Configuration(dict): ("PYOP2_NODE_LOCAL_COMPILATION", bool, True), "no_fork_available": ("PYOP2_NO_FORK_AVAILABLE", bool, False), - "print_cache_size": - ("PYOP2_PRINT_CACHE_SIZE", bool, False), + "print_cache_info": + ("PYOP2_CACHE_INFO", bool, False), "matnest": ("PYOP2_MATNEST", bool, True), "block_sparsity": - ("PYOP2_BLOCK_SPARSITY", bool, True) + ("PYOP2_BLOCK_SPARSITY", bool, True), + "spmd_strict": + ("PYOP2_SPMD_STRICT", bool, False), } """Default values for PyOP2 configuration parameters""" diff --git a/pyop2/exceptions.py b/pyop2/exceptions.py index 9211857d0..eec5eedac 100644 --- a/pyop2/exceptions.py +++ b/pyop2/exceptions.py @@ -146,3 +146,13 @@ class CompilationError(RuntimeError): class SparsityFormatError(ValueError): """Unable to produce a sparsity for this matrix format.""" + + +class CachingError(ValueError): + + """A caching error.""" + + +class HashError(CachingError): + + """Something is wrong with the hash.""" diff --git a/pyop2/global_kernel.py b/pyop2/global_kernel.py index 536d717e9..ae13dc1c5 100644 --- a/pyop2/global_kernel.py +++ b/pyop2/global_kernel.py @@ -1,17 +1,17 @@ import collections.abc import ctypes from dataclasses import dataclass -import itertools import os from typing import Optional, Tuple +import itertools import loopy as lp import numpy as np import pytools from petsc4py import PETSc -from pyop2 import compilation, mpi -from pyop2.caching import Cached +from pyop2 import mpi +from pyop2.compilation import load from pyop2.configuration import configuration from pyop2.datatypes import IntType, as_ctypes from pyop2.types import IterationRegion, Constant, READ @@ -247,7 +247,7 @@ def pack(self): return MatPack -class GlobalKernel(Cached): +class GlobalKernel: """Class representing the generated code for the global computation. :param local_kernel: :class:`pyop2.LocalKernel` instance representing the @@ -271,22 +271,6 @@ class GlobalKernel(Cached): :param pass_layer_arg: Should the wrapper pass the current layer into the kernel (as an `int`). Only makes sense for indirect extruded iteration. """ - - _cache = {} - - @classmethod - def _cache_key(cls, local_knl, arguments, **kwargs): - key = [cls, local_knl.cache_key, - *kwargs.items(), configuration["simd_width"]] - - key.extend([a.cache_key for a in arguments]) - - counter = itertools.count() - seen_maps = collections.defaultdict(lambda: next(counter)) - key.extend([seen_maps[m] for a in arguments for m in a.maps]) - - return tuple(key) - def __init__(self, local_kernel, arguments, *, extruded=False, extruded_periodic=False, @@ -294,9 +278,6 @@ def __init__(self, local_kernel, arguments, *, subset=False, iteration_region=None, pass_layer_arg=False): - if self._initialized: - return - if not len(local_kernel.accesses) == len(arguments): raise ValueError( "Number of arguments passed to the local and global kernels" @@ -320,6 +301,15 @@ def __init__(self, local_kernel, arguments, *, "Cannot request constant_layers argument for non-extruded iteration" ) + counter = itertools.count() + seen_maps = collections.defaultdict(lambda: next(counter)) + self.cache_key = ( + local_kernel.cache_key, + *[a.cache_key for a in arguments], + *[seen_maps[m] for a in arguments for m in a.maps], + extruded, extruded_periodic, constant_layers, subset, + iteration_region, pass_layer_arg, configuration["simd_width"] + ) self.local_kernel = local_kernel self.arguments = arguments self._extruded = extruded @@ -329,11 +319,6 @@ def __init__(self, local_kernel, arguments, *, self._iteration_region = iteration_region self._pass_layer_arg = pass_layer_arg - # Cache for stashing the compiled code - self._func_cache = {} - - self._initialized = True - @mpi.collective def __call__(self, comm, *args): """Execute the compiled kernel. @@ -341,15 +326,8 @@ def __call__(self, comm, *args): :arg comm: Communicator the execution is collective over. :*args: Arguments to pass to the compiled kernel. """ - # If the communicator changes then we cannot safely use the in-memory - # function cache. Note here that we are not using dup_comm to get a - # stable communicator id because we will already be using the internal one. - key = id(comm) - try: - func = self._func_cache[key] - except KeyError: - func = self.compile(comm) - self._func_cache[key] = func + # It is unnecessary to cache this call as it is cached in pyop2/compilation.py + func = self.compile(comm) func(*args) @property @@ -419,11 +397,15 @@ def compile(self, comm): + tuple(self.local_kernel.ldargs) ) - return compilation.load(self, extension, self.name, - cppargs=cppargs, - ldargs=ldargs, - restype=ctypes.c_int, - comm=comm) + return load( + self, + extension, + self.name, + cppargs=cppargs, + ldargs=ldargs, + restype=ctypes.c_int, + comm=comm + ) @cached_property def argtypes(self): diff --git a/pyop2/mpi.py b/pyop2/mpi.py index 554155f20..7e88b8dd0 100644 --- a/pyop2/mpi.py +++ b/pyop2/mpi.py @@ -37,6 +37,7 @@ from petsc4py import PETSc from mpi4py import MPI # noqa from itertools import count +from functools import wraps import atexit import gc import glob @@ -160,13 +161,64 @@ class PyOP2CommError(ValueError): # PYOP2_FINALISED flag. -def collective(fn): - extra = trim(""" - This function is logically collective over MPI ranks, it is an - error to call it on fewer than all the ranks in MPI communicator. - """) - fn.__doc__ = "%s\n\n%s" % (trim(fn.__doc__), extra) if fn.__doc__ else extra - return fn +if configuration["spmd_strict"]: + def collective(fn): + extra = trim(""" + This function is logically collective over MPI ranks, it is an + error to call it on fewer than all the ranks in MPI communicator. + PYOP2_SPMD_STRICT=1 is in your environment and function calls will be + guarded by a barrier where possible. + """) + + @wraps(fn) + def wrapper(*args, **kwargs): + comms = filter( + lambda arg: isinstance(arg, MPI.Comm), + args + tuple(kwargs.values()) + ) + try: + comm = next(comms) + except StopIteration: + if args and hasattr(args[0], "comm"): + comm = args[0].comm + else: + comm = None + + if comm is None: + debug( + "`@collective` wrapper found no communicators in args or kwargs, " + "this means that the call is implicitly collective over an " + "unknown communicator. " + f"The following call to {fn.__module__}.{fn.__qualname__} is " + "not protected by an MPI barrier." + ) + subcomm = ", UNKNOWN Comm" + else: + subcomm = f", {comm.name} R{comm.rank}" + + debug_string_pt1 = f"{COMM_WORLD.name} R{COMM_WORLD.rank}{subcomm}: " + debug_string_pt2 = f" {fn.__module__}.{fn.__qualname__}" + debug(debug_string_pt1 + "Entering" + debug_string_pt2) + if comm is not None: + comm.Barrier() + value = fn(*args, **kwargs) + debug(debug_string_pt1 + "Leaving" + debug_string_pt2) + if comm is not None: + comm.Barrier() + return value + + wrapper.__doc__ = f"{trim(fn.__doc__)}\n\n{extra}" if fn.__doc__ else extra + return wrapper +else: + def collective(fn): + extra = trim(""" + This function is logically collective over MPI ranks, it is an + error to call it on fewer than all the ranks in MPI communicator. + You can set PYOP2_SPMD_STRICT=1 in your environment to try and catch + non-collective calls. + """) + fn.__doc__ = f"{trim(fn.__doc__)}\n\n{extra}" if fn.__doc__ else extra + return fn def delcomm_outer(comm, keyval, icomm): @@ -227,6 +279,7 @@ def delcomm_outer(comm, keyval, icomm): innercomm_keyval = MPI.Comm.Create_keyval(delete_fn=delcomm_outer) outercomm_keyval = MPI.Comm.Create_keyval() compilationcomm_keyval = MPI.Comm.Create_keyval(delete_fn=delcomm_outer) +comm_cache_keyval = MPI.Comm.Create_keyval() def is_pyop2_comm(comm): @@ -539,22 +592,16 @@ def _free_comms(): debug(f"Freeing {comm.name}, with index {key}, which has refcount {refcount[0]}") comm.Free() del _DUPED_COMM_DICT[key] - for kv in [refcount_keyval, - innercomm_keyval, - outercomm_keyval, - compilationcomm_keyval]: + for kv in [ + refcount_keyval, + innercomm_keyval, + outercomm_keyval, + compilationcomm_keyval, + comm_cache_keyval + ]: MPI.Comm.Free_keyval(kv) -def hash_comm(comm): - """Return a hashable identifier for a communicator.""" - if not is_pyop2_comm(comm): - raise PyOP2CommError("`comm` passed to `hash_comm()` must be a PyOP2 communicator") - # `comm` must be a PyOP2 communicator so we can use its id() - # as the hash and this is stable between invocations. - return id(comm) - - # Install an exception hook to MPI Abort if an exception isn't caught # see: https://groups.google.com/d/msg/mpi4py/me2TFzHmmsQ/sSF99LE0t9QJ if COMM_WORLD.size > 1: diff --git a/pyop2/op2.py b/pyop2/op2.py index 85788eafa..35e5649f4 100644 --- a/pyop2/op2.py +++ b/pyop2/op2.py @@ -112,11 +112,10 @@ def init(**kwargs): @collective def exit(): """Exit OP2 and clean up""" - if configuration['print_cache_size'] and COMM_WORLD.rank == 0: - from caching import report_cache, Cached, ObjectCached - print('**** PyOP2 cache sizes at exit ****') - report_cache(typ=ObjectCached) - report_cache(typ=Cached) + if configuration['print_cache_info'] and COMM_WORLD.rank == 0: + from pyop2.caching import print_cache_stats + print(f"{' PyOP2 cache sizes on rank 0 at exit ':*^120}") + print_cache_stats(alive=False) configuration.reset() global _initialised _initialised = False diff --git a/pyop2/utils.py b/pyop2/utils.py index 11b4ead5b..2f26741e1 100644 --- a/pyop2/utils.py +++ b/pyop2/utils.py @@ -40,29 +40,12 @@ from decorator import decorator import argparse +from functools import cached_property # noqa: F401 + from pyop2.exceptions import DataTypeError, DataValueError from pyop2.configuration import configuration -class cached_property(object): - - '''A read-only @property that is only evaluated once. The value is cached - on the object itself rather than the function or class; this should prevent - memory leakage.''' - - def __init__(self, fget, doc=None): - self.fget = fget - self.__doc__ = doc or fget.__doc__ - self.__name__ = fget.__name__ - self.__module__ = fget.__module__ - - def __get__(self, obj, cls): - if obj is None: - return self - obj.__dict__[self.__name__] = result = self.fget(obj) - return result - - def as_tuple(item, type=None, length=None, allow_none=False): # Empty list if we get passed None if item is None: diff --git a/requirements-git.txt b/requirements-git.txt index d6f3d2182..a8f7fb67f 100644 --- a/requirements-git.txt +++ b/requirements-git.txt @@ -1 +1,2 @@ git+https://github.com/firedrakeproject/loopy.git@main#egg=loopy +git+https://github.com/firedrakeproject/pytest-mpi.git@main#egg=pytest-mpi diff --git a/scripts/pyop2-clean b/scripts/pyop2-clean index ab29f1245..52f667ec4 100755 --- a/scripts/pyop2-clean +++ b/scripts/pyop2-clean @@ -1,6 +1,6 @@ #!/usr/bin/env python -from pyop2.compilation import clear_cache +from pyop2.compilation import clear_compiler_disk_cache if __name__ == '__main__': - clear_cache(prompt=True) + clear_compiler_disk_cache(prompt=True) diff --git a/test/unit/test_caching.py b/test/unit/test_caching.py index 40c4256fb..1298991b3 100644 --- a/test/unit/test_caching.py +++ b/test/unit/test_caching.py @@ -31,14 +31,30 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED # OF THE POSSIBILITY OF SUCH DAMAGE. - +import ctypes import os import pytest import tempfile -import cachetools import numpy -from pyop2 import op2, mpi -from pyop2.caching import disk_cached +from itertools import chain +from textwrap import dedent +from pyop2 import op2 +from pyop2.caching import ( + DEFAULT_CACHE, + disk_only_cache, + memory_cache, + memory_and_disk_cache, + clear_memory_cache +) +from pyop2.compilation import load +from pyop2.mpi import ( + MPI, + COMM_WORLD, + COMM_SELF, + comm_cache_keyval, + internal_comm, + temp_internal_comm +) def _seed(): @@ -46,6 +62,7 @@ def _seed(): nelems = 8 +default_cache_name = DEFAULT_CACHE().__class__.__name__ @pytest.fixture @@ -75,7 +92,7 @@ def dindset2(indset): @pytest.fixture def g(): - return op2.Global(1, 0, numpy.uint32, "g", comm=mpi.COMM_WORLD) + return op2.Global(1, 0, numpy.uint32, "g", comm=COMM_WORLD) @pytest.fixture @@ -284,7 +301,14 @@ class TestGeneratedCodeCache: Generated Code Cache Tests. """ - cache = op2.GlobalKernel._cache + @property + def cache(self): + int_comm = internal_comm(COMM_WORLD, self) + _cache_collection = int_comm.Get_attr(comm_cache_keyval) + if _cache_collection is None: + _cache_collection = {default_cache_name: DEFAULT_CACHE()} + int_comm.Set_attr(comm_cache_keyval, _cache_collection) + return _cache_collection[default_cache_name] @pytest.fixture def a(cls, diterset): @@ -448,7 +472,7 @@ def test_change_dat_dtype_matters(self, iterset, diterset): assert len(self.cache) == 2 def test_change_global_dtype_matters(self, iterset, diterset): - g = op2.Global(1, 0, dtype=numpy.uint32, comm=mpi.COMM_WORLD) + g = op2.Global(1, 0, dtype=numpy.uint32, comm=COMM_WORLD) self.cache.clear() assert len(self.cache) == 0 @@ -458,7 +482,7 @@ def test_change_global_dtype_matters(self, iterset, diterset): assert len(self.cache) == 1 - g = op2.Global(1, 0, dtype=numpy.float64, comm=mpi.COMM_WORLD) + g = op2.Global(1, 0, dtype=numpy.float64, comm=COMM_WORLD) op2.par_loop(k, iterset, g(op2.INC)) assert len(self.cache) == 2 @@ -526,70 +550,259 @@ def test_sparsities_different_ordered_map_tuple_cached(self, m1, m2, ds2): class TestDiskCachedDecorator: @staticmethod - def myfunc(arg): + def myfunc(arg, comm): """Example function to cache the outputs of.""" return {arg} - def collective_key(self, *args): - """Return a cache key suitable for use when collective over a communicator.""" - self.comm = mpi.internal_comm(mpi.COMM_SELF, self) - return self.comm, cachetools.keys.hashkey(*args) - @pytest.fixture - def cache(cls): - return {} + def comm(self): + """This fixture provides a temporary comm so that each test gets it's own + communicator and that caches are cleaned on free.""" + temporary_comm = COMM_WORLD.Dup() + temporary_comm.name = "pytest temp COMM_WORLD" + with temp_internal_comm(temporary_comm) as comm: + yield comm + temporary_comm.Free() @pytest.fixture def cachedir(cls): return tempfile.TemporaryDirectory() - def test_decorator_in_memory_cache_reuses_results(self, cache, cachedir): - decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) + def test_decorator_in_memory_cache_reuses_results(self, cachedir, comm): + decorated_func = memory_and_disk_cache( + cachedir=cachedir.name + )(self.myfunc) - obj1 = decorated_func("input1") - assert len(cache) == 1 + obj1 = decorated_func("input1", comm=comm) + mem_cache = comm.Get_attr(comm_cache_keyval)[default_cache_name] + assert len(mem_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 - obj2 = decorated_func("input1") + obj2 = decorated_func("input1", comm=comm) assert obj1 is obj2 - assert len(cache) == 1 - assert len(os.listdir(cachedir.name)) == 1 - - def test_decorator_collective_has_different_in_memory_key(self, cache, cachedir): - decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) - collective_func = disk_cached(cache, cachedir.name, self.collective_key, - collective=True)(self.myfunc) - - obj1 = collective_func("input1") - assert len(cache) == 1 - assert len(os.listdir(cachedir.name)) == 1 - - # The new entry should have a different in-memory key since the communicator - # is not included but the same key on disk. - obj2 = decorated_func("input1") - assert obj1 == obj2 and obj1 is not obj2 - assert len(cache) == 2 + assert len(mem_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 - def test_decorator_disk_cache_reuses_results(self, cache, cachedir): - decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) - - obj1 = decorated_func("input1") - cache.clear() - obj2 = decorated_func("input1") + def test_decorator_uses_different_in_memory_caches_on_different_comms(self, cachedir, comm): + comm_world_func = memory_and_disk_cache( + cachedir=cachedir.name + )(self.myfunc) + + temporary_comm = COMM_SELF.Dup() + temporary_comm.name = "pytest temp COMM_SELF" + with temp_internal_comm(temporary_comm) as comm_self: + comm_self_func = memory_and_disk_cache( + cachedir=cachedir.name + )(self.myfunc) + + # obj1 should be cached on the COMM_WORLD cache + obj1 = comm_world_func("input1", comm=comm) + comm_world_cache = comm.Get_attr(comm_cache_keyval)[default_cache_name] + assert len(comm_world_cache) == 1 + assert len(os.listdir(cachedir.name)) == 1 + + # obj2 should be cached on the COMM_SELF cache + obj2 = comm_self_func("input1", comm=comm_self) + comm_self_cache = comm_self.Get_attr(comm_cache_keyval)[default_cache_name] + assert obj1 == obj2 and obj1 is not obj2 + assert len(comm_world_cache) == 1 + assert len(comm_self_cache) == 1 + assert len(os.listdir(cachedir.name)) == 1 + + temporary_comm.Free() + + def test_decorator_disk_cache_reuses_results(self, cachedir, comm): + decorated_func = memory_and_disk_cache(cachedir=cachedir.name)(self.myfunc) + + obj1 = decorated_func("input1", comm=comm) + clear_memory_cache(comm) + obj2 = decorated_func("input1", comm=comm) + mem_cache = comm.Get_attr(comm_cache_keyval)[default_cache_name] assert obj1 == obj2 and obj1 is not obj2 - assert len(cache) == 1 + assert len(mem_cache) == 1 assert len(os.listdir(cachedir.name)) == 1 - def test_decorator_cache_misses(self, cache, cachedir): - decorated_func = disk_cached(cache, cachedir.name)(self.myfunc) + def test_decorator_cache_misses(self, cachedir, comm): + decorated_func = memory_and_disk_cache(cachedir=cachedir.name)(self.myfunc) - obj1 = decorated_func("input1") - obj2 = decorated_func("input2") + obj1 = decorated_func("input1", comm=comm) + obj2 = decorated_func("input2", comm=comm) + mem_cache = comm.Get_attr(comm_cache_keyval)[default_cache_name] assert obj1 != obj2 - assert len(cache) == 2 + assert len(mem_cache) == 2 assert len(os.listdir(cachedir.name)) == 2 +# Test updated caching functionality +class StateIncrement: + """Simple class for keeping track of the number of times executed + """ + def __init__(self): + self._count = 0 + + def __call__(self): + self._count += 1 + return self._count + + @property + def value(self): + return self._count + + +def twople(x): + return (x, )*2 + + +def threeple(x): + return (x, )*3 + + +def n_comms(n): + return [MPI.COMM_WORLD]*n + + +def n_ops(n): + return [MPI.SUM]*n + + +# decorator = parallel_memory_only_cache, parallel_memory_only_cache_no_broadcast, disk_only_cached +def function_factory(state, decorator, f, **kwargs): + def custom_function(x, comm=COMM_WORLD): + state() + return f(x) + + return decorator(**kwargs)(custom_function) + + +@pytest.fixture +def state(): + return StateIncrement() + + +@pytest.mark.parametrize("decorator, uncached_function", [ + (memory_cache, twople), + (memory_cache, n_comms), + (memory_and_disk_cache, twople), + (disk_only_cache, twople) +]) +def test_function_args_twice_caches(request, state, decorator, uncached_function, tmpdir): + if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}: + kwargs = {"cachedir": tmpdir} + else: + kwargs = {} + + cached_function = function_factory(state, decorator, uncached_function, **kwargs) + assert state.value == 0 + first = cached_function(2, comm=COMM_WORLD) + assert first == uncached_function(2) + assert state.value == 1 + second = cached_function(2, comm=COMM_WORLD) + assert second == uncached_function(2) + if request.node.callspec.params["decorator"] is not disk_only_cache: + assert second is first + assert state.value == 1 + + clear_memory_cache(COMM_WORLD) + + +@pytest.mark.parametrize("decorator, uncached_function", [ + (memory_cache, twople), + (memory_cache, n_comms), + (memory_and_disk_cache, twople), + (disk_only_cache, twople) +]) +def test_function_args_different(request, state, decorator, uncached_function, tmpdir): + if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}: + kwargs = {"cachedir": tmpdir} + else: + kwargs = {} + + cached_function = function_factory(state, decorator, uncached_function, **kwargs) + assert state.value == 0 + first = cached_function(2, comm=COMM_WORLD) + assert first == uncached_function(2) + assert state.value == 1 + second = cached_function(3, comm=COMM_WORLD) + assert second == uncached_function(3) + assert state.value == 2 + + clear_memory_cache(COMM_WORLD) + + +@pytest.mark.parallel(nprocs=3) +@pytest.mark.parametrize("decorator, uncached_function", [ + (memory_cache, twople), + (memory_cache, n_comms), + (memory_and_disk_cache, twople), + (disk_only_cache, twople) +]) +def test_function_over_different_comms(request, state, decorator, uncached_function, tmpdir): + if request.node.callspec.params["decorator"] in {disk_only_cache, memory_and_disk_cache}: + # In parallel different ranks can get different tempdirs, we just want one + tmpdir = COMM_WORLD.bcast(tmpdir, root=0) + kwargs = {"cachedir": tmpdir} + else: + kwargs = {} + + cached_function = function_factory(state, decorator, uncached_function, **kwargs) + assert state.value == 0 + + for ii in range(10): + color = 0 if COMM_WORLD.rank < 2 else MPI.UNDEFINED + comm12 = COMM_WORLD.Split(color=color) + if COMM_WORLD.rank < 2: + _ = cached_function(2, comm=comm12) + comm12.Free() + + color = 0 if COMM_WORLD.rank > 0 else MPI.UNDEFINED + comm23 = COMM_WORLD.Split(color=color) + if COMM_WORLD.rank > 0: + _ = cached_function(2, comm=comm23) + comm23.Free() + + clear_memory_cache(COMM_WORLD) + + +# pyop2/compilation.py uses a custom cache which we test here +@pytest.mark.parallel(nprocs=2) +def test_writing_large_so(): + # This test exercises the compilation caching when handling larger files + if COMM_WORLD.rank == 0: + preamble = dedent("""\ + #include \n + void big(double *result){ + """) + variables = (f"v{next(tempfile._get_candidate_names())}" for _ in range(128*1024)) + lines = (f" double {v} = {hash(v)/1000000000};\n *result += {v};\n" for v in variables) + program = "\n".join(chain.from_iterable(((preamble, ), lines, ("}\n", )))) + with open("big.c", "w") as fh: + fh.write(program) + + COMM_WORLD.Barrier() + with open("big.c", "r") as fh: + program = fh.read() + + if COMM_WORLD.rank == 1: + os.remove("big.c") + + fn = load(program, "c", "big", argtypes=(ctypes.c_voidp,), comm=COMM_WORLD) + assert fn is not None + + +@pytest.mark.parallel(nprocs=2) +def test_two_comms_compile_the_same_code(): + new_comm = COMM_WORLD.Split(color=COMM_WORLD.rank) + new_comm.name = "test_two_comms" + code = dedent("""\ + #include \n + void noop(){ + printf("Do nothing!\\n"); + } + """) + + fn = load(code, "c", "noop", argtypes=(), comm=COMM_WORLD) + assert fn is not None + + if __name__ == '__main__': pytest.main(os.path.abspath(__file__))