From 597563bacdb5b8eec8be84d145f8e00a553251d8 Mon Sep 17 00:00:00 2001 From: Jack Betteridge Date: Mon, 12 Aug 2024 22:08:14 +0100 Subject: [PATCH] WIP --- pyop2/caching.py | 10 +- test/unit/test_updated_caching.py | 158 ++++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+), 2 deletions(-) create mode 100644 test/unit/test_updated_caching.py diff --git a/pyop2/caching.py b/pyop2/caching.py index eb90f2258..7522c6978 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -501,7 +501,7 @@ def wrapper(*args, **kwargs): on calling the function and populating the cache. """ comm, mem_key = key(*args, **kwargs) - k = _as_hexdigest(mem_key) + k = _as_hexdigest(mem_key), func.__qualname__ # Fetch the per-comm cache or set it up if not present local_cache = comm.Get_attr(comm_cache_keyval) @@ -548,7 +548,7 @@ def wrapper(*args, **kwargs): on calling the function and populating the cache. """ comm, mem_key = key(*args, **kwargs) - k = _as_hexdigest(mem_key) + k = _as_hexdigest(mem_key), func.__qualname__ # Fetch the per-comm cache or set it up if not present local_cache = comm.Get_attr(comm_cache_keyval) @@ -573,6 +573,7 @@ def wrapper(*args, **kwargs): return decorator +# TODO: Change call signature def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=False): """Decorator for wrapping a function in a cache that stores values in memory and to disk. @@ -650,6 +651,11 @@ def _as_hexdigest(key): return hashlib.md5(str(key).encode()).hexdigest() +def clear_memory_cache(comm): + if comm.Get_attr(comm_cache_keyval) is not None: + comm.Set_attr(comm_cache_keyval, {}) + + def _disk_cache_get(cachedir, key): """Retrieve a value from the disk cache. diff --git a/test/unit/test_updated_caching.py b/test/unit/test_updated_caching.py new file mode 100644 index 000000000..bdd6583bd --- /dev/null +++ b/test/unit/test_updated_caching.py @@ -0,0 +1,158 @@ +import pytest +from tempfile import gettempdir +from functools import partial + +from pyop2.caching import ( + disk_cached, + parallel_memory_only_cache, + parallel_memory_only_cache_no_broadcast, + DiskCachedObject, + MemoryAndDiskCachedObject, + default_parallel_hashkey, + clear_memory_cache +) +from pyop2.mpi import MPI, COMM_WORLD, comm_cache_keyval + + +# For new disk_cached API +disk_cached = partial(disk_cached, None, key=default_parallel_hashkey, collective=True) + + +class StateIncrement: + """Simple class for keeping track of the number of times executed + """ + def __init__(self): + self._count = 0 + + def __call__(self): + self._count += 1 + return self._count + + @property + def value(self): + return self._count + + +def twople(x): + return (x, )*2 + + +def threeple(x): + return (x, )*3 + + +def n_comms(n): + return [MPI.COMM_WORLD]*n + + +def n_ops(n): + return [MPI.SUM]*n + + +# decorator = parallel_memory_only_cache, parallel_memory_only_cache_no_broadcast, disk_cached +def function_factory(state, decorator, f, **kwargs): + def custom_function(x, comm=COMM_WORLD): + state() + return f(x) + + return decorator(**kwargs)(custom_function) + + +# parent_class = DiskCachedObject, MemoryAndDiskCachedObject +# f(x) = x**2, x**3 +def object_factory(state, parent_class, f, **kwargs): + class CustomObject(parent_class, **kwargs): + def __init__(self, x, comm=COMM_WORLD): + state() + self.x = f(x) + + return CustomObject + + +@pytest.fixture +def state(): + return StateIncrement() + + +@pytest.fixture +def unique_tempdir(): + """This allows us to run with a different tempdir for each test that + requires one""" + return gettempdir() + + +@pytest.mark.parametrize("decorator, uncached_function", [ + (parallel_memory_only_cache, twople), + (parallel_memory_only_cache_no_broadcast, n_comms), + (disk_cached, twople) +]) +def test_function_args_twice_caches(request, state, decorator, uncached_function, tmpdir): + if request.node.callspec.params["decorator"] is disk_cached: + kwargs = {"cachedir": tmpdir} + else: + kwargs = {} + + cached_function = function_factory(state, decorator, uncached_function, **kwargs) + assert state.value == 0 + first = cached_function(2, comm=COMM_WORLD) + assert first == uncached_function(2) + assert state.value == 1 + second = cached_function(2, comm=COMM_WORLD) + assert second == uncached_function(2) + assert second is first + assert state.value == 1 + + clear_memory_cache(COMM_WORLD) + + +@pytest.mark.parametrize("decorator, uncached_function", [ + (parallel_memory_only_cache, twople), + (parallel_memory_only_cache_no_broadcast, n_comms), + (disk_cached, twople) +]) +def test_function_args_different(request, state, decorator, uncached_function, tmpdir): + if request.node.callspec.params["decorator"] is disk_cached: + kwargs = {"cachedir": tmpdir} + else: + kwargs = {} + + cached_function = function_factory(state, decorator, uncached_function, **kwargs) + assert state.value == 0 + first = cached_function(2, comm=COMM_WORLD) + assert first == uncached_function(2) + assert state.value == 1 + second = cached_function(3, comm=COMM_WORLD) + assert second == uncached_function(3) + assert state.value == 2 + + clear_memory_cache(COMM_WORLD) + + +@pytest.mark.parallel(nprocs=3) +@pytest.mark.parametrize("decorator, uncached_function", [ + (parallel_memory_only_cache, twople), + (parallel_memory_only_cache_no_broadcast, n_comms), + (disk_cached, twople) +]) +def test_function_over_different_comms(request, state, decorator, uncached_function, tmpdir): + if request.node.callspec.params["decorator"] is disk_cached: + kwargs = {"cachedir": tmpdir} + else: + kwargs = {} + + cached_function = function_factory(state, decorator, uncached_function, **kwargs) + assert state.value == 0 + for ii in range(10): + color = 0 if COMM_WORLD.rank < 2 else MPI.UNDEFINED + comm12 = COMM_WORLD.Split(color=color) + if COMM_WORLD.rank < 2: + _ = cached_function(2, comm=comm12) + comm12.Free() + + color = 0 if COMM_WORLD.rank > 0 else MPI.UNDEFINED + comm23 = COMM_WORLD.Split(color=color) + if COMM_WORLD.rank > 0: + _ = cached_function(2, comm=comm23) + comm23.Free() + + clear_memory_cache(COMM_WORLD)