diff --git a/pytato/codegen.py b/pytato/codegen.py index 85ac4052d..29fce4307 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -68,7 +68,7 @@ if TYPE_CHECKING: - from collections.abc import Hashable, Mapping + from collections.abc import Mapping from pytato.function import FunctionDefinition, NamedCallResult from pytato.target import Target @@ -135,12 +135,13 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc] :class:`~pytato.array.Stack` :class:`~pytato.array.IndexLambda` ====================================== ===================================== """ + _FunctionCacheT: TypeAlias = CopyMapper._FunctionCacheT def __init__( self, target: Target, kernels_seen: dict[str, lp.LoopKernel] | None = None, - _function_cache: dict[Hashable, FunctionDefinition] | None = None + _function_cache: _FunctionCacheT | None = None ) -> None: super().__init__(_function_cache=_function_cache) self.bound_arguments: dict[str, DataInterface] = {} diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 8e5940a06..557104838 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -93,6 +93,8 @@ if TYPE_CHECKING: + from typing import TypeAlias + import mpi4py.MPI from pytato.function import FunctionDefinition, NamedCallResult @@ -283,12 +285,13 @@ class _DistributedInputReplacer(CopyMapper): instances for their assigned names. Also gathers names for user-supplied inputs needed by the part """ + _FunctionCacheT: TypeAlias = CopyMapper._FunctionCacheT def __init__(self, recvd_ary_to_name: Mapping[Array, str], sptpo_ary_to_name: Mapping[Array, str], name_to_output: Mapping[str, Array], - _function_cache: dict[Hashable, FunctionDefinition] | None = None, + _function_cache: _FunctionCacheT | None = None, ) -> None: super().__init__(_function_cache=_function_cache) @@ -337,9 +340,9 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend: return new_send def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - key = self.get_cache_key(expr) + key = self._cache.get_key(expr) try: - return self._cache[key] + return self._cache.retrieve(expr, key=key) except KeyError: pass diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index b4be9aceb..6601d4b29 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -28,6 +28,7 @@ """ import dataclasses import logging +from collections.abc import Hashable from typing import ( TYPE_CHECKING, Any, @@ -79,17 +80,21 @@ if TYPE_CHECKING: - from collections.abc import Callable, Hashable, Iterable, Mapping + from collections.abc import Callable, Iterable, Mapping ArrayOrNames: TypeAlias = Array | AbstractResultWithNamedArrays MappedT = TypeVar("MappedT", Array, AbstractResultWithNamedArrays, ArrayOrNames) +CacheExprT = TypeVar("CacheExprT") # used in CachedMapperCache +CacheKeyT = TypeVar("CacheKeyT") # used in CachedMapperCache +CacheResultT = TypeVar("CacheResultT") # used in CachedMapperCache IndexOrShapeExpr = TypeVar("IndexOrShapeExpr") R = frozenset[Array] __doc__ = """ .. autoclass:: Mapper +.. autoclass:: CachedMapperCache .. autoclass:: CachedMapper .. autoclass:: TransformMapper .. autoclass:: TransformMapperWithExtraArgs @@ -246,61 +251,147 @@ def __call__(self, # {{{ CachedMapper +class CachedMapperCache(Generic[CacheExprT, CacheKeyT, CacheResultT, P]): + """ + Cache for :class:`CachedMapper`. + + .. automethod:: __init__ + .. automethod:: get_key + .. automethod:: add + .. automethod:: retrieve + """ + def __init__( + self, + # FIXME: Figure out the right way to type annotate this + key_func: Callable[..., CacheKeyT]) -> None: + """ + Initialize the cache. + + :arg key_func: Function to compute a hashable cache key from an input + expression. + """ + self._key_func = key_func + self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {} + + # FIXME: Can this be inlined? + def get_key( + self, expr: CacheExprT, *args: P.args, **kwargs: P.kwargs) -> CacheKeyT: + """Compute the key for an input expression.""" + return self._key_func(expr, *args, **kwargs) + + def add( + self, + key_inputs: + CacheExprT + # FIXME: Figure out the right way to type annotate these + | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], + result: CacheResultT, + key: CacheKeyT | None = None) -> CacheResultT: + """Cache a mapping result.""" + if key is None: + if isinstance(key_inputs, tuple): + expr, key_args, key_kwargs = key_inputs + key = self._key_func(expr, *key_args, **key_kwargs) + else: + key = self._key_func(key_inputs) + + self._expr_key_to_result[key] = result + + return result + + def retrieve( + self, + key_inputs: + CacheExprT + # FIXME: Figure out the right way to type annotate these + | tuple[CacheExprT, tuple[Any, ...], dict[str, Any]], + key: CacheKeyT | None = None) -> CacheResultT: + """Retrieve the cached mapping result.""" + if key is None: + if isinstance(key_inputs, tuple): + expr, key_args, key_kwargs = key_inputs + key = self._key_func(expr, *key_args, **key_kwargs) + else: + key = self._key_func(key_inputs) + + return self._expr_key_to_result[key] + + class CachedMapper(Mapper[ResultT, FunctionResultT, P]): """Mapper class that maps each node in the DAG exactly once. This loses some information compared to :class:`Mapper` as a node is visited only from one of its predecessors. - .. automethod:: get_cache_key - .. automethod:: get_function_definition_cache_key .. automethod:: clone_for_callee """ + # Not sure if there's a way to simplify this stuff? + _OtherP = ParamSpec("_OtherP") + + _CacheType: type[Any] = CachedMapperCache[ + ArrayOrNames, + Hashable, + ResultT, P] + _OtherResultT = TypeVar("_OtherResultT") + _CacheT: TypeAlias = CachedMapperCache[ + ArrayOrNames, + Hashable, + _OtherResultT, _OtherP] + + _FunctionCacheType: type[Any] = CachedMapperCache[ + FunctionDefinition, + Hashable, + FunctionResultT, P] + _OtherFunctionResultT = TypeVar("_OtherFunctionResultT") + _FunctionCacheT: TypeAlias = CachedMapperCache[ + FunctionDefinition, + Hashable, + _OtherFunctionResultT, _OtherP] def __init__( self, # Arrays are cached separately for each call stack frame, but # functions are cached globally - _function_cache: dict[Hashable, FunctionResultT] | None = None + _function_cache: _FunctionCacheT[FunctionResultT, P] | None = None ) -> None: super().__init__() - self._cache: dict[Hashable, ResultT] = {} + + def key_func( + expr: ArrayOrNames | FunctionDefinition, + *args: Any, **kwargs: Any) -> Hashable: + return (expr, args, tuple(sorted(kwargs.items()))) + + self._cache: CachedMapper._CacheT[ResultT, P] = \ + CachedMapper._CacheType(key_func) if _function_cache is not None: function_cache = _function_cache else: - function_cache = {} + function_cache = CachedMapper._FunctionCacheType(key_func) - self._function_cache: dict[Hashable, FunctionResultT] = function_cache - - def get_cache_key( - self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs - ) -> Hashable: - return (expr, *args, tuple(sorted(kwargs.items()))) - - def get_function_definition_cache_key( - self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs - ) -> Hashable: - return (expr, *args, tuple(sorted(kwargs.items()))) + self._function_cache: CachedMapper._FunctionCacheT[FunctionResultT, P] = \ + function_cache def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT: - key = self.get_cache_key(expr, *args, **kwargs) + key = self._cache.get_key(expr, *args, **kwargs) try: - return self._cache[key] + return self._cache.retrieve((expr, args, kwargs), key=key) except KeyError: - result = super().rec(expr, *args, **kwargs) - self._cache[key] = result - return result + return self._cache.add( + (expr, args, kwargs), + super().rec(expr, *args, **kwargs), + key=key) def rec_function_definition( self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs ) -> FunctionResultT: - key = self.get_function_definition_cache_key(expr, *args, **kwargs) + key = self._function_cache.get_key(expr, *args, **kwargs) try: - return self._function_cache[key] + return self._function_cache.retrieve((expr, args, kwargs), key=key) except KeyError: - result = super().rec_function_definition(expr, *args, **kwargs) - self._function_cache[key] = result - return result + return self._function_cache.add( + (expr, args, kwargs), + super().rec_function_definition(expr, *args, **kwargs), + key=key) def clone_for_callee( self, function: FunctionDefinition) -> Self: @@ -320,10 +411,19 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]): other :class:`pytato.array.Array`\\ s. Enables certain operations that can only be done if the mapping results are also - arrays (e.g., calling :meth:`~CachedMapper.get_cache_key` on them). Does not - implement default mapper methods; for that, see :class:`CopyMapper`. - + arrays (e.g., computing a cache key from them). Does not implement default + mapper methods; for that, see :class:`CopyMapper`. """ + _CacheType: type[Any] = CachedMapperCache[ + ArrayOrNames, Hashable, ArrayOrNames, []] + _CacheT: TypeAlias = CachedMapperCache[ + ArrayOrNames, Hashable, ArrayOrNames, []] + + _FunctionCacheType: type[Any] = CachedMapperCache[ + FunctionDefinition, Hashable, FunctionDefinition, []] + _FunctionCacheT: TypeAlias = CachedMapperCache[ + FunctionDefinition, Hashable, FunctionDefinition, []] + def rec_ary(self, expr: Array) -> Array: res = self.rec(expr) assert isinstance(res, Array) @@ -345,6 +445,18 @@ class TransformMapperWithExtraArgs( The logic in :class:`TransformMapper` purposely does not take the extra arguments to keep the cost of its each call frame low. """ + _OtherP = ParamSpec("_OtherP") + + _CacheType: type[Any] = CachedMapperCache[ + ArrayOrNames, Hashable, ArrayOrNames, P] + _CacheT: TypeAlias = CachedMapperCache[ + ArrayOrNames, Hashable, ArrayOrNames, _OtherP] + + _FunctionCacheType: type[Any] = CachedMapperCache[ + FunctionDefinition, Hashable, FunctionDefinition, P] + _FunctionCacheT: TypeAlias = CachedMapperCache[ + FunctionDefinition, Hashable, FunctionDefinition, _OtherP] + def rec_ary(self, expr: Array, *args: P.args, **kwargs: P.kwargs) -> Array: res = self.rec(expr, *args, **kwargs) assert isinstance(res, Array) @@ -1381,11 +1493,12 @@ class CachedMapAndCopyMapper(CopyMapper): Mapper that applies *map_fn* to each node and copies it. Results of traversals are memoized i.e. each node is mapped via *map_fn* exactly once. """ + _FunctionCacheT: TypeAlias = CopyMapper._FunctionCacheT def __init__( self, map_fn: Callable[[ArrayOrNames], ArrayOrNames], - _function_cache: dict[Hashable, FunctionDefinition] | None = None + _function_cache: _FunctionCacheT | None = None ) -> None: super().__init__(_function_cache=_function_cache) self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn @@ -1395,12 +1508,12 @@ def clone_for_callee( return type(self)(self.map_fn, _function_cache=self._function_cache) def rec(self, expr: ArrayOrNames) -> ArrayOrNames: - if expr in self._cache: - return self._cache[expr] - - result = super().rec(self.map_fn(expr)) - self._cache[expr] = result - return result + key = self._cache.get_key(expr) + try: + return self._cache.retrieve(expr, key=key) + except KeyError: + return self._cache.add( + expr, super().rec(self.map_fn(expr)), key=key) # }}} diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index eda50bd71..009439f4f 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -90,9 +90,10 @@ if TYPE_CHECKING: - from collections.abc import Collection, Hashable, Mapping + from collections.abc import Collection, Mapping + from typing import TypeAlias - from pytato.function import FunctionDefinition, NamedCallResult + from pytato.function import NamedCallResult from pytato.loopy import LoopyCall @@ -593,10 +594,12 @@ class AxisTagAttacher(CopyMapper): """ A mapper that tags the axes in a DAG as prescribed by *axis_to_tags*. """ + _FunctionCacheT: TypeAlias = CopyMapper._FunctionCacheT + def __init__(self, axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]], tag_corresponding_redn_descr: bool, - _function_cache: dict[Hashable, FunctionDefinition] | None = None): + _function_cache: _FunctionCacheT | None = None): super().__init__(_function_cache=_function_cache) self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr @@ -644,9 +647,9 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array: return result def rec(self, expr: ArrayOrNames) -> Any: - key = self.get_cache_key(expr) + key = self._cache.get_key(expr) try: - return self._cache[key] + return self._cache.retrieve(expr, key=key) except KeyError: result = Mapper.rec(self, expr) if not isinstance( @@ -654,8 +657,7 @@ def rec(self, expr: ArrayOrNames) -> Any: assert isinstance(expr, Array) # type-ignore reason: passed "ArrayOrNames"; expected "Array" result = self._attach_tags(expr, result) # type: ignore[arg-type] - self._cache[key] = result - return result + return self._cache.add(expr, result, key=key) def map_named_call_result(self, expr: NamedCallResult) -> Array: raise NotImplementedError( diff --git a/test/test_apps.py b/test/test_apps.py index afb3a9ae4..89d52218e 100644 --- a/test/test_apps.py +++ b/test/test_apps.py @@ -94,7 +94,7 @@ def __init__(self, fft_vec_gatherer): arrays = fft_vec_gatherer.level_to_arrays[lev] rec_arrays = [self.rec(ary) for ary in arrays] # reset cache so that the partial subs are not stored - self._cache = {} + self._cache = type(self._cache)(lambda expr: expr) lev_array = pt.concatenate(rec_arrays, axis=0) assert lev_array.shape == (fft_vec_gatherer.n,)