-
Notifications
You must be signed in to change notification settings - Fork 35
/
Copy pathcaching.py
374 lines (305 loc) · 13.5 KB
/
caching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
# This file is part of PyOP2
#
# PyOP2 is Copyright (c) 2012, Imperial College London and
# others. Please see the AUTHORS file in the main source directory for
# a full list of copyright holders. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
#
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * The name of Imperial College London or that of other
# contributors may not be used to endorse or promote products
# derived from this software without specific prior written
# permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTERS
# ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
# HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
# STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
# OF THE POSSIBILITY OF SUCH DAMAGE.
"""Provides common base classes for cached objects."""
import cachetools
import hashlib
import os
import pickle
from collections.abc import MutableMapping
from pathlib import Path
from warnings import warn # noqa F401
from functools import wraps
from pyop2.configuration import configuration
from pyop2.logger import debug
from pyop2.mpi import MPI, COMM_WORLD, comm_cache_keyval
# TODO: Remove this? Rewrite?
def report_cache(typ):
"""Report the size of caches of type ``typ``
:arg typ: A class of cached object. For example
:class:`ObjectCached` or :class:`Cached`.
"""
from collections import defaultdict
from inspect import getmodule
from gc import get_objects
typs = defaultdict(lambda: 0)
n = 0
for x in get_objects():
if isinstance(x, typ):
typs[type(x)] += 1
n += 1
if n == 0:
print("\nNo %s objects in caches" % typ.__name__)
return
print("\n%d %s objects in caches" % (n, typ.__name__))
print("Object breakdown")
print("================")
for k, v in typs.iteritems():
mod = getmodule(k)
if mod is not None:
name = "%s.%s" % (mod.__name__, k.__name__)
else:
name = k.__name__
print('%s: %d' % (name, v))
class ObjectCached(object):
"""Base class for objects that should be cached on another object.
Derived classes need to implement classmethods
:meth:`_process_args` and :meth:`_cache_key` (which see for more
details). The object on which the cache is stored should contain
a dict in its ``_cache`` attribute.
.. warning::
The derived class' :meth:`__init__` is still called if the
object is retrieved from cache. If that is not desired,
derived classes can set a flag indicating whether the
constructor has already been called and immediately return
from :meth:`__init__` if the flag is set. Otherwise the object
will be re-initialized even if it was returned from cache!
"""
@classmethod
def _process_args(cls, *args, **kwargs):
"""Process the arguments to ``__init__`` into a form suitable
for computing a cache key on.
The first returned argument is popped off the argument list
passed to ``__init__`` and is used as the object on which to
cache this instance. As such, *args* should be returned as a
two-tuple of ``(cache_object, ) + (original_args, )``.
*kwargs* must be a (possibly empty) dict.
"""
raise NotImplementedError("Subclass must implement _process_args")
@classmethod
def _cache_key(cls, *args, **kwargs):
"""Compute a cache key from the constructor's preprocessed arguments.
If ``None`` is returned, the object is not to be cached.
.. note::
The return type **must** be hashable.
"""
raise NotImplementedError("Subclass must implement _cache_key")
def __new__(cls, *args, **kwargs):
args, kwargs = cls._process_args(*args, **kwargs)
# First argument is the object we're going to cache on
cache_obj = args[0]
# These are now the arguments to the subclass constructor
args = args[1:]
key = cls._cache_key(*args, **kwargs)
def make_obj():
obj = super(ObjectCached, cls).__new__(cls)
obj._initialized = False
# obj.__init__ will be called twice when constructing
# something not in the cache. The first time here, with
# the canonicalised args, the second time directly in the
# subclass. But that one should hit the cache and return
# straight away.
obj.__init__(*args, **kwargs)
return obj
# Don't bother looking in caches if we're not meant to cache
# this object.
if key is None or cache_obj is None:
return make_obj()
# Does the caching object know about the caches?
try:
cache = cache_obj._cache
except AttributeError:
raise RuntimeError("Provided caching object does not have a '_cache' attribute.")
# OK, we have a cache, let's go ahead and try and find our
# object in it.
try:
return cache[key]
except KeyError:
obj = make_obj()
cache[key] = obj
return obj
class _CacheMiss:
pass
CACHE_MISS = _CacheMiss()
def _as_hexdigest(*args):
hash_ = hashlib.md5()
for a in args:
hash_.update(str(a).encode())
return hash_.hexdigest()
def clear_memory_cache(comm):
if comm.Get_attr(comm_cache_keyval) is not None:
comm.Set_attr(comm_cache_keyval, {})
class DictLikeDiskAccess(MutableMapping):
def __init__(self, cachedir):
"""
:arg cachedir: The cache directory.
"""
self.cachedir = cachedir
def __getitem__(self, key):
"""Retrieve a value from the disk cache.
:arg key: The cache key, a 2-tuple of strings.
:returns: The cached object if found.
"""
filepath = Path(self.cachedir, key[0][:2], key[0][2:] + key[1])
try:
with self.open(filepath, "rb") as fh:
value = self.read(fh)
except FileNotFoundError:
raise KeyError("File not on disk, cache miss")
return value
def __setitem__(self, key, value):
"""Store a new value in the disk cache.
:arg key: The cache key, a 2-tuple of strings.
:arg value: The new item to store in the cache.
"""
k1, k2 = key[0][:2], key[0][2:] + key[1]
basedir = Path(self.cachedir, k1)
basedir.mkdir(parents=True, exist_ok=True)
tempfile = basedir.joinpath(f"{k2}_p{os.getpid()}.tmp")
filepath = basedir.joinpath(k2)
with self.open(tempfile, "wb") as fh:
self.write(fh, value)
tempfile.rename(filepath)
def __delitem__(self, key):
raise ValueError(f"Cannot remove items from {self.__class__.__name__}")
def __iter__(self):
raise ValueError(f"Cannot iterate over keys in {self.__class__.__name__}")
def __len__(self):
raise ValueError(f"Cannot query length of {self.__class__.__name__}")
def __repr__(self):
return f"{self.__class__.__name__}(cachedir={self.cachedir})"
def open(self, *args, **kwargs):
return open(*args, **kwargs)
def read(self, filehandle):
return pickle.load(filehandle)
def write(self, filehandle, value):
pickle.dump(value, filehandle)
def default_comm_fetcher(*args, **kwargs):
comms = filter(
lambda arg: isinstance(arg, MPI.Comm),
args + tuple(kwargs.values())
)
try:
comm = next(comms)
except StopIteration:
raise TypeError("No comms found in args or kwargs")
return comm
def default_parallel_hashkey(*args, **kwargs):
""" We now want to actively remove any comms from args and kwargs to get the same disk cache key
"""
hash_args = tuple(filter(
lambda arg: not isinstance(arg, MPI.Comm),
args
))
hash_kwargs = dict(filter(
lambda arg: not isinstance(arg[1], MPI.Comm),
kwargs.items()
))
return cachetools.keys.hashkey(*hash_args, **hash_kwargs)
class DEFAULT_CACHE(dict):
pass
def parallel_cache(
hashkey=default_parallel_hashkey,
comm_fetcher=default_comm_fetcher,
cache_factory=lambda: DEFAULT_CACHE(),
broadcast=True
):
"""Memory only cache decorator.
Decorator for wrapping a function to be called over a communiucator in a
cache that stores broadcastable values in memory. If the value is found in
the cache of rank 0 it is broadcast to all other ranks.
:arg key: Callable returning the cache key for the function inputs. This
function must return a 2-tuple where the first entry is the
communicator to be collective over and the second is the key. This is
required to ensure that deadlocks do not occur when using different
subcommunicators.
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
""" Extract the key and then try the memory cache before falling back
on calling the function and populating the cache.
"""
comm = comm_fetcher(*args, **kwargs)
k = hashkey(*args, **kwargs)
key = _as_hexdigest(k), func.__qualname__
# Fetch the per-comm cache_collection or set it up if not present
# A collection is required since different types of cache can be set up on the same comm
cache_collection = comm.Get_attr(comm_cache_keyval)
if cache_collection is None:
cache_collection = {}
comm.Set_attr(comm_cache_keyval, cache_collection)
# If this kind of cache is already present on the
# cache_collection, get it, otherwise create it
local_cache = cache_collection.setdefault(
(cf := cache_factory()).__class__.__name__,
cf
)
if broadcast:
# Grab value from rank 0 memory cache and broadcast result
if comm.rank == 0:
value = local_cache.get(key, CACHE_MISS)
if value is None:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss')
else:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache hit')
# TODO: Add communication tags to avoid cross-broadcasting
comm.bcast(value, root=0)
else:
value = comm.bcast(CACHE_MISS, root=0)
if isinstance(value, _CacheMiss):
# We might have the CACHE_MISS from rank 0 and
# `(value is CACHE_MISS) == False` which is confusing,
# so we set it back to the local value
value = CACHE_MISS
else:
# Grab value from all ranks cache and broadcast cache hit/miss
value = local_cache.get(key, CACHE_MISS)
if value is CACHE_MISS:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache miss')
cache_hit = False
else:
debug(f'{COMM_WORLD.name} R{COMM_WORLD.rank}, {comm.name} R{comm.rank}: {k} memory cache hit')
cache_hit = True
all_present = comm.allgather(cache_hit)
# If not present in the cache of all ranks we need to recompute on all ranks
if not min(all_present):
value = CACHE_MISS
if value is CACHE_MISS:
value = func(*args, **kwargs)
return local_cache.setdefault(key, value)
return wrapper
return decorator
# A small collection of default simple caches
memory_cache = parallel_cache
def disk_only_cache(*args, cachedir=configuration["cache_dir"], **kwargs):
return parallel_cache(*args, **kwargs, cache_factory=lambda: DictLikeDiskAccess(cachedir))
def memory_and_disk_cache(*args, cachedir=configuration["cache_dir"], **kwargs):
def decorator(func):
return memory_cache(*args, **kwargs)(disk_only_cache(*args, cachedir=cachedir, **kwargs)(func))
return decorator
# TODO: (Wishlist)
# * Try more exotic caches ie: memory_cache = partial(parallel_cache, cache_factory=lambda: cachetools.LRUCache(maxsize=1000))
# * Add some sort of cache reporting
# * Add some sort of cache statistics
# * Refactor compilation.py to use @mem_and_disk_cached, where get_so is just uses DictLikeDiskAccess with an overloaded self.write() method
# * Add some docstrings and maybe some exposition!