Skip to content

Commit

Permalink
Caching framework. (#933)
Browse files Browse the repository at this point in the history
* Caching
  • Loading branch information
nikfilippas authored Mar 6, 2023
1 parent fa73ff2 commit 163b319
Show file tree
Hide file tree
Showing 4 changed files with 434 additions and 12 deletions.
4 changes: 3 additions & 1 deletion pyccl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

# CCL base
from .base import (
Caching,
cache,
hash_,
)

Expand Down Expand Up @@ -166,7 +168,7 @@

__all__ = (
'lib',
'hash_',
'Caching', 'cache', 'hash_',
'CCLParameters', 'spline_params', 'gsl_params', 'physical_constants',
'CCLError', 'CCLWarning', 'CCLDeprecationWarning',
'Cosmology', 'CosmologyVanillaLCDM', 'CosmologyCalculator',
Expand Down
255 changes: 255 additions & 0 deletions pyccl/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import sys
import functools
from collections import OrderedDict
import numpy as np
from inspect import signature
from _thread import RLock


def _to_hashable(obj):
Expand Down Expand Up @@ -43,3 +46,255 @@ def hash_(obj):
"""Generic hash method, which changes between processes."""
digest = hash(repr(_to_hashable(obj))) + sys.maxsize + 1
return digest


class _ClassPropertyMeta(type):
"""Implement `property` to a `classmethod`."""
# NOTE: Only in 3.8 < py < 3.11 can `classmethod` wrap `property`.
# https://docs.python.org/3.11/library/functions.html#classmethod
@property
def maxsize(cls):
return cls._maxsize

@maxsize.setter
def maxsize(cls, value):
if value < 0:
raise ValueError(
"`maxsize` should be larger than zero. "
"To disable caching, use `Caching.disable()`.")
cls._maxsize = value
for func in cls._cached_functions:
func.cache_info.maxsize = value

@property
def policy(cls):
return cls._policy

@policy.setter
def policy(cls, value):
if value not in cls._policies:
raise ValueError("Cache retention policy not recognized.")
if value == "lfu" != cls._policy:
# Reset counter if we change policy to lfu
# otherwise new objects are prone to being discarded immediately.
# Now, the counter is not just used for stats,
# it is part of the retention policy.
for func in cls._cached_functions:
for item in func.cache_info._caches.values():
item.reset()
cls._policy = value
for func in cls._cached_functions:
func.cache_info.policy = value


class Caching(metaclass=_ClassPropertyMeta):
"""Infrastructure to hold cached objects.
Caching is used for pre-computed objects that are expensive to compute.
Attributes:
maxsize (``int``):
Maximum number of caches to store. If the dictionary is full, new
caches are assigned according to the set cache retention policy.
policy (``'fifo'``, ``'lru'``, ``'lfu'``):
Cache retention policy.
"""
_enabled: bool = False
_policies: list = ['fifo', 'lru', 'lfu']
_default_maxsize: int = 128 # class default maxsize
_default_policy: str = 'lru' # class default policy
_maxsize = _default_maxsize # user-defined maxsize
_policy = _default_policy # user-defined policy
_cached_functions: list = []

@classmethod
def _get_key(cls, func, *args, **kwargs):
"""Calculate the hex hash from arguments and keyword arguments."""
# get a dictionary of default parameters
params = func.cache_info._signature.parameters
# get a dictionary of the passed parameters
passed = {**dict(zip(params, args)), **kwargs}
# discard the values equal to the default
defaults = {param: value.default for param, value in params.items()}
return hex(hash_({**defaults, **passed}))

@classmethod
def _get(cls, dic, key, policy):
"""Get the cached object container
under the implemented caching policy.
"""
obj = dic[key]
if policy == "lru":
dic.move_to_end(key)
# update stats
obj.increment()
return obj

@classmethod
def _pop(cls, dic, policy):
"""Remove one cached item as per the implemented caching policy."""
if policy == "lfu":
keys = list(dic)
idx = np.argmin([item.counter for item in dic.values()])
dic.move_to_end(keys[idx], last=False)
dic.popitem(last=False)

@classmethod
def _decorator(cls, func, maxsize, policy):
# assign caching attributes to decorated function
func.cache_info = CacheInfo(func, maxsize=maxsize, policy=policy)
func.clear_cache = func.cache_info._clear_cache
cls._cached_functions.append(func)

@functools.wraps(func)
def wrapper(*args, **kwargs):
if not cls._enabled:
return func(*args, **kwargs)

key = cls._get_key(func, *args, **kwargs)
# shorthand access
caches = func.cache_info._caches
maxsize = func.cache_info.maxsize
policy = func.cache_info.policy

with RLock():
if key in caches:
# output has been cached; update stats and return it
out = cls._get(caches, key, policy)
func.cache_info.hits += 1
return out.item

with RLock():
while len(caches) >= maxsize:
# output not cached and no space available, so remove
# items as per the caching policy until there is space
cls._pop(caches, policy)

# cache new entry and update stats
out = CachedObject(func(*args, **kwargs))
caches[key] = out
func.cache_info.misses += 1
return out.item

return wrapper

@classmethod
def cache(cls, func=None, *, maxsize=_maxsize, policy=_policy):
"""Cache the output of the decorated function, using the input
arguments as a proxy to build a hash key.
Arguments:
func (``function``):
Function to be decorated.
maxsize (``int``):
Maximum cache size for the decorated function.
policy (``'fifo'``, ``'lru'``, ``'lfu'``):
Cache retention policy. When the storage reaches maxsize
decide which cached object will be deleted. Default is 'lru'.\n
'fifo': first-in-first-out,\n
'lru': least-recently-used,\n
'lfu': least-frequently-used.
"""
if maxsize < 0:
raise ValueError(
"`maxsize` should be larger than zero. "
"To disable caching, use `Caching.disable()`.")
if policy not in cls._policies:
raise ValueError("Cache retention policy not recognized.")

if func is None:
# `@cache` with parentheses
return functools.partial(
cls._decorator, maxsize=maxsize, policy=policy)
# `@cache()` without parentheses
return cls._decorator(func, maxsize=maxsize, policy=policy)

@classmethod
def enable(cls):
cls._enabled = True

@classmethod
def disable(cls):
cls._enabled = False

@classmethod
def reset(cls):
cls.maxsize = cls._default_maxsize
cls.policy = cls._default_policy

@classmethod
def clear_cache(cls):
[func.clear_cache() for func in cls._cached_functions]


cache = Caching.cache


class CacheInfo:
"""Cache info container.
Assigned to cached function as ``function.cache_info``.
Parameters:
func (``function``):
Function in which an instance of this class will be assigned.
maxsize (``Caching.maxsize``):
Maximum number of caches to store.
policy (``Caching.policy``):
Cache retention policy.
.. note ::
To assist in deciding an optimal ``maxsize`` and ``policy``, instances
of this class contain the following attributes:
- ``hits``: number of times the function has been bypassed
- ``misses``: number of times the function has computed something
- ``current_size``: current size of the cache dictionary
"""

def __init__(self, func, maxsize=Caching.maxsize, policy=Caching.policy):
# we store the signature of the function on import
# as it is the most expensive operation (~30x slower)
self._signature = signature(func)
self._caches = OrderedDict()
self.maxsize = maxsize
self.policy = policy
self.hits = self.misses = 0

@property
def current_size(self):
return len(self._caches)

def __repr__(self):
s = f"<{self.__class__.__name__}>"
for par, val in self.__dict__.items():
if not par.startswith("_"):
s += f"\n\t {par} = {val!r}"
s += f"\n\t current_size = {self.current_size!r}"
return s

def _clear_cache(self):
self._caches = OrderedDict()
self.hits = self.misses = 0


class CachedObject:
"""A cached object container.
Attributes:
counter (``int``):
Number of times the cached item has been retrieved.
"""
counter: int = 0

def __init__(self, obj):
self.item = obj

def __repr__(self):
s = f"CachedObject(counter={self.counter})"
return s

def increment(self):
self.counter += 1

def reset(self):
self.counter = 0
33 changes: 22 additions & 11 deletions pyccl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .pyutils import check
from .pk2d import Pk2D
from .bcm import bcm_correct_pk2d
from .base import cache
from .parameters import CCLParameters, physical_constants

# Configuration types
Expand Down Expand Up @@ -766,11 +767,9 @@ def compute_growth(self):
status = lib.cosmology_compute_growth(self.cosmo, status)
check(status, self)

def compute_linear_power(self):
"""Compute the linear power spectrum."""
if self.has_linear_power:
return

@cache(maxsize=3)
def _compute_linear_power(self):
"""Return the linear power spectrum."""
if (self['N_nu_mass'] > 0 and
self._config_init_kwargs['transfer_function'] in
['bbks', 'eisenstein_hu', 'eisenstein_hu_nowiggles', ]):
Expand Down Expand Up @@ -839,6 +838,13 @@ def compute_linear_power(self):
status)
check(status, self)

return pk

def compute_linear_power(self):
"""Compute the linear power spectrum."""
if self.has_linear_power:
return
pk = self._compute_linear_power()
# Assign
self._pk_lin['delta_matter:delta_matter'] = pk

Expand Down Expand Up @@ -874,11 +880,9 @@ def _get_halo_model_nonlin_power(self):
hmc = hal.HMCalculator(self, hmf, hbf, mdef)
return hal.halomod_Pk2D(self, hmc, prf, normprof1=True)

def compute_nonlin_power(self):
"""Compute the non-linear power spectrum."""
if self.has_nonlin_power:
return

@cache(maxsize=3)
def _compute_nonlin_power(self):
"""Return the non-linear power spectrum."""
if self._config_init_kwargs['matter_power_spectrum'] != 'linear':
if self._params_init_kwargs['df_mg'] is not None:
warnings.warn(
Expand Down Expand Up @@ -915,7 +919,7 @@ def compute_nonlin_power(self):

if mps == "camb" and self.has_nonlin_power:
# Already computed
return
return self._pk_nl['delta_matter:delta_matter']

if mps is None:
raise CCLError("You want to compute the non-linear power "
Expand All @@ -942,6 +946,13 @@ def compute_nonlin_power(self):
if self._config_init_kwargs['baryons_power_spectrum'] == 'bcm':
bcm_correct_pk2d(self, pk)

return pk

def compute_nonlin_power(self):
"""Compute the non-linear power spectrum."""
if self.has_nonlin_power:
return
pk = self._compute_nonlin_power()
# Assign
self._pk_nl['delta_matter:delta_matter'] = pk

Expand Down
Loading

0 comments on commit 163b319

Please sign in to comment.