Skip to content
This repository has been archived by the owner on Nov 27, 2024. It is now read-only.

Commit

Permalink
Add instrumentation to caches
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Aug 19, 2024
1 parent 5a1476c commit be1e58b
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 102 deletions.
291 changes: 208 additions & 83 deletions pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,52 +32,34 @@
# OF THE POSSIBILITY OF SUCH DAMAGE.

"""Provides common base classes for cached objects."""

import cachetools
import hashlib
import os
import pickle
from collections.abc import MutableMapping
from pathlib import Path
from warnings import warn # noqa F401
from functools import wraps
from collections import defaultdict
from itertools import count
from functools import partial, wraps

from pyop2.configuration import configuration
from pyop2.logger import debug
from pyop2.mpi import MPI, COMM_WORLD, comm_cache_keyval
from pyop2.mpi import (
MPI, COMM_WORLD, comm_cache_keyval, temp_internal_comm
)


# TODO: Remove this? Rewrite?
def report_cache(typ):
"""Report the size of caches of type ``typ``
:arg typ: A class of cached object. For example
:class:`ObjectCached` or :class:`Cached`.
"""
from collections import defaultdict
from inspect import getmodule
from gc import get_objects
typs = defaultdict(lambda: 0)
n = 0
for x in get_objects():
if isinstance(x, typ):
typs[type(x)] += 1
n += 1
if n == 0:
print("\nNo %s objects in caches" % typ.__name__)
return
print("\n%d %s objects in caches" % (n, typ.__name__))
print("Object breakdown")
print("================")
for k, v in typs.iteritems():
mod = getmodule(k)
if mod is not None:
name = "%s.%s" % (mod.__name__, k.__name__)
else:
name = k.__name__
print('%s: %d' % (name, v))
# Caches created here are registered as a tuple of
# (creation_index, comm, comm.name, function, cache)
# in _KNOWN_CACHES
_CACHE_CIDX = count()
_KNOWN_CACHES = []
# Flag for outputting information at the end of testing (do not abuse!)
_running_on_ci = bool(os.environ.get('PYOP2_CI_TESTS'))


# FIXME: (Later) Remove ObjectCached
class ObjectCached(object):
"""Base class for objects that should be cached on another object.
Expand Down Expand Up @@ -163,6 +145,95 @@ def make_obj():
return obj


def cache_stats(comm=None, comm_name=None, alive=True, function=None, cache_type=None):
caches = _KNOWN_CACHES
if comm is not None:
with temp_internal_comm(comm) as icomm:
cache_collection = icomm.Get_attr(comm_cache_keyval)
if cache_collection is None:
print(f"Communicator {icomm.name} as no associated caches")
comm_name = icomm.name
if comm_name is not None:
caches = filter(lambda c: c[2] == comm_name, caches)
if alive:
caches = filter(lambda c: c[1] != MPI.COMM_NULL, caches)
if function is not None:
if isinstance(function, str):
caches = filter(lambda c: function in c[3].__qualname__, caches)
else:
caches = filter(lambda c: c[3] is function, caches)
if cache_type is not None:
if isinstance(cache_type, str):
caches = filter(lambda c: cache_type in c[4].__qualname__, caches)
else:
caches = filter(lambda c: isinstance(c[4], cache_type), caches)
return [*caches]


def get_stats(cache):
hit = miss = size = maxsize = -1
if isinstance(cache, cachetools.Cache):
size = cache.currsize
maxsize = cache.maxsize
if hasattr(cache, "instrument__"):
hit = cache.hit
miss = cache.miss
if size is None:
try:
size = len(cache)
except NotImplementedError:
pass
if maxsize is None:
try:
maxsize = cache.max_size
except AttributeError:
pass
return hit, miss, size, maxsize


def print_cache_stats(*args, **kwargs):
data = defaultdict(lambda: defaultdict(list))
for entry in cache_stats(*args, **kwargs):
ecid, ecomm, ecomm_name, efunction, ecache = entry
active = (ecomm != MPI.COMM_NULL)
data[(ecomm_name, active)][ecache.__class__.__name__].append(
(ecid, efunction.__module__, efunction.__name__, ecache)
)

tab = " "
hline = "-"*120
col = (90, 27)
stats_col = (6, 6, 6, 6)
stats = ("hit", "miss", "size", "max")
no_stats = "|".join(" "*ii for ii in stats_col)
print(hline)
print(f"|{'Cache':^{col[0]}}|{'Stats':^{col[1]}}|")
subtitles = "|".join(f"{st:^{w}}" for st, w in zip(stats, stats_col))
print("|" + " "*col[0] + f"|{subtitles:{col[1]}}|")
print(hline)
for ecomm, cachedict in data.items():
active = "Active" if ecomm[1] else "Freed"
comm_title = f"{ecomm[0]} ({active})"
print(f"|{comm_title:{col[0]}}|{no_stats}|")
for ecache, function_list in cachedict.items():
cache_title = f"{tab}{ecache}"
print(f"|{cache_title:{col[0]}}|{no_stats}|")
try:
loc = function_list[0][-1].cachedir
except AttributeError:
loc = "Memory"
cache_location = f"{tab}{loc!s}"
if len(str(loc)) < col[0] - 5:
print(f"|{cache_location:{col[0]}}|{no_stats}|")
else:
print(f"|{cache_location:78}|")
for entry in function_list:
function_title = f"{tab*2}id={entry[0]} {'.'.join(entry[1:3])}"
stats = "|".join(f"{s:{w}}" for s, w in zip(get_stats(entry[3]), stats_col))
print(f"|{function_title:{col[0]}}|{stats:{col[1]}}|")
print(hline)


class _CacheMiss:
pass

Expand All @@ -180,11 +251,6 @@ def _as_hexdigest(*args):
return hash_.hexdigest()


def clear_memory_cache(comm):
if comm.Get_attr(comm_cache_keyval) is not None:
comm.Set_attr(comm_cache_keyval, {})


class DictLikeDiskAccess(MutableMapping):
def __init__(self, cachedir):
"""
Expand Down Expand Up @@ -224,17 +290,21 @@ def __setitem__(self, key, value):
tempfile.rename(filepath)

