diff --git a/src/nnbench/__init__.py b/src/nnbench/__init__.py index a29f67b..d35141e 100644 --- a/src/nnbench/__init__.py +++ b/src/nnbench/__init__.py @@ -11,4 +11,4 @@ from .core import benchmark, parametrize, product from .reporter import BenchmarkReporter from .runner import BenchmarkRunner -from .types import Parameters +from .types import Memo, Parameters diff --git a/src/nnbench/types/types.py b/src/nnbench/types/types.py index 047caa2..bf5e884 100644 --- a/src/nnbench/types/types.py +++ b/src/nnbench/types/types.py @@ -5,6 +5,8 @@ import copy import functools import inspect +import logging +import threading from dataclasses import dataclass, field from types import MappingProxyType from typing import Any, Callable, Generic, Literal, Mapping, Protocol, TypeVar @@ -14,6 +16,80 @@ T = TypeVar("T") Variable = tuple[str, type, Any] +_memo_cache: dict[int, Any] = {} +_cache_lock = threading.Lock() + +logger = logging.getLogger(__name__) + + +def memo_cache_size() -> int: + """ + Get the current size of the memo cache. + + Returns + ------- + int + The number of items currently stored in the memo cache. + """ + return len(_memo_cache) + + +def clear_memo_cache() -> None: + """ + Clear all items from memo cache in a thread_safe manner. + """ + with _cache_lock: + _memo_cache.clear() + + +def evict_memo(_id: int) -> Any: + """ + Pop cached item with key `_id` from the memo cache. + + Parameters + ---------- + _id : int + The unique identifier (usually the id assigned by the Python interpreter) of the item to be evicted. + + Returns + ------- + Any + The value that was associated with the removed cache entry. If no item is found with the given `_id`, a KeyError is raised. + """ + with _cache_lock: + return _memo_cache.pop(_id) + + +def cached_memo(fn: Callable) -> Callable: + """ + Decorator that caches the result of a method call based on the instance ID. + + Parameters + ---------- + fn: Callable + The method to memoize. + + Returns + ------- + Callable + A wrapped version of the method that caches its result. + """ + + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + _tid = id(self) + with _cache_lock: + if _tid in _memo_cache: + logger.debug(f"Returning memoized value from cache with ID {_tid}") + return _memo_cache[_tid] + logger.debug(f"Computing value on memo with ID {_tid} (cache miss)") + value = fn(self, *args, **kwargs) + with _cache_lock: + _memo_cache[_tid] = value + return value + + return wrapper + def NoOp(state: State, params: Mapping[str, Any] = MappingProxyType({})) -> None: pass @@ -109,22 +185,24 @@ class State: class Memo(Generic[T]): - @functools.cache - # TODO: Swap this out for a local type-wide memo cache. - # Could also be a decorator, should look a bit like this: - # global memo_cache, memo_cache_lock - # _tid = id(self) - # val: T - # with memocache_lock: - # if _tid in memo_cache: - # val = memo_cache[_tid] - # return val - # val = self.compute() - # memo_cache[_tid] = val - # return val + """Abstract base class for memoized values in benchmark runs.""" + + # TODO: Make this better than the decorator application + # -> _Cached metaclass like in fsspec's AbstractFileSystem (maybe vendor with license) + + @cached_memo def __call__(self) -> T: + """Placeholder to override when subclassing. The call should return the to be cached object.""" raise NotImplementedError + def __del__(self) -> None: + """Delete the cached object and clear it from the cache.""" + with _cache_lock: + sid = id(self) + if sid in _memo_cache: + logger.debug(f"Deleting cached value for memo with ID {sid}") + del _memo_cache[sid] + @dataclass(init=False, frozen=True) class Parameters: diff --git a/tests/test_memos.py b/tests/test_memos.py new file mode 100644 index 0000000..fb96d83 --- /dev/null +++ b/tests/test_memos.py @@ -0,0 +1,28 @@ +from typing import Generator + +import pytest + +from nnbench.types.types import Memo, cached_memo, clear_memo_cache, memo_cache_size + + +@pytest.fixture +def clear_memos() -> Generator[None, None, None]: + try: + clear_memo_cache() + yield + finally: + clear_memo_cache() + + +class MyMemo(Memo[int]): + @cached_memo + def __call__(self): + return 0 + + +def test_memo_caching(clear_memos): + m = MyMemo() + assert memo_cache_size() == 0 + m() + assert memo_cache_size() == 1 + m()