Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
JDBetteridge committed Aug 12, 2024
1 parent bc6e966 commit 597563b
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 2 deletions.
10 changes: 8 additions & 2 deletions pyop2/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ def wrapper(*args, **kwargs):
on calling the function and populating the cache.
"""
comm, mem_key = key(*args, **kwargs)
k = _as_hexdigest(mem_key)
k = _as_hexdigest(mem_key), func.__qualname__

# Fetch the per-comm cache or set it up if not present
local_cache = comm.Get_attr(comm_cache_keyval)
Expand Down Expand Up @@ -548,7 +548,7 @@ def wrapper(*args, **kwargs):
on calling the function and populating the cache.
"""
comm, mem_key = key(*args, **kwargs)
k = _as_hexdigest(mem_key)
k = _as_hexdigest(mem_key), func.__qualname__

# Fetch the per-comm cache or set it up if not present
local_cache = comm.Get_attr(comm_cache_keyval)
Expand All @@ -573,6 +573,7 @@ def wrapper(*args, **kwargs):
return decorator


# TODO: Change call signature
def disk_cached(cache, cachedir=None, key=cachetools.keys.hashkey, collective=False):
"""Decorator for wrapping a function in a cache that stores values in memory and to disk.
Expand Down Expand Up @@ -650,6 +651,11 @@ def _as_hexdigest(key):
return hashlib.md5(str(key).encode()).hexdigest()


def clear_memory_cache(comm):
if comm.Get_attr(comm_cache_keyval) is not None:
comm.Set_attr(comm_cache_keyval, {})


def _disk_cache_get(cachedir, key):
"""Retrieve a value from the disk cache.
Expand Down
158 changes: 158 additions & 0 deletions test/unit/test_updated_caching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import pytest
from tempfile import gettempdir
from functools import partial

from pyop2.caching import (
disk_cached,
parallel_memory_only_cache,
parallel_memory_only_cache_no_broadcast,
DiskCachedObject,
MemoryAndDiskCachedObject,
default_parallel_hashkey,
clear_memory_cache
)
from pyop2.mpi import MPI, COMM_WORLD, comm_cache_keyval


# For new disk_cached API
disk_cached = partial(disk_cached, None, key=default_parallel_hashkey, collective=True)


class StateIncrement:
"""Simple class for keeping track of the number of times executed
"""
def __init__(self):
self._count = 0

def __call__(self):
self._count += 1
return self._count

@property
def value(self):
return self._count


def twople(x):
return (x, )*2


def threeple(x):
return (x, )*3


def n_comms(n):
return [MPI.COMM_WORLD]*n


def n_ops(n):
return [MPI.SUM]*n


# decorator = parallel_memory_only_cache, parallel_memory_only_cache_no_broadcast, disk_cached
def function_factory(state, decorator, f, **kwargs):
def custom_function(x, comm=COMM_WORLD):
state()
return f(x)

return decorator(**kwargs)(custom_function)


# parent_class = DiskCachedObject, MemoryAndDiskCachedObject
# f(x) = x**2, x**3
def object_factory(state, parent_class, f, **kwargs):
class CustomObject(parent_class, **kwargs):
def __init__(self, x, comm=COMM_WORLD):
state()
self.x = f(x)

return CustomObject


@pytest.fixture
def state():
return StateIncrement()


@pytest.fixture
def unique_tempdir():
"""This allows us to run with a different tempdir for each test that
requires one"""
return gettempdir()


@pytest.mark.parametrize("decorator, uncached_function", [
(parallel_memory_only_cache, twople),
(parallel_memory_only_cache_no_broadcast, n_comms),
(disk_cached, twople)
])
def test_function_args_twice_caches(request, state, decorator, uncached_function, tmpdir):
if request.node.callspec.params["decorator"] is disk_cached:
kwargs = {"cachedir": tmpdir}
else:
kwargs = {}

cached_function = function_factory(state, decorator, uncached_function, **kwargs)
assert state.value == 0
first = cached_function(2, comm=COMM_WORLD)
assert first == uncached_function(2)
assert state.value == 1
second = cached_function(2, comm=COMM_WORLD)
assert second == uncached_function(2)
assert second is first
assert state.value == 1

clear_memory_cache(COMM_WORLD)


@pytest.mark.parametrize("decorator, uncached_function", [
(parallel_memory_only_cache, twople),
(parallel_memory_only_cache_no_broadcast, n_comms),
(disk_cached, twople)
])
def test_function_args_different(request, state, decorator, uncached_function, tmpdir):
if request.node.callspec.params["decorator"] is disk_cached:
kwargs = {"cachedir": tmpdir}
else:
kwargs = {}

cached_function = function_factory(state, decorator, uncached_function, **kwargs)
assert state.value == 0
first = cached_function(2, comm=COMM_WORLD)
assert first == uncached_function(2)
assert state.value == 1
second = cached_function(3, comm=COMM_WORLD)
assert second == uncached_function(3)
assert state.value == 2

clear_memory_cache(COMM_WORLD)


@pytest.mark.parallel(nprocs=3)
@pytest.mark.parametrize("decorator, uncached_function", [
(parallel_memory_only_cache, twople),
(parallel_memory_only_cache_no_broadcast, n_comms),
(disk_cached, twople)
])
def test_function_over_different_comms(request, state, decorator, uncached_function, tmpdir):
if request.node.callspec.params["decorator"] is disk_cached:
kwargs = {"cachedir": tmpdir}
else:
kwargs = {}

cached_function = function_factory(state, decorator, uncached_function, **kwargs)
assert state.value == 0
for ii in range(10):
color = 0 if COMM_WORLD.rank < 2 else MPI.UNDEFINED
comm12 = COMM_WORLD.Split(color=color)
if COMM_WORLD.rank < 2:
_ = cached_function(2, comm=comm12)
comm12.Free()

color = 0 if COMM_WORLD.rank > 0 else MPI.UNDEFINED
comm23 = COMM_WORLD.Split(color=color)
if COMM_WORLD.rank > 0:
_ = cached_function(2, comm=comm23)
comm23.Free()

clear_memory_cache(COMM_WORLD)

0 comments on commit 597563b

Please sign in to comment.