-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
chore: add cache_with_transforms decorator
- Loading branch information
Showing
3 changed files
with
307 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from __future__ import annotations | ||
|
||
from functools import wraps | ||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar | ||
|
||
import upath | ||
|
||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Callable | ||
|
||
|
||
P = ParamSpec("P") | ||
R = TypeVar("R") | ||
|
||
|
||
def cache_with_transforms( | ||
*, | ||
arg_transformers: dict[int, Callable[[Any], Any]] | None = None, | ||
kwarg_transformers: dict[str, Callable[[Any], Any]] | None = None, | ||
) -> Callable[[Callable[P, R]], Callable[P, R]]: | ||
"""A caching decorator with transformation functions for args and kwargs. | ||
Can be used to make specific args / kwargs hashable. | ||
Also adds cache and cache_info objects to the decorated function. | ||
Args: | ||
arg_transformers: Dict mapping positional args indices to transformer functions | ||
kwarg_transformers: Dict mapping kwargs names to transformer functions | ||
Returns: | ||
A decorator function that caches results based on transformed arguments | ||
""" | ||
arg_transformers = arg_transformers or {} | ||
kwarg_transformers = kwarg_transformers or {} | ||
|
||
def decorator(func: Callable[P, R]) -> Callable[P, R]: | ||
cache: dict[tuple[Any, ...], R] = {} | ||
|
||
@wraps(func) | ||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: | ||
# Transform positional arguments | ||
transformed_args = tuple( | ||
arg_transformers.get(i, lambda x: x)(arg) for i, arg in enumerate(args) | ||
) | ||
|
||
# Transform keyword arguments | ||
transformed_kwargs = { | ||
key: kwarg_transformers.get(key, lambda x: x)(value) | ||
for key, value in sorted(kwargs.items()) | ||
} | ||
|
||
# Create cache key from transformed arguments | ||
cache_key = (transformed_args, tuple(transformed_kwargs.items())) | ||
|
||
if cache_key not in cache: | ||
cache[cache_key] = func(*args, **kwargs) | ||
return cache[cache_key] | ||
|
||
def cache_info() -> dict[str, int]: | ||
"""Return information about cache hits and size.""" | ||
return {"cache_size": len(cache)} | ||
|
||
wrapper.cache_info = cache_info # type: ignore | ||
wrapper.cache = cache # type: ignore | ||
|
||
return wrapper | ||
|
||
return decorator | ||
|
||
|
||
# def path_cache( | ||
# maxsize: int | None = None, | ||
# typed: bool = False, | ||
# args_index: int = 0, | ||
# ) -> Callable[[Callable[P, R]], Callable[P, R]]: | ||
# """Decorator that caches function results and normalizes PathLike arguments. | ||
|
||
# Caches the first argument by default. Can be changed via args_index argument. | ||
# Also adds cache_info and cache_clear methods to the decorated function. | ||
|
||
# Args: | ||
# maxsize: Maximum size of the cache. If None, cache is unbounded. | ||
# typed: If True, arguments of different types are cached separately. | ||
# args_index: Index of the argument to cache. Default is 0. | ||
|
||
# Returns: | ||
# Decorated function with caching capability. | ||
# """ | ||
|
||
# def decorator(func: Callable[P, R]) -> Callable[P, R]: | ||
# # Create cached version of the original function | ||
# cached_func = lru_cache(maxsize=maxsize, typed=typed)(func) | ||
|
||
# @wraps(func) | ||
# def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: | ||
# # Convert first argument if it's PathLike | ||
# if args and isinstance(args[args_index], str | PathLike): | ||
# normalized_args = (upath.UPath(args[0]).resolve(),) + args[1:] | ||
# return cached_func(*normalized_args, **kwargs) # type: ignore[arg-type] | ||
# return cached_func(*args, **kwargs) # type: ignore[arg-type] | ||
|
||
# # Add cache_info and cache_clear methods to the wrapper | ||
# wrapper.cache_info = cached_func.cache_info # type: ignore | ||
# wrapper.cache_clear = cached_func.cache_clear # type: ignore | ||
|
||
# return wrapper | ||
|
||
# return decorator | ||
|
||
|
||
if __name__ == "__main__": | ||
|
||
@cache_with_transforms(arg_transformers={0: lambda p: upath.UPath(p).resolve()}) | ||
def read_file_content(filepath: str | upath.UPath) -> str: | ||
"""Read and return the content of a file.""" | ||
with upath.UPath(filepath).open() as f: | ||
return f.read() | ||
|
||
# These calls will use the same cache entry | ||
content1 = read_file_content("pyproject.toml") | ||
content1 = read_file_content("mkdocs.yml") | ||
content2 = read_file_content(upath.UPath("pyproject.toml")) | ||
content3 = read_file_content(upath.UPath("./pyproject.toml").absolute()) | ||
|
||
# Check cache statistics | ||
print(read_file_content.cache_info()) # type: ignore |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
# mypy: disable-error-code="attr-defined" | ||
from typing import Any | ||
|
||
import pytest | ||
|
||
from jinjarope.decorators import cache_with_transforms | ||
|
||
|
||
def test_basic_caching() -> None: | ||
"""Test basic function caching without transformers.""" | ||
call_count = 0 | ||
|
||
@cache_with_transforms() | ||
def add(x: int, y: int) -> int: | ||
nonlocal call_count | ||
call_count += 1 | ||
return x + y | ||
|
||
assert add(1, 2) == 3 | ||
assert add(1, 2) == 3 | ||
assert call_count == 1 | ||
assert add.cache_info()["cache_size"] == 1 | ||
|
||
|
||
def test_arg_transformer() -> None: | ||
"""Test caching with argument transformers.""" | ||
|
||
@cache_with_transforms(arg_transformers={0: lambda x: x.lower()}) | ||
def greet(name: str) -> str: | ||
return f"Hello, {name}!" | ||
|
||
assert greet("John") == "Hello, John!" | ||
assert greet("JOHN") == "Hello, John!" # Should hit cache, previous word cached | ||
assert greet.cache_info()["cache_size"] == 1 | ||
|
||
|
||
def test_kwarg_transformer() -> None: | ||
"""Test caching with keyword argument transformers.""" | ||
|
||
@cache_with_transforms(kwarg_transformers={"items": tuple}) | ||
def process_list(*, items: list[int]) -> int: | ||
return sum(items) | ||
|
||
assert process_list(items=[1, 2, 3]) == 6 | ||
assert process_list(items=[1, 2, 3]) == 6 # Should hit cache | ||
assert process_list.cache_info()["cache_size"] == 1 | ||
|
||
|
||
def test_unhashable_args() -> None: | ||
"""Test caching with unhashable arguments.""" | ||
|
||
@cache_with_transforms(arg_transformers={0: tuple}) | ||
def process_list(lst: list[int]) -> int: | ||
return sum(lst) | ||
|
||
assert process_list([1, 2, 3]) == 6 | ||
assert process_list([1, 2, 3]) == 6 # Should hit cache | ||
assert process_list.cache_info()["cache_size"] == 1 | ||
|
||
|
||
def test_mixed_args_kwargs() -> None: | ||
"""Test caching with both positional and keyword arguments.""" | ||
|
||
@cache_with_transforms( | ||
arg_transformers={0: str.lower}, kwarg_transformers={"items": tuple} | ||
) | ||
def process_data(prefix: str, *, items: list[int]) -> str: | ||
return f"{prefix}: {sum(items)}" | ||
|
||
assert process_data("Sum", items=[1, 2, 3]) == "Sum: 6" | ||
assert process_data("SUM", items=[1, 2, 3]) == "Sum: 6" # Should hit cache | ||
assert process_data.cache_info()["cache_size"] == 1 | ||
|
||
|
||
def test_multiple_calls_different_args() -> None: | ||
"""Test caching behavior with different arguments.""" | ||
call_count = 0 | ||
|
||
@cache_with_transforms() | ||
def add(x: int, y: int) -> int: | ||
nonlocal call_count | ||
call_count += 1 | ||
return x + y | ||
|
||
assert add(1, 2) == 3 | ||
assert add(2, 3) == 5 | ||
assert add(1, 2) == 3 # Should hit cache | ||
assert call_count == 2 | ||
assert add.cache_info()["cache_size"] == 2 | ||
|
||
|
||
def test_none_values() -> None: | ||
"""Test caching behavior with None values.""" | ||
|
||
@cache_with_transforms() | ||
def process_optional(x: Any) -> str: | ||
return str(x) | ||
|
||
assert process_optional(None) == "None" | ||
assert process_optional(None) == "None" # Should hit cache | ||
assert process_optional.cache_info()["cache_size"] == 1 | ||
|
||
|
||
def test_empty_transformers() -> None: | ||
"""Test that decorator works with empty transformers.""" | ||
|
||
@cache_with_transforms(arg_transformers={}, kwarg_transformers={}) | ||
def identity(x: int) -> int: | ||
return x | ||
|
||
assert identity(1) == 1 | ||
assert identity.cache_info()["cache_size"] == 1 | ||
|
||
|
||
def test_cache_persistence() -> None: | ||
"""Test that cache persists between calls.""" | ||
|
||
@cache_with_transforms() | ||
def expensive_operation(x: int) -> int: | ||
return x**2 | ||
|
||
assert expensive_operation(2) == 4 | ||
initial_cache = expensive_operation.cache | ||
assert expensive_operation(2) == 4 | ||
assert expensive_operation.cache is initial_cache | ||
|
||
|
||
def test_different_kwarg_orders() -> None: | ||
"""Test that different keyword argument orders produce same cache key.""" | ||
|
||
@cache_with_transforms() | ||
def process_kwargs(*, a: int, b: int) -> int: | ||
return a + b | ||
|
||
assert process_kwargs(a=1, b=2) == 3 | ||
assert process_kwargs(b=2, a=1) == 3 # Should hit cache | ||
assert process_kwargs.cache_info()["cache_size"] == 1 | ||
|
||
|
||
def test_complex_transformers() -> None: | ||
"""Test with more complex transformer functions.""" | ||
|
||
def complex_transform(x: list[Any]) -> tuple[Any, ...]: | ||
return tuple(sorted(x)) | ||
|
||
@cache_with_transforms( | ||
arg_transformers={0: complex_transform}, | ||
kwarg_transformers={"data": complex_transform}, | ||
) | ||
def process_lists(lst: list[int], *, data: list[int]) -> int: | ||
return sum(lst) + sum(data) | ||
|
||
assert process_lists([3, 1, 2], data=[6, 4, 5]) == 21 | ||
assert process_lists([2, 3, 1], data=[5, 6, 4]) == 21 # Should hit cache | ||
assert process_lists.cache_info()["cache_size"] == 1 | ||
|
||
|
||
def test_error_handling() -> None: | ||
"""Test that errors in the original function are not cached.""" | ||
|
||
@cache_with_transforms() | ||
def failing_function(x: int) -> int: | ||
msg = "Error" | ||
raise ValueError(msg) | ||
|
||
with pytest.raises(ValueError): # noqa: PT011 | ||
failing_function(1) | ||
|
||
with pytest.raises(ValueError): # noqa: PT011 | ||
failing_function(1) # Should not be cached | ||
|
||
assert failing_function.cache_info()["cache_size"] == 0 | ||
|
||
|
||
if __name__ == "__main__": | ||
pytest.main([__file__]) |