diff --git a/pyop2/caching.py b/pyop2/caching.py index 96c64de75..d0539609a 100644 --- a/pyop2/caching.py +++ b/pyop2/caching.py @@ -351,6 +351,38 @@ def write(self, filehandle, value): pickle.dump(value, filehandle) +class NoShardDiskAccess(DictLikeDiskAccess): + def __getitem__(self, key): + """Retrieve a value from the disk cache. + + :arg key: The cache key, a 2-tuple of strings. + :returns: The cached object if found. + """ + filepath = Path(self.cachedir, key[0] + key[1]) + try: + with self.open(filepath.with_suffix(self.extension), mode="rb") as fh: + value = self.read(fh) + except FileNotFoundError: + raise KeyError("File not on disk, cache miss") + return value + + def __setitem__(self, key, value): + """Store a new value in the disk cache. + + :arg key: The cache key, a 2-tuple of strings. + :arg value: The new item to store in the cache. + """ + k = key[0] + key[1] + basedir = Path(self.cachedir) + basedir.mkdir(parents=True, exist_ok=True) + + tempfile = basedir.joinpath(f"{k}_p{os.getpid()}.tmp") + filepath = basedir.joinpath(k) + with self.open(tempfile, mode="wb") as fh: + self.write(fh, value) + tempfile.rename(filepath.with_suffix(self.extension)) + + def default_comm_fetcher(*args, **kwargs): """ A sensible default comm fetcher for use with `parallel_cache`. """ diff --git a/pyop2/compilation.py b/pyop2/compilation.py index 9fb654d9c..58d3c7928 100644 --- a/pyop2/compilation.py +++ b/pyop2/compilation.py @@ -51,7 +51,7 @@ from pyop2 import mpi -from pyop2.caching import parallel_cache, memory_cache, default_parallel_hashkey, _as_hexdigest, DictLikeDiskAccess +from pyop2.caching import parallel_cache, memory_cache, default_parallel_hashkey, _as_hexdigest, NoShardDiskAccess from pyop2.configuration import configuration from pyop2.logger import warning, debug, progress, INFO from pyop2.exceptions import CompilationError @@ -501,7 +501,7 @@ def expandWl(ldflags): yield flag -class CompilerDiskAccess(DictLikeDiskAccess): +class CompilerDiskAccess(NoShardDiskAccess): @contextmanager def open(self, filename, *args, **kwargs): yield filename