def __delitem__(self, key):
raise ValueError(f"Cannot remove items from {self.__class__.__name__}")
raise NotImplementedError(f"Cannot remove items from {self.__class__.__name__}")

def __iter__(self):
raise ValueError(f"Cannot iterate over keys in {self.__class__.__name__}")
raise NotImplementedError(f"Cannot iterate over keys in {self.__class__.__name__}")

def __len__(self):
raise ValueError(f"Cannot query length of {self.__class__.__name__}")
raise NotImplementedError(f"Cannot query length of {self.__class__.__name__}")

def __repr__(self):
return f"{self.__class__.__name__}(cachedir={self.cachedir})"

def __eq__(self, other):
# Instances are the same if they have the same cachedir
return self.cachedir == other.cachedir

def open(self, *args, **kwargs):
return open(*args, **kwargs)

Expand Down Expand Up @@ -271,10 +341,47 @@ def default_parallel_hashkey(*args, **kwargs):
return cachetools.keys.hashkey(*hash_args, **hash_kwargs)


def instrument(cls):
@wraps(cls, updated=())
class _wrapper(cls):
instrument__ = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hit = 0
self.miss = 0

def get(self, key, default=None):
value = super().get(key, default)
if value is default:
self.miss += 1
else:
self.hit += 1
return value

def __getitem__(self, key):
try:
value = super().__getitem__(key)
self.hit += 1
except KeyError as e:
self.miss += 1
raise e
return value
return _wrapper


class DEFAULT_CACHE(dict):
pass


# Examples of how to instrument and use different default caches:
# - DEFAULT_CACHE = instrument(DEFAULT_CACHE)
# - DEFAULT_CACHE = instrument(cachetools.LRUCache)
# - DEFAULT_CACHE = partial(DEFAULT_CACHE, maxsize=100)
EXOTIC_CACHE = partial(instrument(cachetools.LRUCache), maxsize=100)
# - DictLikeDiskAccess = instrument(DictLikeDiskAccess)


