From 9092ff1ed3e0def1bfa1ffb22cdc5cbd4751bfa4 Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Thu, 8 Jun 2023 11:31:23 +0100 Subject: [PATCH 1/2] WIP --- pyop2/caching.py | 150 ++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 148 insertions(+), 2 deletions(-) diff --git a/pyop2/caching.py b/pyop2/caching.py index 24a3f5513..3605d9cb1 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -33,15 +33,19 @@ """Provides common base classes for cached objects.""" +import abc +import collections import hashlib import os -from pathlib import Path import pickle +from collections import defaultdict +from functools import partial +from pathlib import Path import cachetools from pyop2.configuration import configuration -from pyop2.mpi import hash_comm +from pyop2.mpi import COMM_WORLD, hash_comm, temp_internal_comm from pyop2.utils import cached_property @@ -343,3 +347,145 @@ def _disk_cache_set(cachedir, key, value): with open(tempfile, "wb") as f: pickle.dump(value, f) tempfile.rename(filepath) + + +class CountedCache(collections.abc.MutableMapping, abc.ABC): + def __init__(self): + self.naccesses = 0 + self.nhits = 0 + + def __getitem__(self, key): + self.naccesses += 1 + value = self._data[key] + self.nhits += 1 + return value + + def __setitem__(self, key, value): + self._data[key] = value + + def __delitem__(self, key): + del self._data[key] + + def __len__(self): + return len(self._data) + + def __iter__(self): + return iter(self._data) + + def __str__(self) -> str: + return f"{type(self)}({self._data})" + + def clear(self) -> None: + self._data.clear() + self.naccesses = 0 + self.nhits = 0 + + @property + def nmisses(self) -> int: + return self.naccesses - self.nhits + + @property + def hit_rate(self) -> float: + try: + return self.nhits / self.naccesses + except ZeroDivisionError: + return 0. + + @property + def miss_rate(self) -> float: + try: + return self.nmisses / self.naccesses + except ZeroDivisionError: + return 0. + + +class CountedNoEvictCache(CountedCache): + def __init__(self): + super().__init__() + self._data = {} + + +class CountedLRUCache(CountedCache): + def __init__(self, maxsize=32): + super().__init__() + self._data = cachetools.LRUCache(maxsize) + + +class PCache(abc.ABC): + """Parallel-safe cache.""" + + def __getitem__(self, key): + comm, key = key + return self.cache(comm)[key] + + def __setitem__(self, key, value): + comm, key = key + self.cache(comm)[key] = value + + def __delitem__(self, key): + comm, key = key + del self.cache(comm)[key] + + def __str__(self) -> str: + return f"{type(self)}({self._caches})" + + def clear(self, comm): + self.cache(comm).clear() + + def currsize(self, comm): + return len(self.cache(comm)) + + def cache(self, comm): + with temp_internal_comm(comm) as icomm: + return self._caches[hash_comm(icomm)] + + +class PNoEvictCache(PCache): + """Parallel-safe cache that does not evict entries.""" + + def __init__(self): + super().__init__() + self._caches = defaultdict(CountedNoEvictCache) + + +class PLRUCache(PCache): + """Parallel-safe LRU cache.""" + + def __init__(self, maxsize=32): + super().__init__() + self._caches = defaultdict(partial(CountedLRUCache, maxsize)) + self.maxsize = maxsize + + + +class CacheManager(dict): + """Object that keeps track of multiple global caches.""" + + def __init__(self, name): + super().__init__() + self.name = name + + def add_cache(self, cache_id, cache=None): + if cache_id in self: + raise ValueError(f"A cache has already been registered under {cache_id}") + + if cache is None: + cache = {} + self[cache_id] = cache + return cache + + def clear(self, cache_id=None, error_if_missing=True, **kwargs): + if cache_id is None: + cache_ids = self.keys() + else: + cache_ids = [cache_id] + + for cache_id in cache_ids: + if cache_id in self: + self[cache_id].clear(**kwargs) + else: + if error_if_missing: + raise ValueError("{cache_id} not found") + + +cache_manager = CacheManager("pyop2") From aebe31251d05c0379d514a8a4a82538ec737cdcf Mon Sep 17 00:00:00 2001 From: Connor Ward Date: Fri, 9 Jun 2023 10:10:17 +0100 Subject: [PATCH 2/2] Add TODO comment --- pyop2/caching.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyop2/caching.py b/pyop2/caching.py index 3605d9cb1..6f84c9907 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -349,6 +349,8 @@ def _disk_cache_set(cachedir, key, value): tempfile.rename(filepath) +# TODO LRU caches should probably emit a warning when they overflow, would make +# development easier class CountedCache(collections.abc.MutableMapping, abc.ABC): def __init__(self): self.naccesses = 0