Skip to content

Commit

Permalink
use a class for CachedMapper caches instead of using a dict directly
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Jan 10, 2025
1 parent 27aa393 commit d40153c
Show file tree
Hide file tree
Showing 5 changed files with 210 additions and 64 deletions.
5 changes: 3 additions & 2 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@


if TYPE_CHECKING:
from collections.abc import Hashable, Mapping
from collections.abc import Mapping

from pytato.function import FunctionDefinition, NamedCallResult
from pytato.target import Target
Expand Down Expand Up @@ -135,12 +135,13 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc]
:class:`~pytato.array.Stack` :class:`~pytato.array.IndexLambda`
====================================== =====================================
"""
_FunctionCacheT: TypeAlias = CopyMapper._FunctionCacheT

def __init__(
self,
target: Target,
kernels_seen: dict[str, lp.LoopKernel] | None = None,
_function_cache: dict[Hashable, FunctionDefinition] | None = None
_function_cache: _FunctionCacheT | None = None
) -> None:
super().__init__(_function_cache=_function_cache)
self.bound_arguments: dict[str, DataInterface] = {}
Expand Down
9 changes: 6 additions & 3 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@


if TYPE_CHECKING:
from typing import TypeAlias

import mpi4py.MPI

from pytato.function import FunctionDefinition, NamedCallResult
Expand Down Expand Up @@ -290,12 +292,13 @@ class _DistributedInputReplacer(CopyMapper):
instances for their assigned names. Also gathers names for
user-supplied inputs needed by the part
"""
_FunctionCacheT: TypeAlias = CopyMapper._FunctionCacheT

def __init__(self,
recvd_ary_to_name: Mapping[Array, str],
sptpo_ary_to_name: Mapping[Array, str],
name_to_output: Mapping[str, Array],
_function_cache: dict[Hashable, FunctionDefinition] | None = None,
_function_cache: _FunctionCacheT | None = None,
) -> None:
super().__init__(_function_cache=_function_cache)

Expand Down Expand Up @@ -344,9 +347,9 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend:
return new_send

def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
key = self.get_cache_key(expr)
key = self._cache.get_key(expr)
try:
return self._cache[key]
return self._cache.retrieve(expr, key=key)
except KeyError:
pass

Expand Down
Loading

0 comments on commit d40153c

Please sign in to comment.