Skip to content

Commit

Permalink
Add global memo cache and integrate (#130)
Browse files Browse the repository at this point in the history
Add a global cache dict in `nnbench.types`. Users subclass the 
`Memo` class with the `@cached_memo` decorator applied
to the override of `__call__` method. The call should return the
object that should be cached. 

`del` on the subclassed `Memo` also removes the wrapped 
item from the cache. 

Add utility methods for cache management as well as tests.

---------

Co-authored-by: Nicholas Junge <[email protected]>
  • Loading branch information
maxmynter and nicholasjng authored Mar 26, 2024
1 parent 20e9fe8 commit 636d18b
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/nnbench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
from .core import benchmark, parametrize, product
from .reporter import BenchmarkReporter
from .runner import BenchmarkRunner
from .types import Parameters
from .types import Memo, Parameters
104 changes: 91 additions & 13 deletions src/nnbench/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import copy
import functools
import inspect
import logging
import threading
from dataclasses import dataclass, field
from types import MappingProxyType
from typing import Any, Callable, Generic, Literal, Mapping, Protocol, TypeVar
Expand All @@ -14,6 +16,80 @@
T = TypeVar("T")
Variable = tuple[str, type, Any]

_memo_cache: dict[int, Any] = {}
_cache_lock = threading.Lock()

logger = logging.getLogger(__name__)


def memo_cache_size() -> int:
"""
Get the current size of the memo cache.
Returns
-------
int
The number of items currently stored in the memo cache.
"""
return len(_memo_cache)


def clear_memo_cache() -> None:
"""
Clear all items from memo cache in a thread_safe manner.
"""
with _cache_lock:
_memo_cache.clear()


def evict_memo(_id: int) -> Any:
"""
Pop cached item with key `_id` from the memo cache.
Parameters
----------
_id : int
The unique identifier (usually the id assigned by the Python interpreter) of the item to be evicted.
Returns
-------
Any
The value that was associated with the removed cache entry. If no item is found with the given `_id`, a KeyError is raised.
"""
with _cache_lock:
return _memo_cache.pop(_id)


def cached_memo(fn: Callable) -> Callable:
"""
Decorator that caches the result of a method call based on the instance ID.
Parameters
----------
fn: Callable
The method to memoize.
Returns
-------
Callable
A wrapped version of the method that caches its result.
"""

@functools.wraps(fn)
def wrapper(self, *args, **kwargs):
_tid = id(self)
with _cache_lock:
if _tid in _memo_cache:
logger.debug(f"Returning memoized value from cache with ID {_tid}")
return _memo_cache[_tid]
logger.debug(f"Computing value on memo with ID {_tid} (cache miss)")
value = fn(self, *args, **kwargs)
with _cache_lock:
_memo_cache[_tid] = value
return value

return wrapper


def NoOp(state: State, params: Mapping[str, Any] = MappingProxyType({})) -> None:
pass
Expand Down Expand Up @@ -109,22 +185,24 @@ class State:


class Memo(Generic[T]):
@functools.cache
# TODO: Swap this out for a local type-wide memo cache.
# Could also be a decorator, should look a bit like this:
# global memo_cache, memo_cache_lock
# _tid = id(self)
# val: T
# with memocache_lock:
# if _tid in memo_cache:
# val = memo_cache[_tid]
# return val
# val = self.compute()
# memo_cache[_tid] = val
# return val
"""Abstract base class for memoized values in benchmark runs."""

# TODO: Make this better than the decorator application
# -> _Cached metaclass like in fsspec's AbstractFileSystem (maybe vendor with license)

@cached_memo
def __call__(self) -> T:
"""Placeholder to override when subclassing. The call should return the to be cached object."""
raise NotImplementedError

def __del__(self) -> None:
"""Delete the cached object and clear it from the cache."""
with _cache_lock:
sid = id(self)
if sid in _memo_cache:
logger.debug(f"Deleting cached value for memo with ID {sid}")
del _memo_cache[sid]


@dataclass(init=False, frozen=True)
class Parameters:
Expand Down
28 changes: 28 additions & 0 deletions tests/test_memos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from typing import Generator

import pytest

from nnbench.types.types import Memo, cached_memo, clear_memo_cache, memo_cache_size


@pytest.fixture
def clear_memos() -> Generator[None, None, None]:
try:
clear_memo_cache()
yield
finally:
clear_memo_cache()


class MyMemo(Memo[int]):
@cached_memo
def __call__(self):
return 0


def test_memo_caching(clear_memos):
m = MyMemo()
assert memo_cache_size() == 0
m()
assert memo_cache_size() == 1
m()

0 comments on commit 636d18b

Please sign in to comment.