Skip to content

Commit

Permalink
with_tagged_axis: accept a collection of tags
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer authored and majosm committed Sep 24, 2024
1 parent 72ce101 commit 7ea4699
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,7 +714,7 @@ def any(self, axis: int = 0) -> ArrayOrScalar:
return pt.any(self, axis)

def with_tagged_axis(self, iaxis: int,
tags: Sequence[Tag] | Tag) -> Array:
tags: Collection[Tag] | Tag) -> Array:
"""
Returns a copy of *self* with *iaxis*-th axis tagged with *tags*.
"""
Expand Down
4 changes: 2 additions & 2 deletions pytato/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@
Any,
Callable,
ClassVar,
Collection,
Hashable,
Iterable,
Iterator,
Mapping,
Sequence,
Tuple,
TypeVar,
)
Expand Down Expand Up @@ -264,7 +264,7 @@ class NamedCallResult(NamedArray):
_mapper_method: ClassVar[str] = "map_named_call_result"

def with_tagged_axis(self, iaxis: int,
tags: Sequence[Tag] | Tag) -> Array:
tags: Collection[Tag] | Tag) -> Array:
raise ValueError("Tagging a NamedCallResult's axis is illegal, use"
" Call.with_tagged_axis instead")

Expand Down
5 changes: 3 additions & 2 deletions pytato/transform/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from typing import (
TYPE_CHECKING,
Any,
Collection,
Hashable,
Iterable,
List,
Expand Down Expand Up @@ -594,11 +595,11 @@ class AxisTagAttacher(CopyMapper):
A mapper that tags the axes in a DAG as prescribed by *axis_to_tags*.
"""
def __init__(self,
axis_to_tags: Mapping[tuple[Array, int], Iterable[Tag]],
axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]],
tag_corresponding_redn_descr: bool,
_function_cache: dict[Hashable, FunctionDefinition] | None = None):
super().__init__(_function_cache=_function_cache)
self.axis_to_tags: Mapping[tuple[Array, int], Iterable[Tag]] = axis_to_tags
self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags
self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr

def rec(self, expr: ArrayOrNames) -> Any:
Expand Down

0 comments on commit 7ea4699

Please sign in to comment.