Skip to content

Commit

Permalink
Refactor JSON caching to resemble satpy caching decorators
Browse files Browse the repository at this point in the history
  • Loading branch information
djhoese committed Nov 13, 2023
1 parent 6fd85a1 commit 94ca7a8
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 44 deletions.
106 changes: 68 additions & 38 deletions pyresample/_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,57 +4,69 @@
throughout pyresample.
"""

import functools
import hashlib
import json
import shutil
from functools import update_wrapper
from glob import glob
from pathlib import Path
from typing import Any, Callable

import pyresample


class JSONCache:
class JSONCacheHelper:
"""Decorator class to cache results to a JSON file on-disk."""

def __init__(self, *args, **kwargs):
self._callable = None
if len(args) == 1 and not kwargs:
self._callable = args[0]
def __init__(
self,
func: Callable,
cache_config_key: str,
cache_version: int = 1,
):
self._callable = func
self._cache_config_key = cache_config_key
self._cache_version = cache_version

def cache_clear(self, cache_dir: str | None = None):
"""Remove all on-disk files associated with this function.
Intended to mimic the :func:`functools.cache` behavior.
"""
cache_dir = self._get_cache_dir_from_config(cache_dir=cache_dir, cache_version="*")
for zarr_dir in glob(str(cache_dir / "*.json")):
shutil.rmtree(zarr_dir, ignore_errors=True)

def __call__(self, *args, **kwargs):
"""Call decorated function and cache the result to JSON."""
is_decorated = len(args) == 1 and isinstance(args[0], Callable)
if is_decorated:
self._callable = args[0]

@functools.wraps(self._callable)
def _func(*args, **kwargs):
if not pyresample.config.get("cache_geom_slices", False):
return self._callable(*args, **kwargs)

# TODO: kwargs
existing_hash = hashlib.sha1()
# hashable_args = [hash(arg) if isinstance(arg, AreaDefinition) else arg for arg in args]
hashable_args = [hash(arg) if arg.__class__.__name__ == "AreaDefinition" else arg for arg in args]
existing_hash.update(json.dumps(tuple(hashable_args)).encode("utf8"))
arg_hash = existing_hash.hexdigest()
print(arg_hash)
base_cache_dir = Path(pyresample.config.get("cache_dir")) / "geometry_slices"
json_path = base_cache_dir / f"{arg_hash}.json"
if not json_path.is_file():
res = self._callable(*args, **kwargs)
json_path.parent.mkdir(exist_ok=True)
with open(json_path, "w") as json_cache:
json.dump(res, json_cache, cls=_ExtraJSONEncoder)
else:
with open(json_path, "r") as json_cache:
res = json.load(json_cache, object_hook=_object_hook)
return res

if is_decorated:
return _func
return _func(*args, **kwargs)
if not pyresample.config.get(self._cache_config_key, False):
return self._callable(*args, **kwargs)

existing_hash = hashlib.sha1()
# TODO: exclude SwathDefinition for hashing reasons

Check notice on line 46 in pyresample/_caching.py

View check run for this annotation

codefactor.io / CodeFactor

pyresample/_caching.py#L46

unresolved comment '# TODO: exclude SwathDefinition for hashing reasons' (C100)
hashable_args = [hash(arg) if arg.__class__.__name__ in ("AreaDefinition",) else arg for arg in args]
hashable_args += sorted(kwargs.items())
existing_hash.update(json.dumps(tuple(hashable_args)).encode("utf8"))
arg_hash = existing_hash.hexdigest()
base_cache_dir = self._get_cache_dir_from_config(cache_version=self._cache_version)
json_path = base_cache_dir / f"{arg_hash}.json"
if not json_path.is_file():
res = self._callable(*args, **kwargs)
json_path.parent.mkdir(exist_ok=True)
with open(json_path, "w") as json_cache:
json.dump(res, json_cache, cls=_ExtraJSONEncoder)
else:
with open(json_path, "r") as json_cache:
res = json.load(json_cache, object_hook=_object_hook)
return res

@staticmethod
def _get_cache_dir_from_config(cache_dir: str | None = None, cache_version: int | str = 1) -> Path:
cache_dir = cache_dir or pyresample.config.get("cache_dir")
if cache_dir is None:
raise RuntimeError("Can't use JSON caching. No 'cache_dir' configured.")
subdir = f"geometry_slices_v{cache_version}"
return Path(cache_dir) / subdir


class _ExtraJSONEncoder(json.JSONEncoder):
Expand All @@ -68,3 +80,21 @@ def _object_hook(obj: object) -> Any:
if isinstance(obj, dict) and obj.get("__slice__", False):
return slice(obj["start"], obj["stop"], obj["step"])
return obj


def cache_to_json_if(cache_config_key: str) -> Callable:
"""Decorate a function and cache the results to a JSON file on disk.
This caching only happens if the ``pyresample.config`` boolean value for
the provided key is ``True`` as well as some other conditions. See
:class:`JSONCacheHelper` for more information. Most importantly this
decorator does not limit how many items can be cached and does not clear
out old entries. It is up to the user to manage the size of the cache.
"""
def _decorator(func: Callable) -> Callable:
zarr_cacher = JSONCacheHelper(func, cache_config_key)
wrapper = update_wrapper(zarr_cacher, func)
return wrapper

return _decorator
4 changes: 2 additions & 2 deletions pyresample/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from pyproj.aoi import AreaOfUse

from pyresample import CHUNK_SIZE
from pyresample._caching import JSONCache
from pyresample._caching import cache_to_json_if
from pyresample._spatial_mp import Cartesian, Cartesian_MP, Proj_MP
from pyresample.area_config import create_area_def
from pyresample.boundary import Boundary, SimpleBoundary
Expand Down Expand Up @@ -2680,7 +2680,7 @@ def geocentric_resolution(self, ellps='WGS84', radius=None):
return res


@JSONCache()
@cache_to_json_if("cache_geom_slices")
def get_area_slices(
src_area: AreaDefinition,
area_to_cover: AreaDefinition,
Expand Down
8 changes: 4 additions & 4 deletions pyresample/test/test_geometry/test_area.py
Original file line number Diff line number Diff line change
Expand Up @@ -1850,13 +1850,13 @@ def test_area_slices_caching(self, create_test_area, tmp_path):
100, 100,
(15.9689, 58.5284, 16.4346, 58.6995))
with pyresample.config.set(cache_dir=tmp_path, cache_geom_slices=False):
assert len(glob(str(tmp_path / "geometry_slices" / "*.json"))) == 0
assert len(glob(str(tmp_path / "geometry_slices_v1" / "*.json"))) == 0
slice_x, slice_y = src_area.get_area_slices(crop_area)
assert len(glob(str(tmp_path / "geometry_slices" / "*.json"))) == 0
assert len(glob(str(tmp_path / "geometry_slices_v1" / "*.json"))) == 0
with pyresample.config.set(cache_dir=tmp_path, cache_geom_slices=True):
assert len(glob(str(tmp_path / "geometry_slices" / "*.json"))) == 0
assert len(glob(str(tmp_path / "geometry_slices_v1" / "*.json"))) == 0
slice_x, slice_y = src_area.get_area_slices(crop_area)
assert len(glob(str(tmp_path / "geometry_slices" / "*.json"))) == 1
assert len(glob(str(tmp_path / "geometry_slices_v1" / "*.json"))) == 1
assert slice_x == slice(5630, 8339)
assert slice_y == slice(9261, 10980)

Expand Down

0 comments on commit 94ca7a8

Please sign in to comment.