diff --git a/pytato/array.py b/pytato/array.py index c1bdb7d52..75fb16e0e 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -177,7 +177,15 @@ import operator import re from abc import ABC, abstractmethod -from collections.abc import Callable, Collection, Iterable, Iterator, Mapping, Sequence +from collections.abc import ( + Callable, + Collection, + Iterable, + Iterator, + KeysView, + Mapping, + Sequence, +) from functools import cached_property, partialmethod from sys import intern from typing import ( @@ -923,6 +931,7 @@ class AbstractResultWithNamedArrays(Mapping[str, NamedArray], Taggable, ABC): .. automethod:: __contains__ .. automethod:: __getitem__ .. automethod:: __len__ + .. automethod:: keys .. note:: @@ -961,6 +970,11 @@ def __eq__(self, other: Any) -> bool: from pytato.equality import EqualityComparer return EqualityComparer()(self, other) + @abstractmethod + def keys(self) -> KeysView[str]: + """Return a :class:`KeysView` of the names of the named arrays.""" + pass + @dataclasses.dataclass(frozen=True, eq=False, init=False) class DictOfNamedArrays(AbstractResultWithNamedArrays): @@ -1009,7 +1023,12 @@ def __iter__(self) -> Iterator[str]: return iter(self._data) def __repr__(self) -> str: - return "DictOfNamedArrays(" + str(self._data) + ")" + return f"DictOfNamedArrays(tags={self.tags!r}, data={self._data!r})" + + # Note: items() and values() are not implemented here, they go through + # __iter__()/__getitem__() above. + def keys(self) -> KeysView[str]: + return self._data.keys() # }}} diff --git a/pytato/function.py b/pytato/function.py index 3450e62e6..c96070bf5 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -58,7 +58,14 @@ import dataclasses import enum import re -from collections.abc import Callable, Hashable, Iterable, Iterator, Mapping +from collections.abc import ( + Callable, + Hashable, + Iterable, + Iterator, + KeysView, + Mapping, +) from functools import cached_property from typing import ( Any, @@ -339,6 +346,9 @@ def __len__(self) -> int: def _with_new_tags(self: Call, tags: frozenset[Tag]) -> Call: return dataclasses.replace(self, tags=tags) + def keys(self) -> KeysView[str]: + return self.function.returns.keys() + # }}} diff --git a/pytato/loopy.py b/pytato/loopy.py index 30cc69128..5bb86fcd1 100644 --- a/pytato/loopy.py +++ b/pytato/loopy.py @@ -149,6 +149,10 @@ def __len__(self) -> int: def __iter__(self) -> Iterator[str]: return iter(self._result_names) + # type-ignore-reason: AbstractResultWithNamedArrays returns a KeysView here + def keys(self) -> frozenset[str]: # type: ignore[override] + return self._result_names + @array_dataclass() # https://github.com/python/mypy/issues/18115 diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 079430867..2e66828cd 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -10,7 +10,7 @@ .. class:: Expression - See :attr:`pymbolic.typing.Expression`. + See :data:`pymbolic.typing.Expression`. """ # FIXME: Unclear why the direct links to pymbolic don't work diff --git a/requirements.txt b/requirements.txt index 043cbf347..b1f1df20a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -git+https://github.com/inducer/pytools.git#egg=pytools >= 2024.1.14 +git+https://github.com/inducer/pytools.git#egg=pytools git+https://github.com/inducer/pymbolic.git#egg=pymbolic git+https://github.com/inducer/genpy.git#egg=genpy git+https://github.com/inducer/loopy.git#egg=loopy