def parallel_cache(
hashkey=default_parallel_hashkey,
comm_fetcher=default_comm_fetcher,
Expand All @@ -299,54 +406,62 @@ def wrapper(*args, **kwargs):
""" Extract the key and then try the memory cache before falling back
on calling the function and populating the cache.
"""
comm = comm_fetcher(*args, **kwargs)
k = hashkey(*args, **kwargs)
key = _as_hexdigest(*k), func.__qualname__

# Fetch the per-comm cache_collection or set it up if not present
# A collection is required since different types of cache can be set up on the same comm
cache_collection = comm.Get_attr(comm_cache_keyval)
if cache_collection is None:
cache_collection = {}
comm.Set_attr(comm_cache_keyval, cache_collection)
# If this kind of cache is already present on the
# cache_collection, get it, otherwise create it
local_cache = cache_collection.setdefault(
(cf := cache_factory()).__class__.__name__,
cf
)

if broadcast:
# Grab value from rank 0 memory cache and broadcast result
if comm.rank == 0:
# Create a PyOP2 comm associated with the key, so it is decrefed when the wrapper exits
with temp_internal_comm(comm_fetcher(*args, **kwargs)) as comm:
# Fetch the per-comm cache_collection or set it up if not present
# A collection is required since different types of cache can be set up on the same comm
cache_collection = comm.Get_attr(comm_cache_keyval)
if cache_collection is None:
cache_collection = {}
comm.Set_attr(comm_cache_keyval, cache_collection)
# If this kind of cache is already present on the
# cache_collection, get it, otherwise create it
local_cache = cache_collection.setdefault(
(cf := cache_factory()).__class__.__name__,
cf
)
local_cache = cache_collection[cf.__class__.__name__]

# If this is a new cache or function add it to the list of known caches
if (comm, comm.name, func, local_cache) not in [k[1:] for k in _KNOWN_CACHES]:
_KNOWN_CACHES.append((next(_CACHE_CIDX), comm, comm.name, func, local_cache))

if broadcast:
# Grab value from rank 0 memory cache and broadcast result
if comm.rank == 0:
value = local_cache.get(key, CACHE_MISS)
if value is CACHE_MISS:
debug(
f"{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: "
f"{k} {local_cache.__class__.__name__} cache miss"
)
else:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit')
# TODO: Add communication tags to avoid cross-broadcasting
comm.bcast(value, root=0)
else:
value = comm.bcast(CACHE_MISS, root=0)
if isinstance(value, _CacheMiss):
# We might have the CACHE_MISS from rank 0 and
# `(value is CACHE_MISS) == False` which is confusing,
# so we set it back to the local value
value = CACHE_MISS
else:
# Grab value from all ranks cache and broadcast cache hit/miss
value = local_cache.get(key, CACHE_MISS)
if value is CACHE_MISS:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache miss')
cache_hit = False
else:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit')
# TODO: Add communication tags to avoid cross-broadcasting
comm.bcast(value, root=0)
else:
value = comm.bcast(CACHE_MISS, root=0)
if isinstance(value, _CacheMiss):
# We might have the CACHE_MISS from rank 0 and
# `(value is CACHE_MISS) == False` which is confusing,
# so we set it back to the local value
value = CACHE_MISS
else:
# Grab value from all ranks cache and broadcast cache hit/miss
value = local_cache.get(key, CACHE_MISS)
if value is CACHE_MISS:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache miss')
cache_hit = False
else:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} {local_cache.__class__.__name__} cache hit')
cache_hit = True
all_present = comm.allgather(cache_hit)
cache_hit = True
all_present = comm.allgather(cache_hit)

# If not present in the cache of all ranks we need to recompute on all ranks
if not min(all_present):
value = CACHE_MISS
# If not present in the cache of all ranks we need to recompute on all ranks
if not min(all_present):
value = CACHE_MISS

if value is CACHE_MISS:
value = func(*args, **kwargs)
Expand All @@ -356,6 +471,12 @@ def wrapper(*args, **kwargs):
return decorator


def clear_memory_cache(comm):
with temp_internal_comm(comm) as icomm:
if icomm.Get_attr(comm_cache_keyval) is not None:
icomm.Set_attr(comm_cache_keyval, {})


# A small collection of default simple caches
memory_cache = parallel_cache

Expand All @@ -374,8 +495,12 @@ def decorator(func):
return decorator

# TODO: (Wishlist)
# * Try more exotic caches ie: memory_cache = partial(parallel_cache, cache_factory=lambda: cachetools.LRUCache(maxsize=1000))
# * Add some sort of cache reporting
# * Add some sort of cache statistics
# * Try more exotic caches ie: memory_cache = partial(parallel_cache, cache_factory=lambda: cachetools.LRUCache(maxsize=1000))
# * Add some sort of cache reporting
# * Add some sort of cache statistics
# * Refactor compilation.py to use @mem_and_disk_cached, where get_so is just uses DictLikeDiskAccess with an overloaded self.write() method
# * Systematic investigation into cache sizes/types for Firedrake
# - Is a mem cache needed for DLLs?
# - Is LRUCache better than a simple dict? (memory profile test suite)
# - What is the optimal maxsize?
# * Add some docstrings and maybe some exposition!
4 changes: 2 additions & 2 deletions pyop2/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ class Configuration(dict):
("PYOP2_NODE_LOCAL_COMPILATION", bool, True),
"no_fork_available":
("PYOP2_NO_FORK_AVAILABLE", bool, False),
"print_cache_size":
("PYOP2_PRINT_CACHE_SIZE", bool, False),
"print_cache_info":
("PYOP2_CACHE_INFO", bool, False),
"matnest":
("PYOP2_MATNEST", bool, True),
"block_sparsity":
Expand Down
Loading

0 comments on commit be1e58b

Please sign in to comment.