diff --git a/pyproject.toml b/pyproject.toml index 0f1a8c3..df2cbbc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -161,9 +161,9 @@ select = [ "YTT", # flake8-2020 ] ignore = [ - "C408", # Unnecessary {obj_type} call (rewrite as a literal) - "B905", # zip() without an explicit strict= parameter - "C901", # {name} is too complex ({complexity} > {max_complexity}) + "C408", # Unnecessary {obj_type} call (rewrite as a literal) + "B905", # zip() without an explicit strict= parameter + "C901", # {name} is too complex ({complexity} > {max_complexity}) "COM812", # "CPY001", # Missing copyright notice at top of file "D100", # Missing docstring in public module @@ -211,6 +211,7 @@ combine-as-imports = true [tool.ruff.lint.per-file-ignores] "__init__.py" = ["E402", "I001"] +"tests/test_decorators.py" = ["PLR2004"] [tool.pyright] reportUnusedCallResult = false diff --git a/src/jinjarope/decorators.py b/src/jinjarope/decorators.py new file mode 100644 index 0000000..7850e75 --- /dev/null +++ b/src/jinjarope/decorators.py @@ -0,0 +1,127 @@ +from __future__ import annotations + +from functools import wraps +from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar + +import upath + + +if TYPE_CHECKING: + from collections.abc import Callable + + +P = ParamSpec("P") +R = TypeVar("R") + + +def cache_with_transforms( + *, + arg_transformers: dict[int, Callable[[Any], Any]] | None = None, + kwarg_transformers: dict[str, Callable[[Any], Any]] | None = None, +) -> Callable[[Callable[P, R]], Callable[P, R]]: + """A caching decorator with transformation functions for args and kwargs. + + Can be used to make specific args / kwargs hashable. + Also adds cache and cache_info objects to the decorated function. + + Args: + arg_transformers: Dict mapping positional args indices to transformer functions + kwarg_transformers: Dict mapping kwargs names to transformer functions + + Returns: + A decorator function that caches results based on transformed arguments + """ + arg_transformers = arg_transformers or {} + kwarg_transformers = kwarg_transformers or {} + + def decorator(func: Callable[P, R]) -> Callable[P, R]: + cache: dict[tuple[Any, ...], R] = {} + + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + # Transform positional arguments + transformed_args = tuple( + arg_transformers.get(i, lambda x: x)(arg) for i, arg in enumerate(args) + ) + + # Transform keyword arguments + transformed_kwargs = { + key: kwarg_transformers.get(key, lambda x: x)(value) + for key, value in sorted(kwargs.items()) + } + + # Create cache key from transformed arguments + cache_key = (transformed_args, tuple(transformed_kwargs.items())) + + if cache_key not in cache: + cache[cache_key] = func(*args, **kwargs) + return cache[cache_key] + + def cache_info() -> dict[str, int]: + """Return information about cache hits and size.""" + return {"cache_size": len(cache)} + + wrapper.cache_info = cache_info # type: ignore + wrapper.cache = cache # type: ignore + + return wrapper + + return decorator + + +# def path_cache( +# maxsize: int | None = None, +# typed: bool = False, +# args_index: int = 0, +# ) -> Callable[[Callable[P, R]], Callable[P, R]]: +# """Decorator that caches function results and normalizes PathLike arguments. + +# Caches the first argument by default. Can be changed via args_index argument. +# Also adds cache_info and cache_clear methods to the decorated function. + +# Args: +# maxsize: Maximum size of the cache. If None, cache is unbounded. +# typed: If True, arguments of different types are cached separately. +# args_index: Index of the argument to cache. Default is 0. + +# Returns: +# Decorated function with caching capability. +# """ + +# def decorator(func: Callable[P, R]) -> Callable[P, R]: +# # Create cached version of the original function +# cached_func = lru_cache(maxsize=maxsize, typed=typed)(func) + +# @wraps(func) +# def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: +# # Convert first argument if it's PathLike +# if args and isinstance(args[args_index], str | PathLike): +# normalized_args = (upath.UPath(args[0]).resolve(),) + args[1:] +# return cached_func(*normalized_args, **kwargs) # type: ignore[arg-type] +# return cached_func(*args, **kwargs) # type: ignore[arg-type] + +# # Add cache_info and cache_clear methods to the wrapper +# wrapper.cache_info = cached_func.cache_info # type: ignore +# wrapper.cache_clear = cached_func.cache_clear # type: ignore + +# return wrapper + +# return decorator + + +if __name__ == "__main__": + + @cache_with_transforms(arg_transformers={0: lambda p: upath.UPath(p).resolve()}) + def read_file_content(filepath: str | upath.UPath) -> str: + """Read and return the content of a file.""" + with upath.UPath(filepath).open() as f: + return f.read() + + # These calls will use the same cache entry + content1 = read_file_content("pyproject.toml") + content1 = read_file_content("mkdocs.yml") + content2 = read_file_content(upath.UPath("pyproject.toml")) + content3 = read_file_content(upath.UPath("./pyproject.toml").absolute()) + + # Check cache statistics + print(read_file_content.cache_info()) # type: ignore diff --git a/tests/test_decorators.py b/tests/test_decorators.py new file mode 100644 index 0000000..5b7d6c6 --- /dev/null +++ b/tests/test_decorators.py @@ -0,0 +1,176 @@ +# mypy: disable-error-code="attr-defined" +from typing import Any + +import pytest + +from jinjarope.decorators import cache_with_transforms + + +def test_basic_caching() -> None: + """Test basic function caching without transformers.""" + call_count = 0 + + @cache_with_transforms() + def add(x: int, y: int) -> int: + nonlocal call_count + call_count += 1 + return x + y + + assert add(1, 2) == 3 + assert add(1, 2) == 3 + assert call_count == 1 + assert add.cache_info()["cache_size"] == 1 + + +def test_arg_transformer() -> None: + """Test caching with argument transformers.""" + + @cache_with_transforms(arg_transformers={0: lambda x: x.lower()}) + def greet(name: str) -> str: + return f"Hello, {name}!" + + assert greet("John") == "Hello, John!" + assert greet("JOHN") == "Hello, John!" # Should hit cache, previous word cached + assert greet.cache_info()["cache_size"] == 1 + + +def test_kwarg_transformer() -> None: + """Test caching with keyword argument transformers.""" + + @cache_with_transforms(kwarg_transformers={"items": tuple}) + def process_list(*, items: list[int]) -> int: + return sum(items) + + assert process_list(items=[1, 2, 3]) == 6 + assert process_list(items=[1, 2, 3]) == 6 # Should hit cache + assert process_list.cache_info()["cache_size"] == 1 + + +def test_unhashable_args() -> None: + """Test caching with unhashable arguments.""" + + @cache_with_transforms(arg_transformers={0: tuple}) + def process_list(lst: list[int]) -> int: + return sum(lst) + + assert process_list([1, 2, 3]) == 6 + assert process_list([1, 2, 3]) == 6 # Should hit cache + assert process_list.cache_info()["cache_size"] == 1 + + +def test_mixed_args_kwargs() -> None: + """Test caching with both positional and keyword arguments.""" + + @cache_with_transforms( + arg_transformers={0: str.lower}, kwarg_transformers={"items": tuple} + ) + def process_data(prefix: str, *, items: list[int]) -> str: + return f"{prefix}: {sum(items)}" + + assert process_data("Sum", items=[1, 2, 3]) == "Sum: 6" + assert process_data("SUM", items=[1, 2, 3]) == "Sum: 6" # Should hit cache + assert process_data.cache_info()["cache_size"] == 1 + + +def test_multiple_calls_different_args() -> None: + """Test caching behavior with different arguments.""" + call_count = 0 + + @cache_with_transforms() + def add(x: int, y: int) -> int: + nonlocal call_count + call_count += 1 + return x + y + + assert add(1, 2) == 3 + assert add(2, 3) == 5 + assert add(1, 2) == 3 # Should hit cache + assert call_count == 2 + assert add.cache_info()["cache_size"] == 2 + + +def test_none_values() -> None: + """Test caching behavior with None values.""" + + @cache_with_transforms() + def process_optional(x: Any) -> str: + return str(x) + + assert process_optional(None) == "None" + assert process_optional(None) == "None" # Should hit cache + assert process_optional.cache_info()["cache_size"] == 1 + + +def test_empty_transformers() -> None: + """Test that decorator works with empty transformers.""" + + @cache_with_transforms(arg_transformers={}, kwarg_transformers={}) + def identity(x: int) -> int: + return x + + assert identity(1) == 1 + assert identity.cache_info()["cache_size"] == 1 + + +def test_cache_persistence() -> None: + """Test that cache persists between calls.""" + + @cache_with_transforms() + def expensive_operation(x: int) -> int: + return x**2 + + assert expensive_operation(2) == 4 + initial_cache = expensive_operation.cache + assert expensive_operation(2) == 4 + assert expensive_operation.cache is initial_cache + + +def test_different_kwarg_orders() -> None: + """Test that different keyword argument orders produce same cache key.""" + + @cache_with_transforms() + def process_kwargs(*, a: int, b: int) -> int: + return a + b + + assert process_kwargs(a=1, b=2) == 3 + assert process_kwargs(b=2, a=1) == 3 # Should hit cache + assert process_kwargs.cache_info()["cache_size"] == 1 + + +def test_complex_transformers() -> None: + """Test with more complex transformer functions.""" + + def complex_transform(x: list[Any]) -> tuple[Any, ...]: + return tuple(sorted(x)) + + @cache_with_transforms( + arg_transformers={0: complex_transform}, + kwarg_transformers={"data": complex_transform}, + ) + def process_lists(lst: list[int], *, data: list[int]) -> int: + return sum(lst) + sum(data) + + assert process_lists([3, 1, 2], data=[6, 4, 5]) == 21 + assert process_lists([2, 3, 1], data=[5, 6, 4]) == 21 # Should hit cache + assert process_lists.cache_info()["cache_size"] == 1 + + +def test_error_handling() -> None: + """Test that errors in the original function are not cached.""" + + @cache_with_transforms() + def failing_function(x: int) -> int: + msg = "Error" + raise ValueError(msg) + + with pytest.raises(ValueError): # noqa: PT011 + failing_function(1) + + with pytest.raises(ValueError): # noqa: PT011 + failing_function(1) # Should not be cached + + assert failing_function.cache_info()["cache_size"] == 0 + + +if __name__ == "__main__": + pytest.main([__file__])