From 163b319ebb3ffcbd47796fd20b3d9b82685f5e5f Mon Sep 17 00:00:00 2001 From: Nick Koukoufilippas Date: Mon, 6 Mar 2023 23:42:23 +0200 Subject: [PATCH] Caching framework. (#933) * Caching --- pyccl/__init__.py | 4 +- pyccl/base.py | 255 +++++++++++++++++++++++++++++++++++ pyccl/core.py | 33 +++-- pyccl/tests/test__caching.py | 154 +++++++++++++++++++++ 4 files changed, 434 insertions(+), 12 deletions(-) create mode 100644 pyccl/tests/test__caching.py diff --git a/pyccl/__init__.py b/pyccl/__init__.py index 77db73d2a..72dc953c7 100644 --- a/pyccl/__init__.py +++ b/pyccl/__init__.py @@ -25,6 +25,8 @@ # CCL base from .base import ( + Caching, + cache, hash_, ) @@ -166,7 +168,7 @@ __all__ = ( 'lib', - 'hash_', + 'Caching', 'cache', 'hash_', 'CCLParameters', 'spline_params', 'gsl_params', 'physical_constants', 'CCLError', 'CCLWarning', 'CCLDeprecationWarning', 'Cosmology', 'CosmologyVanillaLCDM', 'CosmologyCalculator', diff --git a/pyccl/base.py b/pyccl/base.py index d2782b77b..d896bd58c 100644 --- a/pyccl/base.py +++ b/pyccl/base.py @@ -1,6 +1,9 @@ import sys +import functools from collections import OrderedDict import numpy as np +from inspect import signature +from _thread import RLock def _to_hashable(obj): @@ -43,3 +46,255 @@ def hash_(obj): """Generic hash method, which changes between processes.""" digest = hash(repr(_to_hashable(obj))) + sys.maxsize + 1 return digest + + +class _ClassPropertyMeta(type): + """Implement `property` to a `classmethod`.""" + # NOTE: Only in 3.8 < py < 3.11 can `classmethod` wrap `property`. + # https://docs.python.org/3.11/library/functions.html#classmethod + @property + def maxsize(cls): + return cls._maxsize + + @maxsize.setter + def maxsize(cls, value): + if value < 0: + raise ValueError( + "`maxsize` should be larger than zero. " + "To disable caching, use `Caching.disable()`.") + cls._maxsize = value + for func in cls._cached_functions: + func.cache_info.maxsize = value + + @property + def policy(cls): + return cls._policy + + @policy.setter + def policy(cls, value): + if value not in cls._policies: + raise ValueError("Cache retention policy not recognized.") + if value == "lfu" != cls._policy: + # Reset counter if we change policy to lfu + # otherwise new objects are prone to being discarded immediately. + # Now, the counter is not just used for stats, + # it is part of the retention policy. + for func in cls._cached_functions: + for item in func.cache_info._caches.values(): + item.reset() + cls._policy = value + for func in cls._cached_functions: + func.cache_info.policy = value + + +class Caching(metaclass=_ClassPropertyMeta): + """Infrastructure to hold cached objects. + + Caching is used for pre-computed objects that are expensive to compute. + + Attributes: + maxsize (``int``): + Maximum number of caches to store. If the dictionary is full, new + caches are assigned according to the set cache retention policy. + policy (``'fifo'``, ``'lru'``, ``'lfu'``): + Cache retention policy. + """ + _enabled: bool = False + _policies: list = ['fifo', 'lru', 'lfu'] + _default_maxsize: int = 128 # class default maxsize + _default_policy: str = 'lru' # class default policy + _maxsize = _default_maxsize # user-defined maxsize + _policy = _default_policy # user-defined policy + _cached_functions: list = [] + + @classmethod + def _get_key(cls, func, *args, **kwargs): + """Calculate the hex hash from arguments and keyword arguments.""" + # get a dictionary of default parameters + params = func.cache_info._signature.parameters + # get a dictionary of the passed parameters + passed = {**dict(zip(params, args)), **kwargs} + # discard the values equal to the default + defaults = {param: value.default for param, value in params.items()} + return hex(hash_({**defaults, **passed})) + + @classmethod + def _get(cls, dic, key, policy): + """Get the cached object container + under the implemented caching policy. + """ + obj = dic[key] + if policy == "lru": + dic.move_to_end(key) + # update stats + obj.increment() + return obj + + @classmethod + def _pop(cls, dic, policy): + """Remove one cached item as per the implemented caching policy.""" + if policy == "lfu": + keys = list(dic) + idx = np.argmin([item.counter for item in dic.values()]) + dic.move_to_end(keys[idx], last=False) + dic.popitem(last=False) + + @classmethod + def _decorator(cls, func, maxsize, policy): + # assign caching attributes to decorated function + func.cache_info = CacheInfo(func, maxsize=maxsize, policy=policy) + func.clear_cache = func.cache_info._clear_cache + cls._cached_functions.append(func) + + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not cls._enabled: + return func(*args, **kwargs) + + key = cls._get_key(func, *args, **kwargs) + # shorthand access + caches = func.cache_info._caches + maxsize = func.cache_info.maxsize + policy = func.cache_info.policy + + with RLock(): + if key in caches: + # output has been cached; update stats and return it + out = cls._get(caches, key, policy) + func.cache_info.hits += 1 + return out.item + + with RLock(): + while len(caches) >= maxsize: + # output not cached and no space available, so remove + # items as per the caching policy until there is space + cls._pop(caches, policy) + + # cache new entry and update stats + out = CachedObject(func(*args, **kwargs)) + caches[key] = out + func.cache_info.misses += 1 + return out.item + + return wrapper + + @classmethod + def cache(cls, func=None, *, maxsize=_maxsize, policy=_policy): + """Cache the output of the decorated function, using the input + arguments as a proxy to build a hash key. + + Arguments: + func (``function``): + Function to be decorated. + maxsize (``int``): + Maximum cache size for the decorated function. + policy (``'fifo'``, ``'lru'``, ``'lfu'``): + Cache retention policy. When the storage reaches maxsize + decide which cached object will be deleted. Default is 'lru'.\n + 'fifo': first-in-first-out,\n + 'lru': least-recently-used,\n + 'lfu': least-frequently-used. + """ + if maxsize < 0: + raise ValueError( + "`maxsize` should be larger than zero. " + "To disable caching, use `Caching.disable()`.") + if policy not in cls._policies: + raise ValueError("Cache retention policy not recognized.") + + if func is None: + # `@cache` with parentheses + return functools.partial( + cls._decorator, maxsize=maxsize, policy=policy) + # `@cache()` without parentheses + return cls._decorator(func, maxsize=maxsize, policy=policy) + + @classmethod + def enable(cls): + cls._enabled = True + + @classmethod + def disable(cls): + cls._enabled = False + + @classmethod + def reset(cls): + cls.maxsize = cls._default_maxsize + cls.policy = cls._default_policy + + @classmethod + def clear_cache(cls): + [func.clear_cache() for func in cls._cached_functions] + + +cache = Caching.cache + + +class CacheInfo: + """Cache info container. + Assigned to cached function as ``function.cache_info``. + + Parameters: + func (``function``): + Function in which an instance of this class will be assigned. + maxsize (``Caching.maxsize``): + Maximum number of caches to store. + policy (``Caching.policy``): + Cache retention policy. + + .. note :: + + To assist in deciding an optimal ``maxsize`` and ``policy``, instances + of this class contain the following attributes: + - ``hits``: number of times the function has been bypassed + - ``misses``: number of times the function has computed something + - ``current_size``: current size of the cache dictionary + """ + + def __init__(self, func, maxsize=Caching.maxsize, policy=Caching.policy): + # we store the signature of the function on import + # as it is the most expensive operation (~30x slower) + self._signature = signature(func) + self._caches = OrderedDict() + self.maxsize = maxsize + self.policy = policy + self.hits = self.misses = 0 + + @property + def current_size(self): + return len(self._caches) + + def __repr__(self): + s = f"<{self.__class__.__name__}>" + for par, val in self.__dict__.items(): + if not par.startswith("_"): + s += f"\n\t {par} = {val!r}" + s += f"\n\t current_size = {self.current_size!r}" + return s + + def _clear_cache(self): + self._caches = OrderedDict() + self.hits = self.misses = 0 + + +class CachedObject: + """A cached object container. + + Attributes: + counter (``int``): + Number of times the cached item has been retrieved. + """ + counter: int = 0 + + def __init__(self, obj): + self.item = obj + + def __repr__(self): + s = f"CachedObject(counter={self.counter})" + return s + + def increment(self): + self.counter += 1 + + def reset(self): + self.counter = 0 diff --git a/pyccl/core.py b/pyccl/core.py index a66802c27..45f83a08e 100644 --- a/pyccl/core.py +++ b/pyccl/core.py @@ -14,6 +14,7 @@ from .pyutils import check from .pk2d import Pk2D from .bcm import bcm_correct_pk2d +from .base import cache from .parameters import CCLParameters, physical_constants # Configuration types @@ -766,11 +767,9 @@ def compute_growth(self): status = lib.cosmology_compute_growth(self.cosmo, status) check(status, self) - def compute_linear_power(self): - """Compute the linear power spectrum.""" - if self.has_linear_power: - return - + @cache(maxsize=3) + def _compute_linear_power(self): + """Return the linear power spectrum.""" if (self['N_nu_mass'] > 0 and self._config_init_kwargs['transfer_function'] in ['bbks', 'eisenstein_hu', 'eisenstein_hu_nowiggles', ]): @@ -839,6 +838,13 @@ def compute_linear_power(self): status) check(status, self) + return pk + + def compute_linear_power(self): + """Compute the linear power spectrum.""" + if self.has_linear_power: + return + pk = self._compute_linear_power() # Assign self._pk_lin['delta_matter:delta_matter'] = pk @@ -874,11 +880,9 @@ def _get_halo_model_nonlin_power(self): hmc = hal.HMCalculator(self, hmf, hbf, mdef) return hal.halomod_Pk2D(self, hmc, prf, normprof1=True) - def compute_nonlin_power(self): - """Compute the non-linear power spectrum.""" - if self.has_nonlin_power: - return - + @cache(maxsize=3) + def _compute_nonlin_power(self): + """Return the non-linear power spectrum.""" if self._config_init_kwargs['matter_power_spectrum'] != 'linear': if self._params_init_kwargs['df_mg'] is not None: warnings.warn( @@ -915,7 +919,7 @@ def compute_nonlin_power(self): if mps == "camb" and self.has_nonlin_power: # Already computed - return + return self._pk_nl['delta_matter:delta_matter'] if mps is None: raise CCLError("You want to compute the non-linear power " @@ -942,6 +946,13 @@ def compute_nonlin_power(self): if self._config_init_kwargs['baryons_power_spectrum'] == 'bcm': bcm_correct_pk2d(self, pk) + return pk + + def compute_nonlin_power(self): + """Compute the non-linear power spectrum.""" + if self.has_nonlin_power: + return + pk = self._compute_nonlin_power() # Assign self._pk_nl['delta_matter:delta_matter'] = pk diff --git a/pyccl/tests/test__caching.py b/pyccl/tests/test__caching.py new file mode 100644 index 000000000..071560cab --- /dev/null +++ b/pyccl/tests/test__caching.py @@ -0,0 +1,154 @@ +# We use double underscore to make it the first test alphabetically. +import pytest +import pyccl as ccl +import numpy as np +from time import time + + +NUM = 3 # number of different Cosmologies we will check +# some unusual numbers that have not occurred before +s8_arr = np.linspace(0.753141592, 0.953141592, NUM) +# a modest speed increase - we are modest in the test to accommodate for slow +# runs; normally this is is expected to be another order of magnitude faster +SPEEDUP = 50 + +# enable caching if not already enabled +DEFAULT_CACHING_STATUS = ccl.Caching._enabled +if not ccl.Caching._enabled: + ccl.Caching.enable() + + +def get_cosmo(sigma8): + return ccl.Cosmology(Omega_c=0.25, Omega_b=0.05, h=0.67, n_s=0.96, + sigma8=sigma8) + + +def cosmo_create_and_compute_linpow(sigma8): + cosmo = get_cosmo(sigma8) + cosmo.compute_linear_power() + return cosmo + + +def timeit_(sigma8): + t0 = time() + cosmo_create_and_compute_linpow(sigma8) + return time() - t0 + + +def test_caching_switches(): + """Test that the Caching switches work as intended.""" + assert ccl.Caching._enabled == DEFAULT_CACHING_STATUS + assert ccl.Caching._maxsize == ccl.Caching._default_maxsize + ccl.Caching.maxsize = 128 + assert ccl.Caching._maxsize == 128 + ccl.Caching.disable() + assert not ccl.Caching._enabled + ccl.Caching.enable() + assert ccl.Caching._enabled + + +def test_times(): + """Verify that caching is happening. + Return time for querying the Boltzmann code goes from O(1s) to O(5ms). + """ + # If we disable caching, t1 and t2 will be of the same order of magnitude. + # No need to run through the entire s8_arr. + ccl.Caching.disable() + t1 = np.array([timeit_(s8) for s8 in s8_arr[:1]]) + t2 = np.array([timeit_(s8) for s8 in s8_arr[:1]]) + assert np.abs(np.log10(t2/t1)) < 1.0 + # But if caching is enabled, the second call will be much faster. + ccl.Caching.enable() + t1 = np.array([timeit_(s8) for s8 in s8_arr]) + t2 = np.array([timeit_(s8) for s8 in s8_arr]) + assert np.all(t1/t2 > SPEEDUP) + + +def test_caching_fifo(): + """Test First-In-First-Out retention policy.""" + # To save time, we test caching by limiting the maximum cache size + # from 64 (default) to 3. We cache Comologies with different sigma8. + # By now, the caching repo will be full. + ccl.Caching.maxsize = NUM + func = ccl.Cosmology._compute_linear_power + assert len(func.cache_info._caches) >= ccl.Caching.maxsize + + ccl.Caching.policy = "fifo" + + t1 = timeit_(sigma8=s8_arr[0]) # this is the oldest cached pk + # create new and discard oldest + cosmo_create_and_compute_linpow(0.42) + t2 = timeit_(sigma8=s8_arr[0]) # cached again + assert t2/t1 > SPEEDUP + + +def test_caching_lru(): + """Test Least-Recently-Used retention policy.""" + # By now the stored Cosmologies are { s8_arr[2], 0.42, s8_arr[0]] } + # from oldest to newest. Here, we show that we can retain s8_arr[2] + # simply by using it and moving it to the end of the stack. + ccl.Caching.policy = "lru" + + t1 = timeit_(sigma8=s8_arr[2]) # moves to the end of the stack + # create new and discard the least recently used + cosmo_create_and_compute_linpow(0.43) + t2 = timeit_(sigma8=s8_arr[2]) # retrieved + assert np.abs(np.log10(t2/t1)) < 1.0 + + +def test_caching_lfu(): + """Test Least-Frequently-Used retention policy.""" + # Now, the stored Cosmologies are { s8_arr[0], 0.43, s8_arr[2] } + # from oldest to newest. Here, we call each a different number of times + # and we check that the one used the least (0.43) is discarded. + ccl.Caching.policy = "lfu" + + t1 = timeit_(sigma8=0.43) # increments counter by 1 + _ = [timeit_(sigma8=s8_arr[0]) for _ in range(5)] + _ = [timeit_(sigma8=s8_arr[2]) for _ in range(3)] + # create new and discard the least frequently used + cosmo_create_and_compute_linpow(0.44) + t2 = timeit_(sigma8=0.43) # cached again + assert t2/t1 > SPEEDUP + + +def test_cache_info(): + """Test that the CacheInfo repr gives us the expected information.""" + info = ccl.Cosmology._compute_linear_power.cache_info + for text in ["maxsize", "policy", "hits", "misses", "current_size"]: + assert text in repr(info) + + obj = list(info._caches.values())[0] + assert "counter" in repr(obj) + + +def test_caching_reset(): + """Test the reset switches.""" + ccl.Caching.reset() + assert ccl.Caching.maxsize == ccl.Caching._default_maxsize + assert ccl.Caching.policy == ccl.Caching._default_policy + ccl.Caching.clear_cache() + func = ccl.Cosmology._compute_linear_power + assert len(func.cache_info._caches) == 0 + + +def test_caching_policy_raises(): + """Test that if the set policy is not correct, it raises an exception.""" + with pytest.raises(ValueError): + @ccl.Caching.cache(maxsize=-1) + def func1(): + return + + with pytest.raises(ValueError): + @ccl.Caching.cache(policy="my_policy") + def func2(): + return + + with pytest.raises(ValueError): + ccl.Caching.maxsize = -1 + + with pytest.raises(ValueError): + ccl.Caching.policy = "my_policy" + + +ccl.Caching._enabled = DEFAULT_CACHING_STATUS