Skip to content

Commit

Permalink
chore: add cache_with_transforms decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
phil65 committed Oct 30, 2024
1 parent 131f0e9 commit f51c0a7
Show file tree
Hide file tree
Showing 3 changed files with 307 additions and 3 deletions.
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,9 @@ select = [
"YTT", # flake8-2020
]
ignore = [
"C408", # Unnecessary {obj_type} call (rewrite as a literal)
"B905", # zip() without an explicit strict= parameter
"C901", # {name} is too complex ({complexity} > {max_complexity})
"C408", # Unnecessary {obj_type} call (rewrite as a literal)
"B905", # zip() without an explicit strict= parameter
"C901", # {name} is too complex ({complexity} > {max_complexity})
"COM812",
# "CPY001", # Missing copyright notice at top of file
"D100", # Missing docstring in public module
Expand Down Expand Up @@ -211,6 +211,7 @@ combine-as-imports = true

[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["E402", "I001"]
"tests/test_decorators.py" = ["PLR2004"]

[tool.pyright]
reportUnusedCallResult = false
Expand Down
127 changes: 127 additions & 0 deletions src/jinjarope/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from __future__ import annotations

from functools import wraps
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar

import upath


if TYPE_CHECKING:
from collections.abc import Callable


P = ParamSpec("P")
R = TypeVar("R")


def cache_with_transforms(
*,
arg_transformers: dict[int, Callable[[Any], Any]] | None = None,
kwarg_transformers: dict[str, Callable[[Any], Any]] | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""A caching decorator with transformation functions for args and kwargs.
Can be used to make specific args / kwargs hashable.
Also adds cache and cache_info objects to the decorated function.
Args:
arg_transformers: Dict mapping positional args indices to transformer functions
kwarg_transformers: Dict mapping kwargs names to transformer functions
Returns:
A decorator function that caches results based on transformed arguments
"""
arg_transformers = arg_transformers or {}
kwarg_transformers = kwarg_transformers or {}

def decorator(func: Callable[P, R]) -> Callable[P, R]:
cache: dict[tuple[Any, ...], R] = {}

@wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# Transform positional arguments
transformed_args = tuple(
arg_transformers.get(i, lambda x: x)(arg) for i, arg in enumerate(args)
)

# Transform keyword arguments
transformed_kwargs = {
key: kwarg_transformers.get(key, lambda x: x)(value)
for key, value in sorted(kwargs.items())
}

# Create cache key from transformed arguments
cache_key = (transformed_args, tuple(transformed_kwargs.items()))

if cache_key not in cache:
cache[cache_key] = func(*args, **kwargs)
return cache[cache_key]

def cache_info() -> dict[str, int]:
"""Return information about cache hits and size."""
return {"cache_size": len(cache)}

wrapper.cache_info = cache_info # type: ignore
wrapper.cache = cache # type: ignore

return wrapper

return decorator


# def path_cache(
# maxsize: int | None = None,
# typed: bool = False,
# args_index: int = 0,
# ) -> Callable[[Callable[P, R]], Callable[P, R]]:
# """Decorator that caches function results and normalizes PathLike arguments.

# Caches the first argument by default. Can be changed via args_index argument.
# Also adds cache_info and cache_clear methods to the decorated function.

# Args:
# maxsize: Maximum size of the cache. If None, cache is unbounded.
# typed: If True, arguments of different types are cached separately.
# args_index: Index of the argument to cache. Default is 0.

# Returns:
# Decorated function with caching capability.
# """

# def decorator(func: Callable[P, R]) -> Callable[P, R]:
# # Create cached version of the original function
# cached_func = lru_cache(maxsize=maxsize, typed=typed)(func)

# @wraps(func)
# def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# # Convert first argument if it's PathLike
# if args and isinstance(args[args_index], str | PathLike):
# normalized_args = (upath.UPath(args[0]).resolve(),) + args[1:]
# return cached_func(*normalized_args, **kwargs) # type: ignore[arg-type]
# return cached_func(*args, **kwargs) # type: ignore[arg-type]

# # Add cache_info and cache_clear methods to the wrapper
# wrapper.cache_info = cached_func.cache_info # type: ignore
# wrapper.cache_clear = cached_func.cache_clear # type: ignore

# return wrapper

# return decorator


if __name__ == "__main__":

@cache_with_transforms(arg_transformers={0: lambda p: upath.UPath(p).resolve()})
def read_file_content(filepath: str | upath.UPath) -> str:
"""Read and return the content of a file."""
with upath.UPath(filepath).open() as f:
return f.read()

# These calls will use the same cache entry
content1 = read_file_content("pyproject.toml")
content1 = read_file_content("mkdocs.yml")
content2 = read_file_content(upath.UPath("pyproject.toml"))
content3 = read_file_content(upath.UPath("./pyproject.toml").absolute())

# Check cache statistics
print(read_file_content.cache_info()) # type: ignore
176 changes: 176 additions & 0 deletions tests/test_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
# mypy: disable-error-code="attr-defined"
from typing import Any

import pytest

from jinjarope.decorators import cache_with_transforms


def test_basic_caching() -> None:
"""Test basic function caching without transformers."""
call_count = 0

@cache_with_transforms()
def add(x: int, y: int) -> int:
nonlocal call_count
call_count += 1
return x + y

assert add(1, 2) == 3
assert add(1, 2) == 3
assert call_count == 1
assert add.cache_info()["cache_size"] == 1


def test_arg_transformer() -> None:
"""Test caching with argument transformers."""

@cache_with_transforms(arg_transformers={0: lambda x: x.lower()})
def greet(name: str) -> str:
return f"Hello, {name}!"

assert greet("John") == "Hello, John!"
assert greet("JOHN") == "Hello, John!" # Should hit cache, previous word cached
assert greet.cache_info()["cache_size"] == 1


def test_kwarg_transformer() -> None:
"""Test caching with keyword argument transformers."""

@cache_with_transforms(kwarg_transformers={"items": tuple})
def process_list(*, items: list[int]) -> int:
return sum(items)

assert process_list(items=[1, 2, 3]) == 6
assert process_list(items=[1, 2, 3]) == 6 # Should hit cache
assert process_list.cache_info()["cache_size"] == 1


def test_unhashable_args() -> None:
"""Test caching with unhashable arguments."""

@cache_with_transforms(arg_transformers={0: tuple})
def process_list(lst: list[int]) -> int:
return sum(lst)

assert process_list([1, 2, 3]) == 6
assert process_list([1, 2, 3]) == 6 # Should hit cache
assert process_list.cache_info()["cache_size"] == 1


def test_mixed_args_kwargs() -> None:
"""Test caching with both positional and keyword arguments."""

@cache_with_transforms(
arg_transformers={0: str.lower}, kwarg_transformers={"items": tuple}
)
def process_data(prefix: str, *, items: list[int]) -> str:
return f"{prefix}: {sum(items)}"

assert process_data("Sum", items=[1, 2, 3]) == "Sum: 6"
assert process_data("SUM", items=[1, 2, 3]) == "Sum: 6" # Should hit cache
assert process_data.cache_info()["cache_size"] == 1


def test_multiple_calls_different_args() -> None:
"""Test caching behavior with different arguments."""
call_count = 0

@cache_with_transforms()
def add(x: int, y: int) -> int:
nonlocal call_count
call_count += 1
return x + y

assert add(1, 2) == 3
assert add(2, 3) == 5
assert add(1, 2) == 3 # Should hit cache
assert call_count == 2
assert add.cache_info()["cache_size"] == 2


def test_none_values() -> None:
"""Test caching behavior with None values."""

@cache_with_transforms()
def process_optional(x: Any) -> str:
return str(x)

assert process_optional(None) == "None"
assert process_optional(None) == "None" # Should hit cache
assert process_optional.cache_info()["cache_size"] == 1


def test_empty_transformers() -> None:
"""Test that decorator works with empty transformers."""

@cache_with_transforms(arg_transformers={}, kwarg_transformers={})
def identity(x: int) -> int:
return x

assert identity(1) == 1
assert identity.cache_info()["cache_size"] == 1


def test_cache_persistence() -> None:
"""Test that cache persists between calls."""

@cache_with_transforms()
def expensive_operation(x: int) -> int:
return x**2

assert expensive_operation(2) == 4
initial_cache = expensive_operation.cache
assert expensive_operation(2) == 4
assert expensive_operation.cache is initial_cache


def test_different_kwarg_orders() -> None:
"""Test that different keyword argument orders produce same cache key."""

@cache_with_transforms()
def process_kwargs(*, a: int, b: int) -> int:
return a + b

assert process_kwargs(a=1, b=2) == 3
assert process_kwargs(b=2, a=1) == 3 # Should hit cache
assert process_kwargs.cache_info()["cache_size"] == 1


def test_complex_transformers() -> None:
"""Test with more complex transformer functions."""

def complex_transform(x: list[Any]) -> tuple[Any, ...]:
return tuple(sorted(x))

@cache_with_transforms(
arg_transformers={0: complex_transform},
kwarg_transformers={"data": complex_transform},
)
def process_lists(lst: list[int], *, data: list[int]) -> int:
return sum(lst) + sum(data)

assert process_lists([3, 1, 2], data=[6, 4, 5]) == 21
assert process_lists([2, 3, 1], data=[5, 6, 4]) == 21 # Should hit cache
assert process_lists.cache_info()["cache_size"] == 1


def test_error_handling() -> None:
"""Test that errors in the original function are not cached."""

@cache_with_transforms()
def failing_function(x: int) -> int:
msg = "Error"
raise ValueError(msg)

with pytest.raises(ValueError): # noqa: PT011
failing_function(1)

with pytest.raises(ValueError): # noqa: PT011
failing_function(1) # Should not be cached

assert failing_function.cache_info()["cache_size"] == 0


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit f51c0a7

Please sign in to comment.