From f5c377eea26e84a66397a976c5dd544c6b4e58a0 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 22 Mar 2022 22:53:41 -0500 Subject: [PATCH 001/124] Add tag to store array creation traceback --- pytato/array.py | 6 +++++- pytato/tags.py | 10 ++++++++++ test/test_pytato.py | 15 +++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index aa17a2f4b..7b3b1e0f6 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -419,8 +419,12 @@ class Array(Taggable): __array_priority__ = 1 # disallow numpy arithmetic to take precedence def __init__(self, axes: AxesT, tags: FrozenSet[Tag]) -> None: + import traceback + v = "".join(traceback.format_stack()) + from pytato.tags import CreatedAt + c = CreatedAt(v) self.axes = axes - self.tags = tags + self.tags = frozenset({*tags, c}) def copy(self: ArrayT, **kwargs: Any) -> ArrayT: for field in self._fields: diff --git a/pytato/tags.py b/pytato/tags.py index f1794c177..f1ed4984c 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -101,3 +101,13 @@ class AssumeNonNegative(Tag): :class:`~pytato.target.Target` that all entries of the tagged array are non-negative. """ + + +@tag_dataclass +class CreatedAt(UniqueTag): + """ + A tag attached to a :class:`~pytato.Array` to store the traceback + of where it was created. + """ + + traceback: str diff --git a/test/test_pytato.py b/test/test_pytato.py index 5f602a240..940ff2353 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -790,6 +790,21 @@ def test_einsum_dot_axes_has_correct_dim(): assert len(einsum.axes) == einsum.ndim +def test_created_at(): + a = pt.make_placeholder("a", (10, 10), "float64") + b = pt.make_placeholder("b", (10, 10), "float64") + + res = a+b + + from pytato.tags import CreatedAt + assert any(isinstance(tag, CreatedAt) for tag in res.tags) + + # Make sure the function name appears in the traceback + for tag in res.tags: + if isinstance(tag, CreatedAt): + assert "test_created_at" in tag.traceback + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) From 4c32cb69b862fcca6f40502a946af1cb13818f8d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 23 Mar 2022 11:01:28 -0500 Subject: [PATCH 002/124] don't make it a unique tag --- pytato/tags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/tags.py b/pytato/tags.py index f1ed4984c..285062a5c 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -104,7 +104,7 @@ class AssumeNonNegative(Tag): @tag_dataclass -class CreatedAt(UniqueTag): +class CreatedAt(Tag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From 56fbf4c9da48b1463044b0df1b3d365427a329d7 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 23 Mar 2022 16:50:24 -0500 Subject: [PATCH 003/124] adds a common _get_default_tags --- pytato/array.py | 40 +++++++++++++++++++++++++++++++++------- pytato/cmath.py | 9 ++++++--- pytato/utils.py | 8 ++++++-- 3 files changed, 45 insertions(+), 12 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index aa17a2f4b..3e66d8796 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -474,6 +474,7 @@ def ndim(self) -> int: def T(self) -> Array: return AxisPermutation(self, tuple(range(self.ndim)[::-1]), + tags=_get_default_tags(), axes=_get_default_axes(self.ndim)) @memoize_method @@ -539,6 +540,7 @@ def _unary_op(self, op: Any) -> Array: shape=self.shape, dtype=self.dtype, bindings=bindings, + tags=_get_default_tags(), axes=_get_default_axes(self.ndim)) __mul__ = partialmethod(_binary_op, operator.mul) @@ -1133,6 +1135,7 @@ def einsum(subscripts: str, *operands: Array) -> Einsum: access_descriptors.append(access_descriptor) return Einsum(tuple(access_descriptors), operands, + tags=_get_default_tags(), axes=_get_default_axes(len({descr for descr in index_to_descr.values() if isinstance(descr, @@ -1672,6 +1675,10 @@ def _get_default_axes(ndim: int) -> AxesT: return tuple(Axis(frozenset()) for _ in range(ndim)) +def _get_default_tags() -> TagsType: + return frozenset() + + def _get_matmul_ndim(ndim1: int, ndim2: int) -> int: if ndim1 == 1 and ndim2 == 1: return 0 @@ -1738,7 +1745,9 @@ def roll(a: Array, shift: int, axis: Optional[int] = None) -> Array: if shift == 0: return a - return Roll(a, shift, axis, axes=_get_default_axes(a.ndim)) + return Roll(a, shift, axis, + tags=_get_default_tags(), + axes=_get_default_axes(a.ndim)) def transpose(a: Array, axes: Optional[Sequence[int]] = None) -> Array: @@ -1758,7 +1767,9 @@ def transpose(a: Array, axes: Optional[Sequence[int]] = None) -> Array: if set(axes) != set(range(a.ndim)): raise ValueError("repeated or out-of-bounds axes detected") - return AxisPermutation(a, tuple(axes), axes=_get_default_axes(a.ndim)) + return AxisPermutation(a, tuple(axes), + tags=_get_default_tags(), + axes=_get_default_axes(a.ndim)) def stack(arrays: Sequence[Array], axis: int = 0) -> Array: @@ -1790,7 +1801,9 @@ def stack(arrays: Sequence[Array], axis: int = 0) -> Array: if not (0 <= axis <= arrays[0].ndim): raise ValueError("invalid axis") - return Stack(tuple(arrays), axis, axes=_get_default_axes(arrays[0].ndim+1)) + return Stack(tuple(arrays), axis, + tags=_get_default_tags(), + axes=_get_default_axes(arrays[0].ndim+1)) def concatenate(arrays: Sequence[Array], axis: int = 0) -> Array: @@ -1823,7 +1836,9 @@ def shape_except_axis(ary: Array) -> ShapeType: if not (0 <= axis <= arrays[0].ndim): raise ValueError("invalid axis") - return Concatenate(tuple(arrays), axis, axes=_get_default_axes(arrays[0].ndim)) + return Concatenate(tuple(arrays), axis, + tags=_get_default_tags(), + axes=_get_default_axes(arrays[0].ndim)) def reshape(array: Array, newshape: Union[int, Sequence[int]], @@ -1885,6 +1900,7 @@ def reshape(array: Array, newshape: Union[int, Sequence[int]], f" into {newshape}") return Reshape(array, tuple(newshape_explicit), order, + tags=_get_default_tags(), axes=_get_default_axes(len(newshape_explicit))) @@ -1925,7 +1941,8 @@ def make_placeholder(name: str, raise ValueError("'axes' dimensionality mismatch:" f" expected {len(shape)}, got {len(axes)}.") - return Placeholder(name, shape, dtype, axes=axes, tags=tags) + return Placeholder(name, shape, dtype, axes=axes, + tags=(tags | _get_default_tags())) def make_size_param(name: str, @@ -1939,7 +1956,7 @@ def make_size_param(name: str, :param tags: implementation tags """ _check_identifier(name, optional=False) - return SizeParam(name, tags=tags) + return SizeParam(name, tags=(tags | _get_default_tags())) def make_data_wrapper(data: DataInterface, @@ -1967,7 +1984,9 @@ def make_data_wrapper(data: DataInterface, raise ValueError("'axes' dimensionality mismatch:" f" expected {len(shape)}, got {len(axes)}.") - return DataWrapper(name, data, shape, axes=axes, tags=tags) + return DataWrapper(name, data, shape, + axes=axes, + tags=(tags | _get_default_tags())) # }}} @@ -1985,6 +2004,7 @@ def full(shape: ConvertibleToShape, fill_value: ScalarType, shape = normalize_shape(shape) dtype = np.dtype(dtype) return IndexLambda(dtype.type(fill_value), shape, dtype, {}, + tags=_get_default_tags(), axes=_get_default_axes(len(shape))) @@ -2029,6 +2049,7 @@ def eye(N: int, M: Optional[int] = None, k: int = 0, # noqa: N803 return IndexLambda(parse(f"1 if ((_1 - _0) == {k}) else 0"), shape=(N, M), dtype=dtype, bindings={}, + tags=_get_default_tags(), axes=_get_default_axes(2)) # }}} @@ -2122,6 +2143,7 @@ def arange(*args: Any, **kwargs: Any) -> Array: from pymbolic.primitives import Variable return IndexLambda(start + Variable("_0") * step, shape=(size,), dtype=dtype, bindings={}, + tags=_get_default_tags(), axes=_get_default_axes(1)) # }}} @@ -2222,6 +2244,7 @@ def logical_not(x: ArrayOrScalar) -> Union[Array, bool]: shape=x.shape, dtype=np.dtype(np.bool8), bindings={"_in0": x}, + tags=_get_default_tags(), axes=_get_default_axes(len(x.shape))) # }}} @@ -2274,6 +2297,7 @@ def where(condition: ArrayOrScalar, shape=result_shape, dtype=dtype, bindings=bindings, + tags=_get_default_tags(), axes=_get_default_axes(len(result_shape))) # }}} @@ -2330,6 +2354,7 @@ def make_index_lambda( bindings=bindings, shape=shape, dtype=dtype, + tags=_get_default_tags(), axes=_get_default_axes(len(shape))) # }}} @@ -2407,6 +2432,7 @@ def broadcast_to(array: Array, shape: ShapeType) -> Array: shape=shape, dtype=array.dtype, bindings={"in": array}, + tags=_get_default_tags(), axes=_get_default_axes(len(shape))) diff --git a/pytato/cmath.py b/pytato/cmath.py index 8d4fc1da0..5ada73e6b 100644 --- a/pytato/cmath.py +++ b/pytato/cmath.py @@ -59,7 +59,7 @@ import pymbolic.primitives as prim from typing import Tuple, Optional from pytato.array import (Array, ArrayOrScalar, IndexLambda, _dtype_any, - _get_default_axes) + _get_default_axes, _get_default_tags) from pytato.scalar_expr import SCALAR_CLASSES from pymbolic import var @@ -110,8 +110,11 @@ def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...], assert ret_dtype is not None return IndexLambda( - prim.Call(var(f"pytato.c99.{func_name}"), tuple(sym_args)), - shape, ret_dtype, bindings, axes=_get_default_axes(len(shape))) + prim.Call(var(f"pytato.c99.{func_name}"), + tuple(sym_args)), + shape, ret_dtype, bindings, + tags=_get_default_tags(), + axes=_get_default_axes(len(shape))) def abs(x: Array) -> ArrayOrScalar: diff --git a/pytato/utils.py b/pytato/utils.py index 51959e2f5..c475e1da1 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -169,7 +169,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 ) -> ArrayOrScalar: - from pytato.array import _get_default_axes + from pytato.array import _get_default_axes, _get_default_tags if isinstance(a1, SCALAR_CLASSES): a1 = np.dtype(type(a1)).type(a1) @@ -196,6 +196,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, shape=result_shape, dtype=result_dtype, bindings=bindings, + tags=_get_default_tags(), axes=_get_default_axes(len(result_shape))) @@ -461,7 +462,7 @@ def _normalized_slice_len(slice_: NormalizedSlice) -> ShapeComponent: def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Array: from pytato.diagnostic import CannotBroadcastError - from pytato.array import _get_default_axes + from pytato.array import _get_default_axes, _get_default_tags # {{{ handle ellipsis @@ -543,18 +544,21 @@ def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Arra return AdvancedIndexInNoncontiguousAxes( ary, tuple(normalized_indices), + tags=_get_default_tags(), axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: return AdvancedIndexInContiguousAxes( ary, tuple(normalized_indices), + tags=_get_default_tags(), axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: # basic indexing expression return BasicIndex(ary, tuple(normalized_indices), + tags=_get_default_tags(), axes=_get_default_axes( len([idx for idx in normalized_indices From 8fee6d4de8e9134333d29a4d0ae6cb3fd0810248 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 23 Mar 2022 17:19:38 -0500 Subject: [PATCH 004/124] Change back to UniqueTag MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Andreas Klöckner --- pytato/tags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/tags.py b/pytato/tags.py index 285062a5c..f1ed4984c 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -104,7 +104,7 @@ class AssumeNonNegative(Tag): @tag_dataclass -class CreatedAt(Tag): +class CreatedAt(UniqueTag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From bdff59ceb4e5e993e11b0e5c6ac4ed8cb5d0a459 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 23 Mar 2022 17:42:39 -0500 Subject: [PATCH 005/124] use _get_default_tags --- pytato/array.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index decfd4d1f..a90354e93 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -419,12 +419,8 @@ class Array(Taggable): __array_priority__ = 1 # disallow numpy arithmetic to take precedence def __init__(self, axes: AxesT, tags: FrozenSet[Tag]) -> None: - import traceback - v = "".join(traceback.format_stack()) - from pytato.tags import CreatedAt - c = CreatedAt(v) self.axes = axes - self.tags = frozenset({*tags, c}) + self.tags = tags def copy(self: ArrayT, **kwargs: Any) -> ArrayT: for field in self._fields: @@ -1680,7 +1676,12 @@ def _get_default_axes(ndim: int) -> AxesT: def _get_default_tags() -> TagsType: - return frozenset() + import traceback + from pytato.tags import CreatedAt + + v = "".join(traceback.format_stack()) + c = CreatedAt(v) + return frozenset((c,)) def _get_matmul_ndim(ndim1: int, ndim2: int) -> int: From 6d181447ee2564097c24a3d8658bef378400274a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 23 Mar 2022 19:01:23 -0500 Subject: [PATCH 006/124] store a tupleized StackSummary --- pytato/array.py | 7 +++++-- test/test_pytato.py | 10 ++++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index a90354e93..ed56cbda5 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1679,8 +1679,11 @@ def _get_default_tags() -> TagsType: import traceback from pytato.tags import CreatedAt - v = "".join(traceback.format_stack()) - c = CreatedAt(v) + # extract_stack returns a StackSummary, which is a list + # You can restore the StackSummary object by calling + # StackSummary.from_list(c.traceback) + stack_summary = traceback.extract_stack() + c = CreatedAt(tuple(tuple(t) for t in tuple(stack_summary))) return frozenset((c,)) diff --git a/test/test_pytato.py b/test/test_pytato.py index 940ff2353..523aa5639 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -797,12 +797,18 @@ def test_created_at(): res = a+b from pytato.tags import CreatedAt - assert any(isinstance(tag, CreatedAt) for tag in res.tags) + + found = False # Make sure the function name appears in the traceback for tag in res.tags: if isinstance(tag, CreatedAt): - assert "test_created_at" in tag.traceback + for line in tag.traceback: + if line[2] == "test_created_at": + found = True + break + + assert found if __name__ == "__main__": From 2c3faebb4729b6bfd6140aa7410d03edf59a7d55 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 24 Mar 2022 11:34:35 -0500 Subject: [PATCH 007/124] work around mypy --- pytato/tags.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytato/tags.py b/pytato/tags.py index f1ed4984c..9a4f5c40e 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -12,6 +12,7 @@ """ +from dataclasses import dataclass from pytools.tag import Tag, UniqueTag, tag_dataclass @@ -103,7 +104,9 @@ class AssumeNonNegative(Tag): """ -@tag_dataclass +# See https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues +# on why this can not be '@tag_dataclass'. +@dataclass(init=True, eq=True, frozen=True, repr=True) class CreatedAt(UniqueTag): """ A tag attached to a :class:`~pytato.Array` to store the traceback From 739f3d3032e34e88b8f8e3d9df8588f48224a0c7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 24 Mar 2022 13:23:24 -0500 Subject: [PATCH 008/124] more line fixes --- pytato/tags.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytato/tags.py b/pytato/tags.py index 9a4f5c40e..dd35c5dc1 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -14,6 +14,7 @@ from dataclasses import dataclass from pytools.tag import Tag, UniqueTag, tag_dataclass +from typing import Tuple, Any # {{{ pre-defined tag: ImplementationStrategy @@ -104,7 +105,8 @@ class AssumeNonNegative(Tag): """ -# See https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues +# See +# https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. @dataclass(init=True, eq=True, frozen=True, repr=True) class CreatedAt(UniqueTag): @@ -113,4 +115,4 @@ class CreatedAt(UniqueTag): of where it was created. """ - traceback: str + traceback: Tuple[Tuple[Any, ...], ...] From 4fd3d64439b891abac11571e00d49fd646e1be05 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 14:33:03 -0500 Subject: [PATCH 009/124] use a class for the traceback instead of tuples --- pytato/array.py | 31 ++++++++++++++++++++++++++----- pytato/tags.py | 3 ++- test/test_pytato.py | 4 ++-- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 11432184b..8757b537b 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1,4 +1,5 @@ from __future__ import annotations +from traceback import FrameSummary, StackSummary __copyright__ = """ Copyright (C) 2020 Andreas Kloeckner @@ -1675,15 +1676,35 @@ def _get_default_axes(ndim: int) -> AxesT: return tuple(Axis(frozenset()) for _ in range(ndim)) +@dataclass(frozen=True, eq=True) +class _PytatoFrameSummary: + filename: str + lineno: int + name: str + line: str + + +class _PytatoStackSummary(Tag): + def __init__(self, stack_summary: StackSummary) -> None: + self.frames: List[_PytatoFrameSummary] = [] + for s in stack_summary: + pfs = _PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) + self.frames.append(pfs) + + def to_stacksummary(self) -> StackSummary: + frames = [] + for f in self.frames: + frames.append(FrameSummary(f.filename, f.lineno, f.name, line=f.line)) + + # type-ignore-reason: from_list also takes List[FrameSummary] + return StackSummary.from_list(frames) # type: ignore[arg-type] + + def _get_default_tags() -> TagsType: import traceback from pytato.tags import CreatedAt - # extract_stack returns a StackSummary, which is a list - # You can restore the StackSummary object by calling - # StackSummary.from_list(c.traceback) - stack_summary = traceback.extract_stack() - c = CreatedAt(tuple(tuple(t) for t in tuple(stack_summary))) + c = CreatedAt(_PytatoStackSummary(traceback.extract_stack())) return frozenset((c,)) diff --git a/pytato/tags.py b/pytato/tags.py index dd35c5dc1..4f2bdaef5 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -15,6 +15,7 @@ from dataclasses import dataclass from pytools.tag import Tag, UniqueTag, tag_dataclass from typing import Tuple, Any +from pytato.array import _PytatoStackSummary # {{{ pre-defined tag: ImplementationStrategy @@ -115,4 +116,4 @@ class CreatedAt(UniqueTag): of where it was created. """ - traceback: Tuple[Tuple[Any, ...], ...] + traceback: _PytatoStackSummary diff --git a/test/test_pytato.py b/test/test_pytato.py index 523aa5639..0e090ff0d 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -803,8 +803,8 @@ def test_created_at(): # Make sure the function name appears in the traceback for tag in res.tags: if isinstance(tag, CreatedAt): - for line in tag.traceback: - if line[2] == "test_created_at": + for frame in tag.traceback.frames: + if frame.name == "test_created_at": found = True break From 770255024b439f1e7d157d18d5e1607fb6dcc783 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 14:36:42 -0500 Subject: [PATCH 010/124] also test to_stacksummary --- test/test_pytato.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index 0e090ff0d..817aad5fa 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -803,6 +803,7 @@ def test_created_at(): # Make sure the function name appears in the traceback for tag in res.tags: if isinstance(tag, CreatedAt): + _unused = tag.traceback.to_stacksummary() for frame in tag.traceback.frames: if frame.name == "test_created_at": found = True From 7a8655770198d0902f3331479e0775da514f4956 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 14:41:18 -0500 Subject: [PATCH 011/124] flake8 --- pytato/tags.py | 1 - test/test_pytato.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/tags.py b/pytato/tags.py index 4f2bdaef5..77d2fe218 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -14,7 +14,6 @@ from dataclasses import dataclass from pytools.tag import Tag, UniqueTag, tag_dataclass -from typing import Tuple, Any from pytato.array import _PytatoStackSummary diff --git a/test/test_pytato.py b/test/test_pytato.py index 817aad5fa..d642c7486 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -803,7 +803,7 @@ def test_created_at(): # Make sure the function name appears in the traceback for tag in res.tags: if isinstance(tag, CreatedAt): - _unused = tag.traceback.to_stacksummary() + _unused = tag.traceback.to_stacksummary() # noqa for frame in tag.traceback.frames: if frame.name == "test_created_at": found = True From 31bcda8d3406fbf2d1f4b36a43c1b2e26bf78104 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 28 Mar 2022 15:07:19 -0500 Subject: [PATCH 012/124] Add remove_tags_of_type --- pytato/transform.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pytato/transform.py b/pytato/transform.py index 4af6f8ab3..a4d3a1f88 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -74,6 +74,7 @@ .. autofunction:: copy_dict_of_named_arrays .. autofunction:: get_dependencies .. autofunction:: map_and_copy +.. autofunction:: remove_tags_of_type .. autofunction:: materialize_with_mpms Dict representation of DAGs @@ -1031,6 +1032,21 @@ def map_and_copy(expr: ArrayOrNames, return CachedMapAndCopyMapper(map_fn)(expr) +def remove_tags_of_type(tag_types: Union[type, Tuple[type]], expr: ArrayOrNames + ) -> ArrayOrNames: + def process_node(expr: ArrayOrNames) -> ArrayOrNames: + if isinstance(expr, Array): + return expr.copy(tags=frozenset({ + tag for tag in expr.tags + if not isinstance(tag, tag_types)})) + elif isinstance(expr, AbstractResultWithNamedArrays): + return expr + else: + raise AssertionError() + + return map_and_copy(expr, process_node) + + def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: r""" Materialize nodes in *expr* with MPMS materialization strategy. From 5c3222229519d570b9da0d89379c76784779c418 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Mon, 28 Mar 2022 15:07:38 -0500 Subject: [PATCH 013/124] test_array_dot_repr: Remove CreatedAt tags before comparing --- test/test_pytato.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index d642c7486..c4f3c6e6e 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -27,6 +27,8 @@ import sys +from typing import cast + import numpy as np import pytest @@ -462,7 +464,12 @@ def test_array_dot_repr(): x = pt.make_placeholder("x", (10, 4), np.int64) y = pt.make_placeholder("y", (10, 4), np.int64) + from pytato.transform import remove_tags_of_type + from pytato.tags import CreatedAt + def _assert_stripped_repr(ary: pt.Array, expected_repr: str): + ary = cast(pt.Array, remove_tags_of_type(CreatedAt, ary)) + expected_str = "".join([c for c in repr(ary) if c not in [" ", "\n"]]) result_str = "".join([c for c in expected_repr if c not in [" ", "\n"]]) assert expected_str == result_str From 2d0358fbdd4ea298022be4587eb919cf6a179569 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 15:09:29 -0500 Subject: [PATCH 014/124] only add CreatedAt in debug mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Andreas Klöckner --- pytato/array.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 8757b537b..a59414312 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1704,8 +1704,9 @@ def _get_default_tags() -> TagsType: import traceback from pytato.tags import CreatedAt - c = CreatedAt(_PytatoStackSummary(traceback.extract_stack())) - return frozenset((c,)) + if __debug__: + c = CreatedAt(_PytatoStackSummary(traceback.extract_stack())) + return frozenset((c,)) def matmul(x1: Array, x2: Array) -> Array: From 1c275fd80cca82a6f584f7b20b33549c3ff977be Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 15:26:10 -0500 Subject: [PATCH 015/124] restructure test_created_at --- test/test_pytato.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index c4f3c6e6e..bd4dbdd26 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -805,16 +805,20 @@ def test_created_at(): from pytato.tags import CreatedAt + created_tag = res.tags_of_type(CreatedAt) + + assert len(created_tag) == 1 + + tag, = created_tag + found = False # Make sure the function name appears in the traceback - for tag in res.tags: - if isinstance(tag, CreatedAt): - _unused = tag.traceback.to_stacksummary() # noqa - for frame in tag.traceback.frames: - if frame.name == "test_created_at": - found = True - break + _unused = tag.traceback.to_stacksummary() # noqa + for frame in tag.traceback.frames: + if frame.name == "test_created_at": + found = True + break assert found From 02362bfaf8f6d6da00e2b2e639b9b6773866aa14 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 15:27:14 -0500 Subject: [PATCH 016/124] make _PytatoStackSummary a dataclass --- pytato/array.py | 46 +++++++++++++++++++++++++++++++++++---------- pytato/transform.py | 2 +- 2 files changed, 37 insertions(+), 11 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index a59414312..3ed376caf 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1678,35 +1678,61 @@ def _get_default_axes(ndim: int) -> AxesT: @dataclass(frozen=True, eq=True) class _PytatoFrameSummary: + """Class to store a single call frame.""" filename: str lineno: int name: str line: str + def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: + key_builder.rec(key_hash, + (self.__class__.__module__, self.__class__.__qualname__)) -class _PytatoStackSummary(Tag): - def __init__(self, stack_summary: StackSummary) -> None: - self.frames: List[_PytatoFrameSummary] = [] - for s in stack_summary: - pfs = _PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) - self.frames.append(pfs) + from dataclasses import fields + # Fields are ordered consistently, so ordered hashing is OK. + # + # No need to dispatch to superclass: fields() automatically gives us + # fields from the entire class hierarchy. + for f in fields(self): + key_builder.rec(key_hash, getattr(self, f.name)) + + +@dataclass(frozen=True, eq=True) +class _PytatoStackSummary: + """Class to store a list of :class:`_PytatoFrameSummary` call frames.""" + frames: Tuple[_PytatoFrameSummary, ...] def to_stacksummary(self) -> StackSummary: - frames = [] - for f in self.frames: - frames.append(FrameSummary(f.filename, f.lineno, f.name, line=f.line)) + frames = [FrameSummary(f.filename, f.lineno, f.name, line=f.line) + for f in self.frames] # type-ignore-reason: from_list also takes List[FrameSummary] return StackSummary.from_list(frames) # type: ignore[arg-type] + def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: + key_builder.rec(key_hash, + (self.__class__.__module__, self.__class__.__qualname__)) + + from dataclasses import fields + # Fields are ordered consistently, so ordered hashing is OK. + # + # No need to dispatch to superclass: fields() automatically gives us + # fields from the entire class hierarchy. + for f in fields(self): + key_builder.rec(key_hash, getattr(self, f.name)) + def _get_default_tags() -> TagsType: import traceback from pytato.tags import CreatedAt if __debug__: - c = CreatedAt(_PytatoStackSummary(traceback.extract_stack())) + frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) + for s in traceback.extract_stack()) + c = CreatedAt(_PytatoStackSummary(frames)) return frozenset((c,)) + else: + return frozenset() def matmul(x1: Array, x2: Array) -> Array: diff --git a/pytato/transform.py b/pytato/transform.py index a4d3a1f88..fe39e1243 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -1033,7 +1033,7 @@ def map_and_copy(expr: ArrayOrNames, def remove_tags_of_type(tag_types: Union[type, Tuple[type]], expr: ArrayOrNames - ) -> ArrayOrNames: + ) -> ArrayOrNames: def process_node(expr: ArrayOrNames) -> ArrayOrNames: if isinstance(expr, Array): return expr.copy(tags=frozenset({ From 7b1f7b81bbb33c0567285f3e8f166d95ef4a180a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 22:11:01 -0500 Subject: [PATCH 017/124] add __repr__ --- pytato/array.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pytato/array.py b/pytato/array.py index 3ed376caf..37cc653bf 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1696,6 +1696,9 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: for f in fields(self): key_builder.rec(key_hash, getattr(self, f.name)) + def __repr__(self) -> str: + return f"{self.filename}:{self.lineno}, in {self.name}: {self.line}" + @dataclass(frozen=True, eq=True) class _PytatoStackSummary: @@ -1721,6 +1724,9 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: for f in fields(self): key_builder.rec(key_hash, getattr(self, f.name)) + def __repr__(self) -> str: + return "\n " + "\n ".join([str(f) for f in self.frames]) + def _get_default_tags() -> TagsType: import traceback From 07b2fa16c92f5b9b9c2c923ed65dd0e61e502e6d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 23:00:20 -0500 Subject: [PATCH 018/124] fix 2 tests --- pytato/transform.py | 2 +- test/test_pytato.py | 28 ++++++++++++++++++++++++---- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/pytato/transform.py b/pytato/transform.py index fe39e1243..adb9f1346 100644 --- a/pytato/transform.py +++ b/pytato/transform.py @@ -1042,7 +1042,7 @@ def process_node(expr: ArrayOrNames) -> ArrayOrNames: elif isinstance(expr, AbstractResultWithNamedArrays): return expr else: - raise AssertionError() + raise AssertionError(type(expr)) return map_and_copy(expr, process_node) diff --git a/test/test_pytato.py b/test/test_pytato.py index bd4dbdd26..f227c8a75 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -297,6 +297,14 @@ def test_dict_of_named_arrays_comparison(): dict2 = pt.make_dict_of_named_arrays({"out": 2 * x}) dict3 = pt.make_dict_of_named_arrays({"not_out": 2 * x}) dict4 = pt.make_dict_of_named_arrays({"out": 3 * x}) + + from pytato.transform import remove_tags_of_type + from pytato.tags import CreatedAt + dict1 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict1)) + dict2 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict2)) + dict3 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict3)) + dict4 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict4)) + assert dict1 == dict2 assert dict1 != dict3 assert dict1 != dict4 @@ -626,10 +634,22 @@ def test_rec_get_user_nodes(): expr = pt.make_dict_of_named_arrays({"out1": 2 * x1, "out2": 7 * x1 + 3 * x2}) - assert (pt.transform.rec_get_user_nodes(expr, x1) - == frozenset({2 * x1, 7*x1, 7*x1 + 3 * x2, expr})) - assert (pt.transform.rec_get_user_nodes(expr, x2) - == frozenset({3 * x2, 7*x1 + 3 * x2, expr})) + t1 = pt.transform.rec_get_user_nodes(expr, x1) + t1r = frozenset({2 * x1, 7*x1, 7*x1 + 3 * x2, expr}) + + t2 = pt.transform.rec_get_user_nodes(expr, x2) + t2r = frozenset({3 * x2, 7*x1 + 3 * x2, expr}) + + from pytato.transform import remove_tags_of_type + from pytato.tags import CreatedAt + + t1 = frozenset({remove_tags_of_type(CreatedAt, t) for t in t1}) + t1r = frozenset({remove_tags_of_type(CreatedAt, t) for t in t1r}) + t2 = frozenset({remove_tags_of_type(CreatedAt, t) for t in t2}) + t2r = frozenset({remove_tags_of_type(CreatedAt, t) for t in t2r}) + + assert (t1 == t1r) + assert (t2 == t2r) def test_rec_get_user_nodes_linear_complexity(): From 437954cbb203793890c16616bbcabab0757f04a2 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 28 Mar 2022 23:24:00 -0500 Subject: [PATCH 019/124] illustrate test failure with construct_intestine_graph --- test/test_pytato.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index f227c8a75..ac8fd106b 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -707,9 +707,29 @@ def post_visit(self, expr): expected_result[expr] = {"foo"} expr, inp = construct_intestine_graph() + + from pytato.transform import remove_tags_of_type + from pytato.tags import CreatedAt + # node_to_users = remove_tags_of_type(CreatedAt, user_collector.node_to_users) + + node_to_users = {} + + for k in user_collector.node_to_users.keys(): + new_key = remove_tags_of_type(CreatedAt, k) + new_values = set({remove_tags_of_type(CreatedAt, v) for v in user_collector.node_to_users[k]}) + + node_to_users[new_key] = new_values + + + result = pt.transform.tag_user_nodes(user_collector.node_to_users, "foo", inp) ExpectedResultComputer()(expr) + import pudb + pu.db + + + assert expected_result == result From f05592e87cea001c195714fe4a7c59deac0fb28f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 29 Mar 2022 10:44:55 -0500 Subject: [PATCH 020/124] shorten traceback printing --- pytato/array.py | 28 +++++++++++++++++++++++++++- pytato/tags.py | 3 +++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index 37cc653bf..08e7e4bea 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1697,7 +1697,7 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: key_builder.rec(key_hash, getattr(self, f.name)) def __repr__(self) -> str: - return f"{self.filename}:{self.lineno}, in {self.name}: {self.line}" + return f"{self.filename}:{self.lineno}, in {self.name}(): {self.line}" @dataclass(frozen=True, eq=True) @@ -1724,6 +1724,32 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: for f in fields(self): key_builder.rec(key_hash, getattr(self, f.name)) + def __str__(self) -> str: + from os.path import dirname + + res = None + + # Find the first file in the frames that it is not in pytato's pytato/ + # directory. + for idx, frame in enumerate(reversed(self.frames)): + frame_dir = dirname(frame.filename) + if not frame_dir.endswith("pytato"): + res = str(frame) + + # Indicate whether frames were omitted + if idx < len(self.frames)-1: + res += " ..." + if idx > 0: + res = "... " + res + break + + if not res: + # Fallback in case we don't find any file that is not in the pytato/ + # directory (should be unlikely). + return self.__repr__() + + return res + def __repr__(self) -> str: return "\n " + "\n ".join([str(f) for f in self.frames]) diff --git a/pytato/tags.py b/pytato/tags.py index 77d2fe218..36b28d588 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -116,3 +116,6 @@ class CreatedAt(UniqueTag): """ traceback: _PytatoStackSummary + + def __repr__(self) -> str: + return "CreatedAt(" + str(self.traceback) + ")" From d0409968304c553f7c62a28463f74881e4e8a4c3 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 30 Mar 2022 11:38:35 -0500 Subject: [PATCH 021/124] use separate field for CreatedAt --- pytato/array.py | 32 ++++++++++++++------------------ pytato/visualization.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 43 insertions(+), 19 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 08e7e4bea..074775b5d 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1696,6 +1696,14 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: for f in fields(self): key_builder.rec(key_hash, getattr(self, f.name)) + def short_str(self) -> str: + s = f"{self.filename}:{self.lineno}, in {self.name}():\n{self.line}" + s1, s2 = s.split("\n") + # Limit display to 35 characters + s1 = "[...] " + s1[len(s1)-35:] if len(s1) > 35 else s1 + s2 = s2[:35] + " [...]" if len(s2) > 35 else s2 + return s1 + "\n" + s2 + def __repr__(self) -> str: return f"{self.filename}:{self.lineno}, in {self.name}(): {self.line}" @@ -1724,31 +1732,19 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: for f in fields(self): key_builder.rec(key_hash, getattr(self, f.name)) - def __str__(self) -> str: + def short_str(self) -> str: from os.path import dirname - res = None - # Find the first file in the frames that it is not in pytato's pytato/ # directory. - for idx, frame in enumerate(reversed(self.frames)): + for frame in reversed(self.frames): frame_dir = dirname(frame.filename) if not frame_dir.endswith("pytato"): - res = str(frame) - - # Indicate whether frames were omitted - if idx < len(self.frames)-1: - res += " ..." - if idx > 0: - res = "... " + res - break - - if not res: - # Fallback in case we don't find any file that is not in the pytato/ - # directory (should be unlikely). - return self.__repr__() + return frame.short_str() - return res + # Fallback in case we don't find any file that is not in the pytato/ + # directory (should be unlikely). + return self.__repr__() def __repr__(self) -> str: return "\n " + "\n ".join([str(f) for f in self.frames]) diff --git a/pytato/visualization.py b/pytato/visualization.py index 573aa1d04..d755afe6b 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -73,13 +73,40 @@ class DotNodeInfo: edges: Dict[str, ArrayOrNames] +def stringify_created_at(tags: TagsType) -> str: + from pytato.tags import CreatedAt + for tag in tags: + if isinstance(tag, CreatedAt): + return tag.traceback.short_str() + + return "" + + def stringify_tags(tags: TagsType) -> str: + # The CreatedAt tag is handled in stringify_created_at() + from pytato.tags import CreatedAt + tags = set(tag for tag in tags if not isinstance(tag, CreatedAt)) + components = sorted(str(elem) for elem in tags) return "{" + ", ".join(components) + "}" def stringify_shape(shape: ShapeType) -> str: - components = [str(elem) for elem in shape] + from pytato.tags import CreatedAt + from pytato import SizeParam + + new_elems = set() + for elem in shape: + # Remove CreatedAt tags from SizeParam + if isinstance(elem, SizeParam): + new_elem = elem.copy( + tags=frozenset(tag for tag in elem.tags + if not isinstance(tag, CreatedAt))) + new_elems.add(new_elem) + else: + new_elems.add(elem) + + components = [str(elem) for elem in new_elems] if not components: components = [","] elif len(components) == 1: @@ -95,6 +122,7 @@ def __init__(self) -> None: def get_common_dot_info(self, expr: Array) -> DotNodeInfo: title = type(expr).__name__ fields = dict(addr=hex(id(expr)), + created_at=stringify_created_at(expr.tags), shape=stringify_shape(expr.shape), dtype=str(expr.dtype), tags=stringify_tags(expr.tags)) From 9fdd602a66f82f4be67175bc85bde4289f318eab Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 30 Mar 2022 13:09:39 -0500 Subject: [PATCH 022/124] fix tests --- pytato/visualization.py | 10 +++++----- test/test_pytato.py | 38 +++++++++++++++++--------------------- 2 files changed, 22 insertions(+), 26 deletions(-) diff --git a/pytato/visualization.py b/pytato/visualization.py index d755afe6b..68b397067 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -85,7 +85,7 @@ def stringify_created_at(tags: TagsType) -> str: def stringify_tags(tags: TagsType) -> str: # The CreatedAt tag is handled in stringify_created_at() from pytato.tags import CreatedAt - tags = set(tag for tag in tags if not isinstance(tag, CreatedAt)) + tags = frozenset(tag for tag in tags if not isinstance(tag, CreatedAt)) components = sorted(str(elem) for elem in tags) return "{" + ", ".join(components) + "}" @@ -97,14 +97,14 @@ def stringify_shape(shape: ShapeType) -> str: new_elems = set() for elem in shape: - # Remove CreatedAt tags from SizeParam - if isinstance(elem, SizeParam): + if not isinstance(elem, SizeParam): + new_elems.add(elem) + else: + # Remove CreatedAt tags from SizeParam new_elem = elem.copy( tags=frozenset(tag for tag in elem.tags if not isinstance(tag, CreatedAt))) new_elems.add(new_elem) - else: - new_elems.add(elem) components = [str(elem) for elem in new_elems] if not components: diff --git a/test/test_pytato.py b/test/test_pytato.py index ac8fd106b..38f937229 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -399,7 +399,7 @@ def test_linear_complexity_inequality(): from pytato.equality import EqualityComparer from numpy.random import default_rng - def construct_intestine_graph(depth=100, seed=0): + def construct_intestine_graph(depth=90, seed=0): rng = default_rng(seed) x = pt.make_placeholder("x", shape=(10,), dtype=float) @@ -413,6 +413,13 @@ def construct_intestine_graph(depth=100, seed=0): graph2 = construct_intestine_graph() graph3 = construct_intestine_graph(seed=3) + from pytato.transform import remove_tags_of_type + from pytato.tags import CreatedAt + + graph1 = remove_tags_of_type(CreatedAt, graph1) + graph2 = remove_tags_of_type(CreatedAt, graph2) + graph3 = remove_tags_of_type(CreatedAt, graph3) + assert EqualityComparer()(graph1, graph2) assert EqualityComparer()(graph2, graph1) assert not EqualityComparer()(graph1, graph3) @@ -685,7 +692,7 @@ def post_visit(self, expr): def test_tag_user_nodes_linear_complexity(): from numpy.random import default_rng - def construct_intestine_graph(depth=100, seed=0): + def construct_intestine_graph(depth=90, seed=0): rng = default_rng(seed) x = pt.make_placeholder("x", shape=(10,), dtype=float) y = x @@ -696,7 +703,13 @@ def construct_intestine_graph(depth=100, seed=0): return y, x + from pytato.transform import remove_tags_of_type + from pytato.tags import CreatedAt + expr, inp = construct_intestine_graph() + expr = remove_tags_of_type(CreatedAt, expr) + inp = remove_tags_of_type(CreatedAt, inp) + user_collector = pt.transform.UsersCollector() user_collector(expr) @@ -707,29 +720,12 @@ def post_visit(self, expr): expected_result[expr] = {"foo"} expr, inp = construct_intestine_graph() - - from pytato.transform import remove_tags_of_type - from pytato.tags import CreatedAt - # node_to_users = remove_tags_of_type(CreatedAt, user_collector.node_to_users) - - node_to_users = {} - - for k in user_collector.node_to_users.keys(): - new_key = remove_tags_of_type(CreatedAt, k) - new_values = set({remove_tags_of_type(CreatedAt, v) for v in user_collector.node_to_users[k]}) - - node_to_users[new_key] = new_values - - + expr = remove_tags_of_type(CreatedAt, expr) + inp = remove_tags_of_type(CreatedAt, inp) result = pt.transform.tag_user_nodes(user_collector.node_to_users, "foo", inp) ExpectedResultComputer()(expr) - import pudb - pu.db - - - assert expected_result == result From e606a4823264a936d04513cce99dddfde033a1d3 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 30 Mar 2022 15:24:57 -0500 Subject: [PATCH 023/124] fix doctest --- pytato/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytato/utils.py b/pytato/utils.py index fe48c1368..8bd32139d 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -242,6 +242,8 @@ def dim_to_index_lambda_components(expr: ShapeComponent, .. testsetup:: >>> import pytato as pt + >>> from pytato.transform import remove_tags_of_type + >>> from pytato.tags import CreatedAt >>> from pytato.utils import dim_to_index_lambda_components >>> from pytools import UniqueNameGenerator @@ -251,7 +253,7 @@ def dim_to_index_lambda_components(expr: ShapeComponent, >>> expr, bnds = dim_to_index_lambda_components(3*n+8, UniqueNameGenerator()) >>> print(expr) 3*_in + 8 - >>> bnds + >>> {"_in": remove_tags_of_type(CreatedAt, bnds["_in"])} {'_in': SizeParam(name='n')} """ if isinstance(expr, INT_CLASSES): From 235d9a72c009d8765bed92fc8130fd6dbad79f94 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 30 Mar 2022 17:25:35 -0500 Subject: [PATCH 024/124] make it a tag again --- pytato/tags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/tags.py b/pytato/tags.py index 36b28d588..6c6f773f6 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -109,7 +109,7 @@ class AssumeNonNegative(Tag): # https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. @dataclass(init=True, eq=True, frozen=True, repr=True) -class CreatedAt(UniqueTag): +class CreatedAt(Tag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From d80066f45b80130f54e105f41a5f1520187d0381 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 1 Apr 2022 10:36:00 -0500 Subject: [PATCH 025/124] use tooltip instead of table row --- pytato/array.py | 12 ++++++------ pytato/visualization.py | 11 +++++++---- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 074775b5d..e6546edb7 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1696,12 +1696,12 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: for f in fields(self): key_builder.rec(key_hash, getattr(self, f.name)) - def short_str(self) -> str: + def short_str(self, maxlen: int = 100) -> str: s = f"{self.filename}:{self.lineno}, in {self.name}():\n{self.line}" s1, s2 = s.split("\n") - # Limit display to 35 characters - s1 = "[...] " + s1[len(s1)-35:] if len(s1) > 35 else s1 - s2 = s2[:35] + " [...]" if len(s2) > 35 else s2 + # Limit display to maxlen characters + s1 = "[...] " + s1[len(s1)-maxlen:] if len(s1) > maxlen else s1 + s2 = s2[:maxlen] + " [...]" if len(s2) > maxlen else s2 return s1 + "\n" + s2 def __repr__(self) -> str: @@ -1732,7 +1732,7 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: for f in fields(self): key_builder.rec(key_hash, getattr(self, f.name)) - def short_str(self) -> str: + def short_str(self, maxlen: int = 100) -> str: from os.path import dirname # Find the first file in the frames that it is not in pytato's pytato/ @@ -1740,7 +1740,7 @@ def short_str(self) -> str: for frame in reversed(self.frames): frame_dir = dirname(frame.filename) if not frame_dir.endswith("pytato"): - return frame.short_str() + return frame.short_str(maxlen) # Fallback in case we don't find any file that is not in the pytato/ # directory (should be unlikely). diff --git a/pytato/visualization.py b/pytato/visualization.py index 68b397067..498a0d58a 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -293,8 +293,10 @@ def _emit_array(emit: DotEmitter, title: str, fields: Dict[str, str], td_attrib = 'border="0"' table_attrib = 'border="0" cellborder="1" cellspacing="0"' - rows = ['%s' - % (td_attrib, dot_escape(title))] + rows = [f"{dot_escape(title)}"] + + created_at = fields.pop("created_at", "") + tooltip = dot_escape(created_at) for name, field in fields.items(): field_content = dot_escape(field).replace("\n", "
") @@ -302,8 +304,9 @@ def _emit_array(emit: DotEmitter, title: str, fields: Dict[str, str], f"{dot_escape(name)}:" f"{field_content}" ) - table = "\n%s
" % (table_attrib, "".join(rows)) - emit("%s [label=<%s> style=filled fillcolor=%s]" % (dot_node_id, table, color)) + table = f"\n{''.join(rows)}
" + emit(f"{dot_node_id} [label=<{table}> style=filled fillcolor={color} " + f'tooltip="{tooltip}"]') def _emit_name_cluster(emit: DotEmitter, names: Mapping[str, ArrayOrNames], From ef3339f906224f6fefbc04834e0dff6c534191e5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 1 Apr 2022 11:21:04 -0500 Subject: [PATCH 026/124] force openmpi usage --- .test-conda-env-py3.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.test-conda-env-py3.yml b/.test-conda-env-py3.yml index 6459d4c06..7917877fb 100644 --- a/.test-conda-env-py3.yml +++ b/.test-conda-env-py3.yml @@ -12,3 +12,4 @@ dependencies: - islpy - sphinx-autodoc-typehints - mpi4py +- openmpi From 1ff1a2b49f8c867bab87d00907ba3271a11a048e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 1 Apr 2022 15:00:04 -0500 Subject: [PATCH 027/124] check for existing CreatedAt and make it a UniqueTag again --- pytato/array.py | 14 ++++++++------ pytato/tags.py | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index e6546edb7..9a18f4081 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1735,7 +1735,7 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: def short_str(self, maxlen: int = 100) -> str: from os.path import dirname - # Find the first file in the frames that it is not in pytato's pytato/ + # Find the first file in the frames that is not in pytato's pytato/ # directory. for frame in reversed(self.frames): frame_dir = dirname(frame.filename) @@ -1750,11 +1750,13 @@ def __repr__(self) -> str: return "\n " + "\n ".join([str(f) for f in self.frames]) -def _get_default_tags() -> TagsType: +def _get_default_tags(existing_tags: Optional[TagsType] = None) -> TagsType: import traceback from pytato.tags import CreatedAt - if __debug__: + if __debug__ and ( + existing_tags is None + or not any(isinstance(tag, CreatedAt) for tag in existing_tags)): frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) for s in traceback.extract_stack()) c = CreatedAt(_PytatoStackSummary(frames)) @@ -2015,7 +2017,7 @@ def make_placeholder(name: str, f" expected {len(shape)}, got {len(axes)}.") return Placeholder(name, shape, dtype, axes=axes, - tags=(tags | _get_default_tags())) + tags=(tags | _get_default_tags(tags))) def make_size_param(name: str, @@ -2029,7 +2031,7 @@ def make_size_param(name: str, :param tags: implementation tags """ _check_identifier(name, optional=False) - return SizeParam(name, tags=(tags | _get_default_tags())) + return SizeParam(name, tags=(tags | _get_default_tags(tags))) def make_data_wrapper(data: DataInterface, @@ -2059,7 +2061,7 @@ def make_data_wrapper(data: DataInterface, return DataWrapper(name, data, shape, axes=axes, - tags=(tags | _get_default_tags())) + tags=(tags | _get_default_tags(tags))) # }}} diff --git a/pytato/tags.py b/pytato/tags.py index 6c6f773f6..36b28d588 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -109,7 +109,7 @@ class AssumeNonNegative(Tag): # https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. @dataclass(init=True, eq=True, frozen=True, repr=True) -class CreatedAt(Tag): +class CreatedAt(UniqueTag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From c57a4a17b2757e74cff63fd54ec806dc6acb594d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 16 May 2022 10:50:51 -0500 Subject: [PATCH 028/124] flake8 --- pytato/visualization.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/visualization.py b/pytato/visualization.py index 91af34607..d5693ed5a 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -31,10 +31,9 @@ import html from typing import (TYPE_CHECKING, Callable, Dict, Union, Iterator, List, - Mapping, Hashable, Any, FrozenSet) + Mapping, Hashable, Any) from pytools import UniqueNameGenerator -from pytools.tag import Tag from pytools.codegen import CodeGenerator as CodeGeneratorBase from pytato.loopy import LoopyCall From 4ae31b1ef8416e152899fe5f0807a8dc560655a1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 16 May 2022 10:54:53 -0500 Subject: [PATCH 029/124] add simple equality test --- test/test_pytato.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index 28f1a723e..f2ad7d9b4 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -844,6 +844,10 @@ def test_created_at(): b = pt.make_placeholder("b", (10, 10), "float64") res = a+b + res2 = a+b + + # CreatedAt tags need to be filtered for equality to work correctly. + assert res == res2 from pytato.tags import CreatedAt From f559a59bf35a1d77a7445c3ca5b8bb633aa64849 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 16 May 2022 13:17:32 -0500 Subject: [PATCH 030/124] lint fixes --- pytato/array.py | 9 +++++---- pytato/visualization.py | 7 ++++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 36e5de2e9..2ed902a46 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -258,7 +258,7 @@ def normalize_shape_component( # }}} -# {{{ array inteface +# {{{ array interface ConvertibleToIndexExpr = Union[int, slice, "Array", None, EllipsisType] IndexExpr = Union[IntegralT, "NormalizedSlice", "Array", None, EllipsisType] @@ -1682,9 +1682,9 @@ def _get_default_axes(ndim: int) -> AxesT: class _PytatoFrameSummary: """Class to store a single call frame.""" filename: str - lineno: int + lineno: Optional[int] name: str - line: str + line: Optional[str] def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: key_builder.rec(key_hash, @@ -1752,7 +1752,8 @@ def __repr__(self) -> str: return "\n " + "\n ".join([str(f) for f in self.frames]) -def _get_default_tags(existing_tags: Optional[TagsType] = None) -> TagsType: +def _get_default_tags(existing_tags: Optional[FrozenSet[Tag]] = None) \ + -> FrozenSet[Tag]: import traceback from pytato.tags import CreatedAt diff --git a/pytato/visualization.py b/pytato/visualization.py index d5693ed5a..92215c44d 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -31,7 +31,7 @@ import html from typing import (TYPE_CHECKING, Callable, Dict, Union, Iterator, List, - Mapping, Hashable, Any) + Mapping, Hashable, Any, FrozenSet) from pytools import UniqueNameGenerator from pytools.codegen import CodeGenerator as CodeGeneratorBase @@ -44,6 +44,7 @@ from pytato.codegen import normalize_outputs from pytato.transform import CachedMapper, ArrayOrNames +from pytools.tag import Tag from pytato.partition import GraphPartition from pytato.distributed import DistributedGraphPart @@ -72,7 +73,7 @@ class DotNodeInfo: edges: Dict[str, ArrayOrNames] -def stringify_created_at(tags: TagsType) -> str: +def stringify_created_at(tags: FrozenSet[Tag]) -> str: from pytato.tags import CreatedAt for tag in tags: if isinstance(tag, CreatedAt): @@ -81,7 +82,7 @@ def stringify_created_at(tags: TagsType) -> str: return "" -def stringify_tags(tags: TagsType) -> str: +def stringify_tags(tags: FrozenSet[Tag]) -> str: # The CreatedAt tag is handled in stringify_created_at() from pytato.tags import CreatedAt tags = frozenset(tag for tag in tags if not isinstance(tag, CreatedAt)) From 0794bdb4fd986549bce4ed6f463e31f34f995dec Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 16 May 2022 13:51:00 -0500 Subject: [PATCH 031/124] add InfoTag class and filter tags based on it --- pytato/array.py | 7 ++++++- pytato/equality.py | 46 +++++++++++++++++++++++++++++++--------------- pytato/tags.py | 7 ++++++- 3 files changed, 43 insertions(+), 17 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 2ed902a46..a257da153 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -480,9 +480,12 @@ def T(self) -> Array: @memoize_method def __hash__(self) -> int: + from pytato.equality import preprocess_tags_for_equality attrs = [] for field in self._fields: attr = getattr(self, field) + if field == "tags": + attr = preprocess_tags_for_equality(attr) if isinstance(attr, dict): attr = frozenset(attr.items()) attrs.append(attr) @@ -1590,7 +1593,9 @@ def __init__(self, self._shape = shape def __hash__(self) -> int: - return id(self) + from pytato.equality import preprocess_tags_for_equality + return hash((self.name, id(self.data), self._shape, self.axes, + preprocess_tags_for_equality(self.tags), self.axes)) def __eq__(self, other: Any) -> bool: return self is other diff --git a/pytato/equality.py b/pytato/equality.py index 984d1f712..8f438b791 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -24,7 +24,7 @@ THE SOFTWARE. """ -from typing import Any, Callable, Dict, TYPE_CHECKING, Tuple, Union +from typing import Any, Callable, Dict, TYPE_CHECKING, Tuple, Union, FrozenSet from pytato.array import (AdvancedIndexInContiguousAxes, AdvancedIndexInNoncontiguousAxes, AxisPermutation, BasicIndex, Concatenate, DataWrapper, Einsum, @@ -32,11 +32,16 @@ Reshape, Roll, Stack, AbstractResultWithNamedArrays, Array, DictOfNamedArrays, Placeholder, SizeParam) +from pytools.tag import Tag +from pytato.tags import InfoTag + if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult from pytato.distributed import DistributedRecv, DistributedSendRefHolder + __doc__ = """ +.. autofunction:: preprocess_tags_for_equality .. autoclass:: EqualityComparer """ @@ -44,6 +49,13 @@ ArrayOrNames = Union[Array, AbstractResultWithNamedArrays] +def preprocess_tags_for_equality(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: + """Remove tags of :class:`InfoTag` for equality comparison.""" + return frozenset(tag + for tag in tags + if not isinstance(tag, InfoTag)) + + # {{{ EqualityComparer class EqualityComparer: @@ -95,6 +107,10 @@ def handle_unsupported_array(self, expr1: Array, expr2: Any) -> bool: raise NotImplementedError(type(expr1).__name__) + def are_tags_equal(self, tags1: FrozenSet[Tag], tags2: FrozenSet[Tag]) -> bool: + return (preprocess_tags_for_equality(tags1) + == preprocess_tags_for_equality(tags2)) + def map_foreign(self, expr1: Any, expr2: Any) -> bool: raise NotImplementedError(type(expr1).__name__) @@ -103,14 +119,14 @@ def map_placeholder(self, expr1: Placeholder, expr2: Any) -> bool: and expr1.name == expr2.name and expr1.shape == expr2.shape and expr1.dtype == expr2.dtype - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) def map_size_param(self, expr1: SizeParam, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.name == expr2.name - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) @@ -129,7 +145,7 @@ def map_index_lambda(self, expr1: IndexLambda, expr2: Any) -> bool: if isinstance(dim1, Array) else dim1 == dim2 for dim1, dim2 in zip(expr1.shape, expr2.shape)) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes) def map_stack(self, expr1: Stack, expr2: Any) -> bool: @@ -138,7 +154,7 @@ def map_stack(self, expr1: Stack, expr2: Any) -> bool: and len(expr1.arrays) == len(expr2.arrays) and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) @@ -148,7 +164,7 @@ def map_concatenate(self, expr1: Concatenate, expr2: Any) -> bool: and len(expr1.arrays) == len(expr2.arrays) and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) @@ -156,7 +172,7 @@ def map_roll(self, expr1: Roll, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.axis == expr2.axis and self.rec(expr1.array, expr2.array) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) @@ -164,7 +180,7 @@ def map_axis_permutation(self, expr1: AxisPermutation, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.axis_permutation == expr2.axis_permutation and self.rec(expr1.array, expr2.array) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) @@ -177,7 +193,7 @@ def _map_index_base(self, expr1: IndexBase, expr2: Any) -> bool: and isinstance(idx2, Array)) else idx1 == idx2 for idx1, idx2 in zip(expr1.indices, expr2.indices)) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) @@ -200,7 +216,7 @@ def map_reshape(self, expr1: Reshape, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.newshape == expr2.newshape and self.rec(expr1.array, expr2.array) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) @@ -210,14 +226,14 @@ def map_einsum(self, expr1: Einsum, expr2: Any) -> bool: and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.args, expr2.args)) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes ) def map_named_array(self, expr1: NamedArray, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and self.rec(expr1._container, expr2._container) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes and expr1.name == expr2.name) @@ -236,7 +252,7 @@ def map_loopy_call(self, expr1: LoopyCall, expr2: Any) -> bool: def map_loopy_call_result(self, expr1: LoopyCallResult, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and self.rec(expr1._container, expr2._container) - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) and expr1.axes == expr2.axes and expr1.name == expr2.name) @@ -254,7 +270,7 @@ def map_distributed_send_ref_holder( and expr1.send.dest_rank == expr2.send.dest_rank and expr1.send.comm_tag == expr2.send.comm_tag and expr1.send.tags == expr2.send.tags - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) ) def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: @@ -263,7 +279,7 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: and expr1.comm_tag == expr2.comm_tag and expr1.shape == expr2.shape and expr1.dtype == expr2.dtype - and expr1.tags == expr2.tags + and self.are_tags_equal(expr1.tags, expr2.tags) ) # }}} diff --git a/pytato/tags.py b/pytato/tags.py index 36b28d588..b8e0ad7bd 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -105,11 +105,16 @@ class AssumeNonNegative(Tag): """ +class InfoTag(Tag): + """A type of tag whose value is purely informational and should not be used + for equality comparison.""" + + # See # https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. @dataclass(init=True, eq=True, frozen=True, repr=True) -class CreatedAt(UniqueTag): +class CreatedAt(UniqueTag, InfoTag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From 43c83ec2dfc2007537af12227ca56e2f31be5e60 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 16 May 2022 14:20:41 -0500 Subject: [PATCH 032/124] fix doc --- pytato/equality.py | 2 +- pytato/tags.py | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pytato/equality.py b/pytato/equality.py index 8f438b791..0dacbbb9f 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -50,7 +50,7 @@ def preprocess_tags_for_equality(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: - """Remove tags of :class:`InfoTag` for equality comparison.""" + """Remove tags of :class:`~pytato.tags.InfoTag` for equality comparison.""" return frozenset(tag for tag in tags if not isinstance(tag, InfoTag)) diff --git a/pytato/tags.py b/pytato/tags.py index b8e0ad7bd..0e480a273 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -9,6 +9,8 @@ .. autoclass:: Named .. autoclass:: PrefixNamed .. autoclass:: AssumeNonNegative +.. autoclass:: InfoTag +.. autoclass:: CreatedAt """ From c09bbf39d3665122b31afe612f1cc13a62b683c3 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 16 May 2022 17:47:08 -0500 Subject: [PATCH 033/124] another doc fix --- pytato/array.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytato/array.py b/pytato/array.py index a257da153..c0285a909 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -142,6 +142,9 @@ .. autoclass:: ReductionAxis .. autoclass:: NormalizedSlice +.. autoclass:: _PytatoFrameSummary +.. autoclass:: _PytatoStackSummary + Internal stuff that is only here because the documentation tool wants it ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From cd67d684d1c777bc3efef080031446e033e75f42 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 16 May 2022 21:20:21 -0500 Subject: [PATCH 034/124] use IgnoredForEqualityTag --- pytato/equality.py | 8 ++++---- pytato/tags.py | 10 ++-------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/pytato/equality.py b/pytato/equality.py index 0dacbbb9f..df48c85be 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -32,8 +32,7 @@ Reshape, Roll, Stack, AbstractResultWithNamedArrays, Array, DictOfNamedArrays, Placeholder, SizeParam) -from pytools.tag import Tag -from pytato.tags import InfoTag +from pytools.tag import Tag, IgnoredForEqualityTag if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult @@ -50,10 +49,11 @@ def preprocess_tags_for_equality(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: - """Remove tags of :class:`~pytato.tags.InfoTag` for equality comparison.""" + """Remove tags of :class:`~pytools.tag.IgnoredForEqualityTag` for equality + comparison.""" return frozenset(tag for tag in tags - if not isinstance(tag, InfoTag)) + if not isinstance(tag, IgnoredForEqualityTag)) # {{{ EqualityComparer diff --git a/pytato/tags.py b/pytato/tags.py index 0e480a273..a51ff14c3 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -9,13 +9,12 @@ .. autoclass:: Named .. autoclass:: PrefixNamed .. autoclass:: AssumeNonNegative -.. autoclass:: InfoTag .. autoclass:: CreatedAt """ from dataclasses import dataclass -from pytools.tag import Tag, UniqueTag, tag_dataclass +from pytools.tag import Tag, UniqueTag, tag_dataclass, IgnoredForEqualityTag from pytato.array import _PytatoStackSummary @@ -107,16 +106,11 @@ class AssumeNonNegative(Tag): """ -class InfoTag(Tag): - """A type of tag whose value is purely informational and should not be used - for equality comparison.""" - - # See # https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. @dataclass(init=True, eq=True, frozen=True, repr=True) -class CreatedAt(UniqueTag, InfoTag): +class CreatedAt(UniqueTag, IgnoredForEqualityTag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From 71dd791b45cee918d33ac8c4cb8e9a3cf643ec37 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 17 May 2022 10:34:51 -0500 Subject: [PATCH 035/124] UNDO BEFORE MERGE: use external project branches --- .github/workflows/ci.yml | 6 +++++- requirements.txt | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 81681f6da..032a42991 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -112,7 +112,11 @@ jobs: curl -L -O https://tiker.net/ci-support-v0 . ./ci-support-v0 - test_downstream "$DOWNSTREAM_PROJECT" + if [[ "$DOWNSTREAM_PROJECT" != "meshmode" ]]; then + test_downstream "$DOWNSTREAM_PROJECT" + else + test_downstream https://github.com/inducer/meshmode@filter_tags + fi if [[ "$DOWNSTREAM_PROJECT" = "meshmode" ]]; then python ../examples/simple-dg.py --lazy diff --git a/requirements.txt b/requirements.txt index a9cf2c76e..65fc57329 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1 +git+https://github.com/matthiasdiener/pytools.git@eq_tag#egg=pytools git+https://github.com/inducer/pymbolic.git#egg=pymbolic git+https://github.com/inducer/genpy.git#egg=genpy git+https://github.com/inducer/loopy.git#egg=loopy From b4f8b82d080433fcaa6e8affe8394b852ed17017 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 18 May 2022 10:04:52 -0500 Subject: [PATCH 036/124] Revert "UNDO BEFORE MERGE: use external project branches" This reverts commit 71dd791b45cee918d33ac8c4cb8e9a3cf643ec37. --- .github/workflows/ci.yml | 6 +----- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 032a42991..81681f6da 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -112,11 +112,7 @@ jobs: curl -L -O https://tiker.net/ci-support-v0 . ./ci-support-v0 - if [[ "$DOWNSTREAM_PROJECT" != "meshmode" ]]; then - test_downstream "$DOWNSTREAM_PROJECT" - else - test_downstream https://github.com/inducer/meshmode@filter_tags - fi + test_downstream "$DOWNSTREAM_PROJECT" if [[ "$DOWNSTREAM_PROJECT" = "meshmode" ]]; then python ../examples/simple-dg.py --lazy diff --git a/requirements.txt b/requirements.txt index 65fc57329..a9cf2c76e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -git+https://github.com/matthiasdiener/pytools.git@eq_tag#egg=pytools +git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1 git+https://github.com/inducer/pymbolic.git#egg=pymbolic git+https://github.com/inducer/genpy.git#egg=genpy git+https://github.com/inducer/loopy.git#egg=loopy From bfb22ba66d0872c3848a4ab41260f9f82ffa188d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 18 May 2022 10:04:59 -0500 Subject: [PATCH 037/124] Revert "use IgnoredForEqualityTag" This reverts commit cd67d684d1c777bc3efef080031446e033e75f42. --- pytato/equality.py | 8 ++++---- pytato/tags.py | 10 ++++++++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/pytato/equality.py b/pytato/equality.py index df48c85be..0dacbbb9f 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -32,7 +32,8 @@ Reshape, Roll, Stack, AbstractResultWithNamedArrays, Array, DictOfNamedArrays, Placeholder, SizeParam) -from pytools.tag import Tag, IgnoredForEqualityTag +from pytools.tag import Tag +from pytato.tags import InfoTag if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult @@ -49,11 +50,10 @@ def preprocess_tags_for_equality(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: - """Remove tags of :class:`~pytools.tag.IgnoredForEqualityTag` for equality - comparison.""" + """Remove tags of :class:`~pytato.tags.InfoTag` for equality comparison.""" return frozenset(tag for tag in tags - if not isinstance(tag, IgnoredForEqualityTag)) + if not isinstance(tag, InfoTag)) # {{{ EqualityComparer diff --git a/pytato/tags.py b/pytato/tags.py index a51ff14c3..0e480a273 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -9,12 +9,13 @@ .. autoclass:: Named .. autoclass:: PrefixNamed .. autoclass:: AssumeNonNegative +.. autoclass:: InfoTag .. autoclass:: CreatedAt """ from dataclasses import dataclass -from pytools.tag import Tag, UniqueTag, tag_dataclass, IgnoredForEqualityTag +from pytools.tag import Tag, UniqueTag, tag_dataclass from pytato.array import _PytatoStackSummary @@ -106,11 +107,16 @@ class AssumeNonNegative(Tag): """ +class InfoTag(Tag): + """A type of tag whose value is purely informational and should not be used + for equality comparison.""" + + # See # https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. @dataclass(init=True, eq=True, frozen=True, repr=True) -class CreatedAt(UniqueTag, IgnoredForEqualityTag): +class CreatedAt(UniqueTag, InfoTag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From 99ff0cdaeaa145b1475a05ffd24b4b1ba3a53162 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 18 May 2022 11:14:29 -0500 Subject: [PATCH 038/124] rename InfoTag -> IgnoredForEqualityTag --- pytato/equality.py | 7 ++++--- pytato/tags.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/pytato/equality.py b/pytato/equality.py index 0dacbbb9f..9a6145093 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -33,7 +33,7 @@ Array, DictOfNamedArrays, Placeholder, SizeParam) from pytools.tag import Tag -from pytato.tags import InfoTag +from pytato.tags import IgnoredForEqualityTag if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult @@ -50,10 +50,11 @@ def preprocess_tags_for_equality(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: - """Remove tags of :class:`~pytato.tags.InfoTag` for equality comparison.""" + """Remove tags of :class:`~pytato.tags.IgnoredForEqualityTag` for equality + comparison.""" return frozenset(tag for tag in tags - if not isinstance(tag, InfoTag)) + if not isinstance(tag, IgnoredForEqualityTag)) # {{{ EqualityComparer diff --git a/pytato/tags.py b/pytato/tags.py index 0e480a273..f6f9383d7 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -9,7 +9,7 @@ .. autoclass:: Named .. autoclass:: PrefixNamed .. autoclass:: AssumeNonNegative -.. autoclass:: InfoTag +.. autoclass:: IgnoredForEqualityTag .. autoclass:: CreatedAt """ @@ -107,7 +107,7 @@ class AssumeNonNegative(Tag): """ -class InfoTag(Tag): +class IgnoredForEqualityTag(Tag): """A type of tag whose value is purely informational and should not be used for equality comparison.""" @@ -116,7 +116,7 @@ class InfoTag(Tag): # https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. @dataclass(init=True, eq=True, frozen=True, repr=True) -class CreatedAt(UniqueTag, InfoTag): +class CreatedAt(UniqueTag, IgnoredForEqualityTag): """ A tag attached to a :class:`~pytato.Array` to store the traceback of where it was created. From a818694d22faf97c68e599a2ccea851df4ff6151 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 18 May 2022 11:36:53 -0500 Subject: [PATCH 039/124] more stringent tests --- test/test_pytato.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index f2ad7d9b4..cf1c05408 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -846,21 +846,32 @@ def test_created_at(): res = a+b res2 = a+b - # CreatedAt tags need to be filtered for equality to work correctly. + # {{{ Check that CreatedAt tags are filtered correctly for equality + from pytato.equality import preprocess_tags_for_equality + assert res == res2 + assert res.tags != res2.tags + assert (preprocess_tags_for_equality(res.tags) + == preprocess_tags_for_equality(res2.tags)) + + # }}} + from pytato.tags import CreatedAt created_tag = res.tags_of_type(CreatedAt) assert len(created_tag) == 1 + # {{{ Make sure the function name appears in the traceback + tag, = created_tag found = False - # Make sure the function name appears in the traceback - _unused = tag.traceback.to_stacksummary() # noqa + stacksummary = tag.traceback.to_stacksummary() + assert len(stacksummary) > 10 + for frame in tag.traceback.frames: if frame.name == "test_created_at": found = True @@ -868,6 +879,8 @@ def test_created_at(): assert found + # }}} + def test_pickling_and_unpickling_is_equal(): from testlib import RandomDAGContext, make_random_dag From 8a0a773124610a203bcec5bc50e21e959bea9333 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 18 May 2022 11:46:43 -0500 Subject: [PATCH 040/124] undo unnecessary test changes --- pytato/visualization.py | 2 +- test/test_pytato.py | 49 +++++------------------------------------ 2 files changed, 7 insertions(+), 44 deletions(-) diff --git a/pytato/visualization.py b/pytato/visualization.py index 92215c44d..b1535ada9 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -34,6 +34,7 @@ Mapping, Hashable, Any, FrozenSet) from pytools import UniqueNameGenerator +from pytools.tag import Tag from pytools.codegen import CodeGenerator as CodeGeneratorBase from pytato.loopy import LoopyCall @@ -44,7 +45,6 @@ from pytato.codegen import normalize_outputs from pytato.transform import CachedMapper, ArrayOrNames -from pytools.tag import Tag from pytato.partition import GraphPartition from pytato.distributed import DistributedGraphPart diff --git a/test/test_pytato.py b/test/test_pytato.py index cf1c05408..4ea7fb065 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -297,14 +297,6 @@ def test_dict_of_named_arrays_comparison(): dict2 = pt.make_dict_of_named_arrays({"out": 2 * x}) dict3 = pt.make_dict_of_named_arrays({"not_out": 2 * x}) dict4 = pt.make_dict_of_named_arrays({"out": 3 * x}) - - from pytato.transform import remove_tags_of_type - from pytato.tags import CreatedAt - dict1 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict1)) - dict2 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict2)) - dict3 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict3)) - dict4 = cast(pt.Array, remove_tags_of_type(CreatedAt, dict4)) - assert dict1 == dict2 assert dict1 != dict3 assert dict1 != dict4 @@ -413,13 +405,6 @@ def construct_intestine_graph(depth=90, seed=0): graph2 = construct_intestine_graph() graph3 = construct_intestine_graph(seed=3) - from pytato.transform import remove_tags_of_type - from pytato.tags import CreatedAt - - graph1 = remove_tags_of_type(CreatedAt, graph1) - graph2 = remove_tags_of_type(CreatedAt, graph2) - graph3 = remove_tags_of_type(CreatedAt, graph3) - assert EqualityComparer()(graph1, graph2) assert EqualityComparer()(graph2, graph1) assert not EqualityComparer()(graph1, graph3) @@ -479,10 +464,9 @@ def test_array_dot_repr(): x = pt.make_placeholder("x", (10, 4), np.int64) y = pt.make_placeholder("y", (10, 4), np.int64) - from pytato.transform import remove_tags_of_type - from pytato.tags import CreatedAt - def _assert_stripped_repr(ary: pt.Array, expected_repr: str): + from pytato.transform import remove_tags_of_type + from pytato.tags import CreatedAt ary = cast(pt.Array, remove_tags_of_type(CreatedAt, ary)) expected_str = "".join([c for c in repr(ary) if c not in [" ", "\n"]]) @@ -641,22 +625,10 @@ def test_rec_get_user_nodes(): expr = pt.make_dict_of_named_arrays({"out1": 2 * x1, "out2": 7 * x1 + 3 * x2}) - t1 = pt.transform.rec_get_user_nodes(expr, x1) - t1r = frozenset({2 * x1, 7*x1, 7*x1 + 3 * x2, expr}) - - t2 = pt.transform.rec_get_user_nodes(expr, x2) - t2r = frozenset({3 * x2, 7*x1 + 3 * x2, expr}) - - from pytato.transform import remove_tags_of_type - from pytato.tags import CreatedAt - - t1 = frozenset({remove_tags_of_type(CreatedAt, t) for t in t1}) - t1r = frozenset({remove_tags_of_type(CreatedAt, t) for t in t1r}) - t2 = frozenset({remove_tags_of_type(CreatedAt, t) for t in t2}) - t2r = frozenset({remove_tags_of_type(CreatedAt, t) for t in t2r}) - - assert (t1 == t1r) - assert (t2 == t2r) + assert (pt.transform.rec_get_user_nodes(expr, x1) + == frozenset({2 * x1, 7*x1, 7*x1 + 3 * x2, expr})) + assert (pt.transform.rec_get_user_nodes(expr, x2) + == frozenset({3 * x2, 7*x1 + 3 * x2, expr})) def test_rec_get_user_nodes_linear_complexity(): @@ -703,13 +675,7 @@ def construct_intestine_graph(depth=90, seed=0): return y, x - from pytato.transform import remove_tags_of_type - from pytato.tags import CreatedAt - expr, inp = construct_intestine_graph() - expr = remove_tags_of_type(CreatedAt, expr) - inp = remove_tags_of_type(CreatedAt, inp) - user_collector = pt.transform.UsersCollector() user_collector(expr) @@ -720,9 +686,6 @@ def post_visit(self, expr): expected_result[expr] = {"foo"} expr, inp = construct_intestine_graph() - expr = remove_tags_of_type(CreatedAt, expr) - inp = remove_tags_of_type(CreatedAt, inp) - result = pt.transform.tag_user_nodes(user_collector.node_to_users, "foo", inp) ExpectedResultComputer()(expr) From 91fe92f0f78d52e8cf61dbc4568d64f7463645e1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 19 May 2022 15:26:43 -0500 Subject: [PATCH 041/124] Revert "Revert "use IgnoredForEqualityTag"" This reverts commit bfb22ba66d0872c3848a4ab41260f9f82ffa188d. --- pytato/equality.py | 5 ++--- pytato/tags.py | 8 +------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/pytato/equality.py b/pytato/equality.py index 9a6145093..df48c85be 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -32,8 +32,7 @@ Reshape, Roll, Stack, AbstractResultWithNamedArrays, Array, DictOfNamedArrays, Placeholder, SizeParam) -from pytools.tag import Tag -from pytato.tags import IgnoredForEqualityTag +from pytools.tag import Tag, IgnoredForEqualityTag if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult @@ -50,7 +49,7 @@ def preprocess_tags_for_equality(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: - """Remove tags of :class:`~pytato.tags.IgnoredForEqualityTag` for equality + """Remove tags of :class:`~pytools.tag.IgnoredForEqualityTag` for equality comparison.""" return frozenset(tag for tag in tags diff --git a/pytato/tags.py b/pytato/tags.py index f6f9383d7..a51ff14c3 100644 --- a/pytato/tags.py +++ b/pytato/tags.py @@ -9,13 +9,12 @@ .. autoclass:: Named .. autoclass:: PrefixNamed .. autoclass:: AssumeNonNegative -.. autoclass:: IgnoredForEqualityTag .. autoclass:: CreatedAt """ from dataclasses import dataclass -from pytools.tag import Tag, UniqueTag, tag_dataclass +from pytools.tag import Tag, UniqueTag, tag_dataclass, IgnoredForEqualityTag from pytato.array import _PytatoStackSummary @@ -107,11 +106,6 @@ class AssumeNonNegative(Tag): """ -class IgnoredForEqualityTag(Tag): - """A type of tag whose value is purely informational and should not be used - for equality comparison.""" - - # See # https://mypy.readthedocs.io/en/stable/additional_features.html#caveats-known-issues # on why this can not be '@tag_dataclass'. From 1111b7915e23561161eee132e8d5ccb716b101eb Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 19 May 2022 16:21:12 -0500 Subject: [PATCH 042/124] simplify condition --- pytato/array.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index c0285a909..8d1d5f2ec 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1765,9 +1765,7 @@ def _get_default_tags(existing_tags: Optional[FrozenSet[Tag]] = None) \ import traceback from pytato.tags import CreatedAt - if __debug__ and ( - existing_tags is None - or not any(isinstance(tag, CreatedAt) for tag in existing_tags)): + if __debug__ and not any(isinstance(tag, CreatedAt) for tag in existing_tags): frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) for s in traceback.extract_stack()) c = CreatedAt(_PytatoStackSummary(frames)) From ff2582f0d2c37d070ecdbf771ea9df61bff6f310 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 19 May 2022 16:31:23 -0500 Subject: [PATCH 043/124] Revert "simplify condition" This reverts commit 1111b7915e23561161eee132e8d5ccb716b101eb. --- pytato/array.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index 8d1d5f2ec..c0285a909 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1765,7 +1765,9 @@ def _get_default_tags(existing_tags: Optional[FrozenSet[Tag]] = None) \ import traceback from pytato.tags import CreatedAt - if __debug__ and not any(isinstance(tag, CreatedAt) for tag in existing_tags): + if __debug__ and ( + existing_tags is None + or not any(isinstance(tag, CreatedAt) for tag in existing_tags)): frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) for s in traceback.extract_stack()) c = CreatedAt(_PytatoStackSummary(frames)) From 26c3590e07313788ad2a1b856c5f7e1474ea07db Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Sat, 21 May 2022 18:02:40 -0500 Subject: [PATCH 044/124] bump pytools version + a few spelling fixes --- pytato/array.py | 4 ++-- setup.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index c0285a909..d9d2beeeb 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -355,7 +355,7 @@ class Array(Taggable): :class:`~pytato.array.IndexLambda` is used to produce references to named arrays. Since any array that needs to be referenced in this way needs to obey this restriction anyway, - a decision was made to requir the same of *all* array expressions. + a decision was made to require the same of *all* array expressions. .. attribute:: dtype @@ -677,7 +677,7 @@ def dtype(self) -> np.dtype[Any]: class NamedArray(Array): """An entry in a :class:`AbstractResultWithNamedArrays`. Holds a reference - back to thecontaining instance as well as the name by which *self* is + back to the containing instance as well as the name by which *self* is known there. .. automethod:: __init__ diff --git a/setup.py b/setup.py index 336157d5d..e32d7eca0 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ python_requires="~=3.8", install_requires=[ "loopy>=2020.2", - "pytools>=2021.1", + "pytools>=2022.1.8", "pyrsistent" ], package_data={"pytato": ["py.typed"]}, From 423e3fb3aaa6950d458b360bf984f13e5bac8c1c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Sat, 21 May 2022 18:10:24 -0500 Subject: [PATCH 045/124] remove duplicated self.axes in hash() --- pytato/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index d9d2beeeb..58c6f8ac1 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1598,7 +1598,7 @@ def __init__(self, def __hash__(self) -> int: from pytato.equality import preprocess_tags_for_equality return hash((self.name, id(self.data), self._shape, self.axes, - preprocess_tags_for_equality(self.tags), self.axes)) + preprocess_tags_for_equality(self.tags))) def __eq__(self, other: Any) -> bool: return self is other From 87606ed9569d3d26a36a858316c23f2fcf174840 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 1 Jun 2022 22:32:52 -0500 Subject: [PATCH 046/124] use Taggable{__eq__,__hash__} --- pytato/array.py | 3 +-- pytato/equality.py | 36 ++++++++++++++++-------------------- 2 files changed, 17 insertions(+), 22 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 2d2839f82..157428b83 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1596,9 +1596,8 @@ def __init__(self, self._shape = shape def __hash__(self) -> int: - from pytato.equality import preprocess_tags_for_equality return hash((self.name, id(self.data), self._shape, self.axes, - preprocess_tags_for_equality(self.tags))) + Taggable.__hash__(self))) def __eq__(self, other: Any) -> bool: return self is other diff --git a/pytato/equality.py b/pytato/equality.py index 4dffab13c..4a99bf4a1 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -32,7 +32,7 @@ Reshape, Roll, Stack, AbstractResultWithNamedArrays, Array, DictOfNamedArrays, Placeholder, SizeParam) -from pytools.tag import Tag, IgnoredForEqualityTag +from pytools.tag import Tag, IgnoredForEqualityTag, Taggable if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult @@ -107,10 +107,6 @@ def handle_unsupported_array(self, expr1: Array, expr2: Any) -> bool: raise NotImplementedError(type(expr1).__name__) - def are_tags_equal(self, tags1: FrozenSet[Tag], tags2: FrozenSet[Tag]) -> bool: - return (preprocess_tags_for_equality(tags1) - == preprocess_tags_for_equality(tags2)) - def map_foreign(self, expr1: Any, expr2: Any) -> bool: raise NotImplementedError(type(expr1).__name__) @@ -119,14 +115,14 @@ def map_placeholder(self, expr1: Placeholder, expr2: Any) -> bool: and expr1.name == expr2.name and expr1.shape == expr2.shape and expr1.dtype == expr2.dtype - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) def map_size_param(self, expr1: SizeParam, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.name == expr2.name - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) @@ -145,7 +141,7 @@ def map_index_lambda(self, expr1: IndexLambda, expr2: Any) -> bool: if isinstance(dim1, Array) else dim1 == dim2 for dim1, dim2 in zip(expr1.shape, expr2.shape)) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes) def map_stack(self, expr1: Stack, expr2: Any) -> bool: @@ -154,7 +150,7 @@ def map_stack(self, expr1: Stack, expr2: Any) -> bool: and len(expr1.arrays) == len(expr2.arrays) and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) @@ -164,7 +160,7 @@ def map_concatenate(self, expr1: Concatenate, expr2: Any) -> bool: and len(expr1.arrays) == len(expr2.arrays) and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) @@ -172,7 +168,7 @@ def map_roll(self, expr1: Roll, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.axis == expr2.axis and self.rec(expr1.array, expr2.array) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) @@ -180,7 +176,7 @@ def map_axis_permutation(self, expr1: AxisPermutation, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.axis_permutation == expr2.axis_permutation and self.rec(expr1.array, expr2.array) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) @@ -193,7 +189,7 @@ def _map_index_base(self, expr1: IndexBase, expr2: Any) -> bool: and isinstance(idx2, Array)) else idx1 == idx2 for idx1, idx2 in zip(expr1.indices, expr2.indices)) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) @@ -216,7 +212,7 @@ def map_reshape(self, expr1: Reshape, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.newshape == expr2.newshape and self.rec(expr1.array, expr2.array) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) @@ -226,14 +222,14 @@ def map_einsum(self, expr1: Einsum, expr2: Any) -> bool: and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.args, expr2.args)) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes ) def map_named_array(self, expr1: NamedArray, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and self.rec(expr1._container, expr2._container) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes and expr1.name == expr2.name) @@ -252,7 +248,7 @@ def map_loopy_call(self, expr1: LoopyCall, expr2: Any) -> bool: def map_loopy_call_result(self, expr1: LoopyCallResult, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and self.rec(expr1._container, expr2._container) - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) and expr1.axes == expr2.axes and expr1.name == expr2.name) @@ -269,8 +265,8 @@ def map_distributed_send_ref_holder( and self.rec(expr1.passthrough_data, expr2.passthrough_data) and expr1.send.dest_rank == expr2.send.dest_rank and expr1.send.comm_tag == expr2.send.comm_tag - and expr1.send.tags == expr2.send.tags - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1.send, expr2.send) + and Taggable.__eq__(expr1, expr2) ) def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: @@ -279,7 +275,7 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: and expr1.comm_tag == expr2.comm_tag and expr1.shape == expr2.shape and expr1.dtype == expr2.dtype - and self.are_tags_equal(expr1.tags, expr2.tags) + and Taggable.__eq__(expr1, expr2) ) # }}} From 73c7f7793e900d5a8ed8a4f03d0cf668be8c5067 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Sat, 4 Jun 2022 15:09:36 +0200 Subject: [PATCH 047/124] add another test --- test/test_pytato.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 67de9461e..961271d53 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -806,23 +806,34 @@ def test_created_at(): a = pt.make_placeholder("a", (10, 10), "float64") b = pt.make_placeholder("b", (10, 10), "float64") - res = a+b + # res1 and res2 are defined on different lines and should have different + # CreatedAt tags. + res1 = a+b res2 = a+b - # {{{ Check that CreatedAt tags are filtered correctly for equality + # res3 and res4 are defined on the same line and should have the same + # CreatedAt tags. + res3 = a+b; res4 = a+b # noqa: E702 + + # {{{ Check that CreatedAt tags are handled correctly for equality + from pytato.equality import preprocess_tags_for_equality - assert res == res2 + assert res1 == res2 == res3 == res4 + + assert res1.tags != res2.tags + assert res3.tags == res4.tags - assert res.tags != res2.tags - assert (preprocess_tags_for_equality(res.tags) + assert (preprocess_tags_for_equality(res1.tags) == preprocess_tags_for_equality(res2.tags)) + assert (preprocess_tags_for_equality(res3.tags) + == preprocess_tags_for_equality(res4.tags)) # }}} from pytato.tags import CreatedAt - created_tag = res.tags_of_type(CreatedAt) + created_tag = res1.tags_of_type(CreatedAt) assert len(created_tag) == 1 From b84b66ef83ac0d31244c6847279900306f977271 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Sat, 4 Jun 2022 15:30:13 +0200 Subject: [PATCH 048/124] add vis test --- test/test_pytato.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 961271d53..2a76f841d 100755 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -847,7 +847,7 @@ def test_created_at(): assert len(stacksummary) > 10 for frame in tag.traceback.frames: - if frame.name == "test_created_at": + if frame.name == "test_created_at" and "a+b" in frame.line: found = True break @@ -855,6 +855,15 @@ def test_created_at(): # }}} + # {{{ Make sure that CreatedAt tags are in the visualization + + from pytato.visualization import get_dot_graph + s = get_dot_graph(res1) + assert "test_created_at" in s + assert "a+b" in s + + # }}} + def test_pickling_and_unpickling_is_equal(): from testlib import RandomDAGContext, make_random_dag From 6c653bf3ae968640135c07fddeb2756737b97953 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 20 Jun 2022 10:32:52 -0500 Subject: [PATCH 049/124] make _PytatoFrameSummary, _PytatoStackSummary undocumented --- pytato/array.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 6b0ede462..fc1f9f94f 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -143,8 +143,13 @@ .. autoclass:: EinsumReductionAxis .. autoclass:: NormalizedSlice -.. autoclass:: _PytatoFrameSummary -.. autoclass:: _PytatoStackSummary +Internal classes for traceback +------------------------------ + +Please consider these undocumented and subject to change at any time. + +.. class:: _PytatoFrameSummary +.. class:: _PytatoStackSummary Internal stuff that is only here because the documentation tool wants it ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ From 02dd6f5b353494c507bed3492e9457a91d91323b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 20 Jun 2022 10:38:41 -0500 Subject: [PATCH 050/124] use Taggable.__hash__ for tags in Array.__hash__ --- pytato/array.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index fc1f9f94f..0e3a3908e 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -502,12 +502,11 @@ def T(self) -> Array: @memoize_method def __hash__(self) -> int: - from pytato.equality import preprocess_tags_for_equality attrs = [] for field in self._fields: attr = getattr(self, field) if field == "tags": - attr = preprocess_tags_for_equality(attr) + attr = Taggable.__hash__(self) if isinstance(attr, dict): attr = frozenset(attr.items()) attrs.append(attr) From 7c937079ec4498917b11e8a3b1cf47a120303376 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 28 Mar 2023 15:44:20 -0500 Subject: [PATCH 051/124] change dataclass to attrs --- pytato/array.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index d912ffc68..f6c2e5d69 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1808,7 +1808,7 @@ def _get_default_axes(ndim: int) -> AxesT: return tuple(Axis(frozenset()) for _ in range(ndim)) -@dataclass(frozen=True, eq=True) +@attrs.define(frozen=True, eq=True) class _PytatoFrameSummary: """Class to store a single call frame.""" filename: str @@ -1820,12 +1820,12 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: key_builder.rec(key_hash, (self.__class__.__module__, self.__class__.__qualname__)) - from dataclasses import fields + from attrs import fields # Fields are ordered consistently, so ordered hashing is OK. # # No need to dispatch to superclass: fields() automatically gives us # fields from the entire class hierarchy. - for f in fields(self): + for f in fields(self.__class__): key_builder.rec(key_hash, getattr(self, f.name)) def short_str(self, maxlen: int = 100) -> str: @@ -1840,7 +1840,7 @@ def __repr__(self) -> str: return f"{self.filename}:{self.lineno}, in {self.name}(): {self.line}" -@dataclass(frozen=True, eq=True) +@attrs.define(frozen=True, eq=True) class _PytatoStackSummary: """Class to store a list of :class:`_PytatoFrameSummary` call frames.""" frames: Tuple[_PytatoFrameSummary, ...] @@ -1849,19 +1849,18 @@ def to_stacksummary(self) -> StackSummary: frames = [FrameSummary(f.filename, f.lineno, f.name, line=f.line) for f in self.frames] - # type-ignore-reason: from_list also takes List[FrameSummary] - return StackSummary.from_list(frames) # type: ignore[arg-type] + return StackSummary.from_list(frames) def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: key_builder.rec(key_hash, (self.__class__.__module__, self.__class__.__qualname__)) - from dataclasses import fields + from attrs import fields # Fields are ordered consistently, so ordered hashing is OK. # # No need to dispatch to superclass: fields() automatically gives us # fields from the entire class hierarchy. - for f in fields(self): + for f in fields(self.__class__): key_builder.rec(key_hash, getattr(self, f.name)) def short_str(self, maxlen: int = 100) -> str: From dd9916bdb5571c405c948a9f259368921a22d583 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 28 Mar 2023 16:13:31 -0500 Subject: [PATCH 052/124] flake8 --- pytato/visualization.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/pytato/visualization.py b/pytato/visualization.py index 413b5495a..dd2448a51 100644 --- a/pytato/visualization.py +++ b/pytato/visualization.py @@ -122,11 +122,11 @@ def __init__(self) -> None: def get_common_dot_info(self, expr: Array) -> DotNodeInfo: title = type(expr).__name__ fields = {"addr": hex(id(expr)), - "shape": stringify_shape(expr.shape), - "dtype": str(expr.dtype), - "tags": stringify_tags(expr.tags), - "created_at": stringify_created_at(expr.tags), - } + "shape": stringify_shape(expr.shape), + "dtype": str(expr.dtype), + "tags": stringify_tags(expr.tags), + "created_at": stringify_created_at(expr.tags), + } edges: Dict[str, ArrayOrNames] = {} return DotNodeInfo(title, fields, edges) From 61c029cbdc184007efbb3282d2be5388e4441180 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 28 Mar 2023 18:38:53 -0500 Subject: [PATCH 053/124] Taggable.__eq__ --- pytato/equality.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/equality.py b/pytato/equality.py index b381e96ec..eea88a45f 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -247,7 +247,7 @@ def map_loopy_call(self, expr1: LoopyCall, expr2: Any) -> bool: if isinstance(bnd, Array) else bnd == expr2.bindings[name] for name, bnd in expr1.bindings.items()) - and expr1.tags == expr2.tags + and Taggable.__eq__(expr1, expr2) ) def map_loopy_call_result(self, expr1: LoopyCallResult, expr2: Any) -> bool: @@ -262,7 +262,7 @@ def map_dict_of_named_arrays(self, expr1: DictOfNamedArrays, expr2: Any) -> bool and frozenset(expr1._data.keys()) == frozenset(expr2._data.keys()) and all(self.rec(expr1._data[name], expr2._data[name]) for name in expr1._data) - and expr1.tags == expr2.tags + and Taggable.__eq__(expr1, expr2) ) def map_distributed_send_ref_holder( From 3cf1559cdc16d0f4604ef4a3178f35882d3614fc Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 29 Mar 2023 14:22:55 -0500 Subject: [PATCH 054/124] add Array.tagged() --- pytato/array.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index f6c2e5d69..4a0faf0d9 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -180,7 +180,7 @@ import pymbolic.primitives as prim from pymbolic import var from pytools import memoize_method -from pytools.tag import Tag, Taggable +from pytools.tag import Tag, Taggable, ToTagSetConvertible from pytato.scalar_expr import (ScalarType, SCALAR_CLASSES, ScalarExpression, IntegralT, @@ -689,6 +689,12 @@ def __repr__(self) -> str: from pytato.stringifier import Reprifier return Reprifier()(self) + def tagged(self, tags: ToTagSetConvertible) -> Array: + from pytato.equality import preprocess_tags_for_equality + from pytools.tag import normalize_tags + new_tags = preprocess_tags_for_equality(normalize_tags(tags)) + return super().tagged(new_tags) + # }}} From 44d1c34acd2795feac052c6c4f4cc7a3764f3394 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 19 May 2023 14:11:41 -0500 Subject: [PATCH 055/124] restrict to DEBUG_ENABLED --- pytato/array.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index e02f5e1c1..8ea7f9270 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1893,7 +1893,11 @@ def _get_default_tags(existing_tags: Optional[FrozenSet[Tag]] = None) \ import traceback from pytato.tags import CreatedAt - if __debug__ and ( + from pytato import DEBUG_ENABLED + + # This has a significant overhead, so only enable it when PYTATO_DEBUG is + # enabled. + if DEBUG_ENABLED and ( existing_tags is None or not any(isinstance(tag, CreatedAt) for tag in existing_tags)): frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) From a150c79e603a4a6bc9ac79bf26cea703fdcfb04f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 19 May 2023 15:26:49 -0500 Subject: [PATCH 056/124] force DEBUG_ENABLED for test --- test/test_pytato.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/test_pytato.py b/test/test_pytato.py index 80aaf2c9a..bb2a46549 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -836,6 +836,9 @@ def test_created_at(): a = pt.make_placeholder("a", (10, 10), "float64") b = pt.make_placeholder("b", (10, 10), "float64") + _prev_debug_enabled = pt.DEBUG_ENABLED + pt.DEBUG_ENABLED = True + # res1 and res2 are defined on different lines and should have different # CreatedAt tags. res1 = a+b @@ -892,6 +895,8 @@ def test_created_at(): assert "test_created_at" in s assert "a+b" in s + pt.DEBUG_ENABLED = _prev_debug_enabled + # }}} From 07690673d0b161ec9e94ed10cf99b912b2704817 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 17 Nov 2021 11:19:40 -0600 Subject: [PATCH 057/124] CHERRY-PICK: Preserve High-Level Info in the Pymbolic expressions --- pytato/array.py | 44 ++++++++++++++++++++++++++++---------------- 1 file changed, 28 insertions(+), 16 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 964ff351f..69efdcb1f 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -585,25 +585,37 @@ def _unary_op(self, op: Any) -> Array: axes=_get_default_axes(self.ndim), var_to_reduction_descr=Map()) - __mul__ = partialmethod(_binary_op, operator.mul) - __rmul__ = partialmethod(_binary_op, operator.mul, reverse=True) - - __add__ = partialmethod(_binary_op, operator.add) - __radd__ = partialmethod(_binary_op, operator.add, reverse=True) - - __sub__ = partialmethod(_binary_op, operator.sub) - __rsub__ = partialmethod(_binary_op, operator.sub, reverse=True) - - __floordiv__ = partialmethod(_binary_op, operator.floordiv) - __rfloordiv__ = partialmethod(_binary_op, operator.floordiv, reverse=True) - - __truediv__ = partialmethod(_binary_op, operator.truediv, + # NOTE: Initializing the expression to "prim.Product(expr1, expr2)" is + # essential as opposed to performing "expr1 * expr2". This is to account + # for pymbolic's implementation of the "*" operator which might not + # instantiate the node corresponding to the operation when one of + # the operands is the neutral element of the operation. + # + # For the same reason 'prim.(Sum|FloorDiv|Quotient)' is preferred over the + # python operators on the operands. + + __mul__ = partialmethod(_binary_op, lambda l, r: prim.Product((l, r))) + __rmul__ = partialmethod(_binary_op, lambda l, r: prim.Product((l, r)), + reverse=True) + + __add__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, r))) + __radd__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, r)), + reverse=True) + + __sub__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, -r))) + __rsub__ = partialmethod(_binary_op, lambda l, r: prim.Sum((l, -r)), + reverse=True) + + __floordiv__ = partialmethod(_binary_op, prim.FloorDiv) + __rfloordiv__ = partialmethod(_binary_op, prim.FloorDiv, reverse=True) + + __truediv__ = partialmethod(_binary_op, prim.Quotient, get_result_type=_truediv_result_type) - __rtruediv__ = partialmethod(_binary_op, operator.truediv, + __rtruediv__ = partialmethod(_binary_op, prim.Quotient, get_result_type=_truediv_result_type, reverse=True) - __pow__ = partialmethod(_binary_op, operator.pow) - __rpow__ = partialmethod(_binary_op, operator.pow, reverse=True) + __pow__ = partialmethod(_binary_op, prim.Power) + __rpow__ = partialmethod(_binary_op, prim.Power, reverse=True) __neg__ = partialmethod(_unary_op, operator.neg) From 8b3a13b3243055523ab0d3b31edb28b2ed9aeac0 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni Date: Wed, 25 May 2022 21:07:29 -0500 Subject: [PATCH 058/124] [CHERRY-PICK]: Call BranchMorpher after dw deduplication --- pytato/transform/__init__.py | 37 +++++++++++++++++++++++++++++++----- test/test_codegen.py | 2 +- 2 files changed, 33 insertions(+), 6 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index a75083f9a..cd2a39674 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -224,10 +224,9 @@ def rec(self, # type: ignore[override] # type-ignore-reason: CachedMapper.rec's return type is imprecise return super().rec(expr) # type: ignore[return-value] - # type-ignore-reason: specialized variant of super-class' rec method - def __call__(self, # type: ignore[override] - expr: CopyMapperResultT) -> CopyMapperResultT: - return self.rec(expr) + # type-ignore reason: incompatible type with Mapper.rec + def __call__(self, expr: MappedT) -> MappedT: # type: ignore[override] + return self.rec(expr) # type: ignore[no-any-return] def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...] ) -> Tuple[IndexOrShapeExpr, ...]: @@ -1569,6 +1568,33 @@ def tag_user_nodes( # }}} +# {{{ BranchMorpher + +class BranchMorpher(CopyMapper): + """ + A mapper that replaces equal segments of graphs with identical objects. + """ + def __init__(self) -> None: + super().__init__() + self.result_cache: Dict[ArrayOrNames, ArrayOrNames] = {} + + def cache_key(self, expr: CachedMapperT) -> Any: + return (id(expr), expr) + + # type-ignore reason: incompatible with Mapper.rec + def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] + rec_expr = super().rec(expr) + try: + # type-ignored because 'result_cache' maps to ArrayOrNames + return self.result_cache[rec_expr] # type: ignore[return-value] + except KeyError: + self.result_cache[rec_expr] = rec_expr + # type-ignored because of super-class' relaxed types + return rec_expr # type: ignore[no-any-return] + +# }}} + + # {{{ deduplicate_data_wrappers def _get_data_dedup_cache_key(ary: DataInterface) -> Hashable: @@ -1658,8 +1684,9 @@ def cached_data_wrapper_if_present(ary: ArrayOrNames) -> ArrayOrNames: len(data_wrapper_cache), data_wrappers_encountered - len(data_wrapper_cache)) - return array_or_names + return BranchMorpher()(array_or_names) # }}} + # vim: foldmethod=marker diff --git a/test/test_codegen.py b/test/test_codegen.py index 874906528..46e7a3186 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -1556,7 +1556,7 @@ def test_zero_size_cl_array_dedup(ctx_factory): x4 = pt.make_data_wrapper(x_cl2) out = pt.make_dict_of_named_arrays({"out1": 2*x1, - "out2": 2*x2, + "out2": 3*x2, "out3": x3 + x4 }) From 8e20d15eafb9c60b991da82967ffadcca98861a0 Mon Sep 17 00:00:00 2001 From: Andreas Kloeckner Date: Fri, 4 Aug 2023 13:18:12 -0500 Subject: [PATCH 059/124] Define __attrs_post_init__ only if __debug__, for all Array classes --- pytato/array.py | 18 ++++++++++-------- pytato/function.py | 11 ++++++----- 2 files changed, 16 insertions(+), 13 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index bbf4ae739..067c1a08d 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -745,10 +745,11 @@ class AbstractResultWithNamedArrays(Mapping[str, NamedArray], Taggable, ABC): def _is_eq_valid(self) -> bool: return self.__class__.__eq__ is AbstractResultWithNamedArrays.__eq__ - def __attrs_post_init__(self) -> None: - # ensure that a developer does not uses dataclass' "__eq__" - # or "__hash__" implementation as they have exponential complexity. - assert self._is_eq_valid() + if __debug__: + def __attrs_post_init__(self) -> None: + # ensure that a developer does not uses dataclass' "__eq__" + # or "__hash__" implementation as they have exponential complexity. + assert self._is_eq_valid() @abstractmethod def __contains__(self, name: object) -> bool: @@ -1450,10 +1451,11 @@ class Reshape(IndexRemappingBase): _mapper_method: ClassVar[str] = "map_reshape" - def __attrs_post_init__(self) -> None: - # FIXME: Get rid of this restriction - assert self.order == "C" - super().__attrs_post_init__() + if __debug__: + def __attrs_post_init__(self) -> None: + # FIXME: Get rid of this restriction + assert self.order == "C" + super().__attrs_post_init__() @property def shape(self) -> ShapeType: diff --git a/pytato/function.py b/pytato/function.py index 6e5d044d2..b053831a0 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -276,11 +276,12 @@ class Call(AbstractResultWithNamedArrays): copy = attrs.evolve - def __attrs_post_init__(self) -> None: - # check that the invocation parameters and the function definition - # parameters agree with each other. - assert frozenset(self.bindings) == self.function.parameters - super().__attrs_post_init__() + if __debug__: + def __attrs_post_init__(self) -> None: + # check that the invocation parameters and the function definition + # parameters agree with each other. + assert frozenset(self.bindings) == self.function.parameters + super().__attrs_post_init__() def __contains__(self, name: object) -> bool: return name in self.function.returns From efcae65bda3679246d24f63ff6a9d8821a1c08df Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sun, 10 Sep 2023 20:41:15 -0500 Subject: [PATCH 060/124] First shot at implementing 'F' ordered array reshapes --- pytato/array.py | 8 +-- pytato/transform/lower_to_index_lambda.py | 65 +++++++++++++++-------- 2 files changed, 45 insertions(+), 28 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 9fdf1de02..f09e2d8be 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1493,8 +1493,6 @@ class Reshape(IndexRemappingBase): _mapper_method: ClassVar[str] = "map_reshape" def __post_init__(self) -> None: - # FIXME: Get rid of this restriction - assert self.order == "C" super().__post_init__() __attrs_post_init__ = __post_init__ @@ -1958,8 +1956,7 @@ def reshape(array: Array, newshape: Union[int, Sequence[int]], """ :param array: array to be reshaped :param newshape: shape of the resulting array - :param order: ``"C"`` or ``"F"``. Layout order of the result array. Only - ``"C"`` allowed for now. + :param order: ``"C"`` or ``"F"``. Layout order of the resulting array. .. note:: @@ -1979,9 +1976,6 @@ def reshape(array: Array, newshape: Union[int, Sequence[int]], if not all(isinstance(axis_len, INT_CLASSES) for axis_len in array.shape): raise ValueError("reshape of arrays with symbolic lengths not allowed") - if order != "C": - raise NotImplementedError("Reshapes to a 'F'-ordered arrays") - newshape_explicit = [] for new_axislen in newshape: diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index a2bb443f0..5c2dfbca1 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -51,27 +51,50 @@ def _get_reshaped_indices(expr: Reshape) -> Tuple[ScalarExpression, ...]: assert expr.size == 1 return () - if expr.order != "C": - raise NotImplementedError(expr.order) - - newstrides: List[IntegralT] = [1] # reshaped array strides - for new_axis_len in reversed(expr.shape[1:]): - assert isinstance(new_axis_len, INT_CLASSES) - newstrides.insert(0, newstrides[0]*new_axis_len) - - flattened_idx = sum(prim.Variable(f"_{i}")*stride - for i, stride in enumerate(newstrides)) - - oldstrides: List[IntegralT] = [1] # input array strides - for axis_len in reversed(expr.array.shape[1:]): - assert isinstance(axis_len, INT_CLASSES) - oldstrides.insert(0, oldstrides[0]*axis_len) - - assert isinstance(expr.array.shape[-1], INT_CLASSES) - oldsizetills = [expr.array.shape[-1]] # input array size till for axes idx - for old_axis_len in reversed(expr.array.shape[:-1]): - assert isinstance(old_axis_len, INT_CLASSES) - oldsizetills.insert(0, oldsizetills[0]*old_axis_len) + if expr.order not in ["C", "F"]: + raise NotImplementedError("Order expected to be 'C' or 'F'", + f" found {expr.order}") + + if expr.order == "C": + newstrides: List[IntegralT] = [1] # reshaped array strides + for new_axis_len in reversed(expr.shape[1:]): + assert isinstance(new_axis_len, INT_CLASSES) + newstrides.insert(0, newstrides[0]*new_axis_len) + + flattened_idx = sum(prim.Variable(f"_{i}")*stride + for i, stride in enumerate(newstrides)) + + oldstrides: List[IntegralT] = [1] # input array strides + for axis_len in reversed(expr.array.shape[1:]): + assert isinstance(axis_len, INT_CLASSES) + oldstrides.insert(0, oldstrides[0]*axis_len) + + assert isinstance(expr.array.shape[-1], INT_CLASSES) + oldsizetills = [expr.array.shape[-1]] # input array size + # till for axes idx + for old_axis_len in reversed(expr.array.shape[:-1]): + assert isinstance(old_axis_len, INT_CLASSES) + oldsizetills.insert(0, oldsizetills[0]*old_axis_len) + + else: + newstrides: List[IntegralT] = [1] # reshaped array strides + for new_axis_len in expr.shape[:-1]: + assert isinstance(new_axis_len, INT_CLASSES) + newstrides.append(newstrides[-1]*new_axis_len) + + flattened_idx = sum(prim.Variable(f"_{i}")*stride + for i, stride in enumerate(newstrides)) + + oldstrides: List[IntegralT] = [1] # input array strides + for axis_len in expr.array.shape[:-1]: + assert isinstance(axis_len, INT_CLASSES) + oldstrides.append(oldstrides[-1]*axis_len) + + assert isinstance(expr.array.shape[0], INT_CLASSES) + oldsizetills = [expr.array.shape[0]] # input array size till for axes idx + for old_axis_len in expr.array.shape[1:]: + assert isinstance(old_axis_len, INT_CLASSES) + oldsizetills.append(oldsizetills[-1]*old_axis_len) return tuple(((flattened_idx % sizetill) // stride) for stride, sizetill in zip(oldstrides, oldsizetills)) From 35c6d1fe084b82ac1282b70c0eb9471aab15a460 Mon Sep 17 00:00:00 2001 From: Addison Alvey-Blanco Date: Sun, 10 Sep 2023 21:38:24 -0500 Subject: [PATCH 061/124] Remove restriction on reshape order --- pytato/array.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index a981ee20e..bbfdafda0 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1453,8 +1453,6 @@ class Reshape(IndexRemappingBase): if __debug__: def __attrs_post_init__(self) -> None: - # FIXME: Get rid of this restriction - assert self.order == "C" super().__attrs_post_init__() @property From 86233c647e5e2854b15b74a66a15aa7aa62f29ff Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Sat, 14 Oct 2023 15:43:55 -0500 Subject: [PATCH 062/124] work around mypy/attrs issue --- pytato/stringifier.py | 7 +++++-- pytato/visualization/dot.py | 3 ++- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pytato/stringifier.py b/pytato/stringifier.py index 8aac8d340..269c9a546 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -95,7 +95,9 @@ def _map_generic_array(self, expr: Array, depth: int) -> str: if depth > self.truncation_depth: return self.truncation_string - fields = tuple(field.name for field in attrs.fields(type(expr))) + # type-ignore-reason: https://github.com/python/mypy/issues/16254 + fields = tuple(field.name + for field in attrs.fields(type(expr))) # type: ignore[misc] if expr.ndim <= 1: # prettify: if ndim <=1 'expr.axes' would be trivial, @@ -153,7 +155,8 @@ def _get_field_val(field: str) -> str: return (f"{type(expr).__name__}(" + ", ".join(f"{field.name}={_get_field_val(field.name)}" - for field in attrs.fields(type(expr))) + # type-ignore-reason: https://github.com/python/mypy/issues/16254 + for field in attrs.fields(type(expr))) # type: ignore[misc] + ")") def map_loopy_call(self, expr: LoopyCall, depth: int) -> str: diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index d84b3aaec..32a2ae5b0 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -184,7 +184,8 @@ def handle_unsupported_array(self, # type: ignore[override] # Default handler, does its best to guess how to handle fields. info = self.get_common_dot_info(expr) - for field in attrs.fields(type(expr)): + # type-ignore-reason: https://github.com/python/mypy/issues/16254 + for field in attrs.fields(type(expr)): # type: ignore[misc] if field.name in info.fields: continue attr = getattr(expr, field.name) From 3b6fdade9d66dd53e3eb2d8b54ac607025df4645 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Sat, 14 Oct 2023 15:58:55 -0500 Subject: [PATCH 063/124] fix for fields --- pytato/array.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 0d1ff5b8a..b1a5ecd98 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -517,15 +517,15 @@ def T(self) -> Array: @memoize_method def __hash__(self) -> int: - attrs = [] - for field in self._fields: + attrs_filtered: List[Any] = [] + for field in attrs.fields(type(self)): # type: ignore[misc] attr = getattr(self, field) if field == "tags": attr = Taggable.__hash__(self) if isinstance(attr, dict): attr = frozenset(attr.items()) - attrs.append(attr) - return hash(tuple(attrs)) + attrs_filtered.append(attr) + return hash(tuple(attrs_filtered)) def __eq__(self, other: Any) -> bool: if self is other: @@ -1796,7 +1796,7 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: # # No need to dispatch to superclass: fields() automatically gives us # fields from the entire class hierarchy. - for f in fields(self.__class__): + for f in fields(self.__class__): # type: ignore[misc] key_builder.rec(key_hash, getattr(self, f.name)) def short_str(self, maxlen: int = 100) -> str: @@ -1831,7 +1831,7 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: # # No need to dispatch to superclass: fields() automatically gives us # fields from the entire class hierarchy. - for f in fields(self.__class__): + for f in fields(self.__class__): # type: ignore[misc] key_builder.rec(key_hash, getattr(self, f.name)) def short_str(self, maxlen: int = 100) -> str: From 8a390a583f8450a2a85889cd31016ac40fdde5c8 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Wed, 1 Nov 2023 16:59:38 -0500 Subject: [PATCH 064/124] Update comments a little --- pytato/transform/__init__.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 509718014..62c7a3d40 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -245,26 +245,14 @@ def rec(self, # type: ignore[override] expr: CopyMapperResultT) -> CopyMapperResultT: # type-ignore-reason: CachedMapper.rec's return type is imprecise return super().rec(expr) # type: ignore[return-value] - # DISABLED/REPLACED FROM MAIN - # # type-ignore-reason: specialized variant of super-class' rec method - # def rec(self, # type: ignore[override] - # expr: CopyMapperResultT) -> CopyMapperResultT: - # # type-ignore-reason: CachedMapper.rec's return type is imprecise - # return super().rec(expr) # type: ignore[return-value] - # ----- PREVIOUS CODE IN MAIN - # # type-ignore reason: incompatible type with Mapper.rec - # def __call__(self, expr: MappedT) -> MappedT: # type: ignore[override] - # return self.rec(expr) # type: ignore[no-any-return] - # --------------------------- - # ------- CURRENT CODE IN MAIN - # # type-ignore-reason: specialized variant of super-class' rec method - # def __call__(self, # type: ignore[override] - # expr: CopyMapperResultT) -> CopyMapperResultT: - # return self.rec(expr) - # ------------------------------------------------------ - # --------- CURRENT CODE IN CEESD - __call__ = rec - # ------------------------------- + + # REPLACED WITH NEW CODE FROM MAIN + # __call__ = rec + # ------------------------------- + # type-ignore-reason: specialized variant of super-class' rec method + def __call__(self, # type: ignore[override] + expr: CopyMapperResultT) -> CopyMapperResultT: + return self.rec(expr) def clone_for_callee(self: _SelfMapper) -> _SelfMapper: """ From 870849a106c4100e3f3b5a4929e3179cd5ece2de Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Thu, 9 Nov 2023 12:44:48 -0600 Subject: [PATCH 065/124] attempt to fix tag issue --- pytato/distributed/tags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/distributed/tags.py b/pytato/distributed/tags.py index 41ae3273c..95067ee66 100644 --- a/pytato/distributed/tags.py +++ b/pytato/distributed/tags.py @@ -106,7 +106,7 @@ def set_union( next_tag = base_tag assert isinstance(all_tags, frozenset) - for sym_tag in sorted(all_tags): + for sym_tag in sorted(all_tags, key=lambda tag: repr(tag)): sym_tag_to_int_tag[sym_tag] = next_tag next_tag += 1 From 060f864e2153c9c6d650ad919ff1440dc795e136 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 9 Nov 2023 16:53:19 -0600 Subject: [PATCH 066/124] number_distributed_tags: non-set, non-sorted numbering --- pytato/distributed/__init__.py | 2 +- pytato/distributed/tags.py | 46 ++++++++++------------------------ test/test_distributed.py | 10 ++++++-- 3 files changed, 22 insertions(+), 36 deletions(-) diff --git a/pytato/distributed/__init__.py b/pytato/distributed/__init__.py index 4354b2f0f..ee0ff39de 100644 --- a/pytato/distributed/__init__.py +++ b/pytato/distributed/__init__.py @@ -23,7 +23,7 @@ .. class:: CommTagType A type representing a communication tag. Communication tags must be - hashable and totally ordered (and hence comparable). + hashable. .. class:: ShapeType diff --git a/pytato/distributed/tags.py b/pytato/distributed/tags.py index 41ae3273c..4b97b2d50 100644 --- a/pytato/distributed/tags.py +++ b/pytato/distributed/tags.py @@ -31,7 +31,7 @@ """ -from typing import TYPE_CHECKING, Tuple, FrozenSet, Optional, TypeVar +from typing import TYPE_CHECKING, Tuple, TypeVar from pytato.distributed.partition import DistributedGraphPartition @@ -62,53 +62,33 @@ def number_distributed_tags( This is a potentially heavyweight MPI-collective operation on *mpi_communicator*. - - .. note:: - - This function requires that symbolic tags are comparable. """ - tags = frozenset({ + from pytools import flatten + + tags = tuple([ recv.comm_tag for part in partition.parts.values() for recv in part.name_to_recv_node.values() - } | { + ] + [ send.comm_tag for part in partition.parts.values() for sends in part.name_to_send_nodes.values() - for send in sends}) - - from mpi4py import MPI - - def set_union( - set_a: FrozenSet[T], set_b: FrozenSet[T], - mpi_data_type: Optional[MPI.Datatype]) -> FrozenSet[T]: - assert mpi_data_type is None - assert isinstance(set_a, frozenset) - assert isinstance(set_b, frozenset) - - return set_a | set_b + for send in sends]) root_rank = 0 - set_union_mpi_op = MPI.Op.Create( - # type ignore reason: mpi4py misdeclares op functions as returning - # None. - set_union, # type: ignore[arg-type] - commute=True) - try: - all_tags = mpi_communicator.reduce( - tags, set_union_mpi_op, root=root_rank) - finally: - set_union_mpi_op.Free() + all_tags = mpi_communicator.gather(tags, root=root_rank) if mpi_communicator.rank == root_rank: sym_tag_to_int_tag = {} next_tag = base_tag - assert isinstance(all_tags, frozenset) + assert isinstance(all_tags, list) + assert len(all_tags) == mpi_communicator.size - for sym_tag in sorted(all_tags): - sym_tag_to_int_tag[sym_tag] = next_tag - next_tag += 1 + for sym_tag in flatten(all_tags): # type: ignore[no-untyped-call] + if sym_tag not in sym_tag_to_int_tag: + sym_tag_to_int_tag[sym_tag] = next_tag + next_tag += 1 mpi_communicator.bcast((sym_tag_to_int_tag, next_tag), root=root_rank) else: diff --git a/test/test_distributed.py b/test/test_distributed.py index f7a8e5b4c..c36f4caae 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -266,7 +266,7 @@ def _do_test_distributed_execution_random_dag(ctx_factory): ntests = 10 for i in range(ntests): seed = 120 + i - print(f"Step {i} {seed}") + print(f"Step {i} {seed=}") # {{{ compute value with communication @@ -278,7 +278,13 @@ def gen_comm(rdagc): nonlocal comm_tag comm_tag += 1 - tag = (comm_tag, _RandomDAGTag) # noqa: B023 + + if comm_tag % 5 == 1: + tag = (comm_tag, frozenset([_RandomDAGTag, _RandomDAGTag])) + elif comm_tag % 5 == 2: + tag = (comm_tag, (_RandomDAGTag,)) + else: + tag = (comm_tag, _RandomDAGTag) # noqa: B023 inner = make_random_dag(rdagc) return pt.staple_distributed_send( From 65d014297d01bb767a11555f32b361193b5bfd28 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 9 Nov 2023 17:33:10 -0600 Subject: [PATCH 067/124] make the test a bit more difficult --- test/test_distributed.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_distributed.py b/test/test_distributed.py index c36f4caae..3a3e785fc 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -279,10 +279,10 @@ def gen_comm(rdagc): nonlocal comm_tag comm_tag += 1 - if comm_tag % 5 == 1: - tag = (comm_tag, frozenset([_RandomDAGTag, _RandomDAGTag])) + if comm_tag % 5 == 1 or 1: + tag = (comm_tag, frozenset([_RandomDAGTag, "a", comm_tag])) elif comm_tag % 5 == 2: - tag = (comm_tag, (_RandomDAGTag,)) + tag = (comm_tag, (_RandomDAGTag, "b")) else: tag = (comm_tag, _RandomDAGTag) # noqa: B023 From 3ebfcfd30f9a75e3d046e637712e53e835ea2719 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 14 Nov 2023 11:51:43 -0600 Subject: [PATCH 068/124] undo mypy ignores --- pytato/array.py | 6 +++--- pytato/stringifier.py | 7 ++----- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index b1a5ecd98..4511b4f96 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -518,7 +518,7 @@ def T(self) -> Array: @memoize_method def __hash__(self) -> int: attrs_filtered: List[Any] = [] - for field in attrs.fields(type(self)): # type: ignore[misc] + for field in attrs.fields(type(self)): attr = getattr(self, field) if field == "tags": attr = Taggable.__hash__(self) @@ -1796,7 +1796,7 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: # # No need to dispatch to superclass: fields() automatically gives us # fields from the entire class hierarchy. - for f in fields(self.__class__): # type: ignore[misc] + for f in fields(self.__class__): key_builder.rec(key_hash, getattr(self, f.name)) def short_str(self, maxlen: int = 100) -> str: @@ -1831,7 +1831,7 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: # # No need to dispatch to superclass: fields() automatically gives us # fields from the entire class hierarchy. - for f in fields(self.__class__): # type: ignore[misc] + for f in fields(self.__class__): key_builder.rec(key_hash, getattr(self, f.name)) def short_str(self, maxlen: int = 100) -> str: diff --git a/pytato/stringifier.py b/pytato/stringifier.py index 269c9a546..b5172e768 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -95,9 +95,7 @@ def _map_generic_array(self, expr: Array, depth: int) -> str: if depth > self.truncation_depth: return self.truncation_string - # type-ignore-reason: https://github.com/python/mypy/issues/16254 - fields = tuple(field.name - for field in attrs.fields(type(expr))) # type: ignore[misc] + fields = tuple(field.name for field in attrs.fields(type(expr))) if expr.ndim <= 1: # prettify: if ndim <=1 'expr.axes' would be trivial, @@ -155,8 +153,7 @@ def _get_field_val(field: str) -> str: return (f"{type(expr).__name__}(" + ", ".join(f"{field.name}={_get_field_val(field.name)}" - # type-ignore-reason: https://github.com/python/mypy/issues/16254 - for field in attrs.fields(type(expr))) # type: ignore[misc] + for field in attrs.fields(type(expr))) + ")") def map_loopy_call(self, expr: LoopyCall, depth: int) -> str: From eb1c052539e90412d5fad022ee38895309e0ea08 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 14 Nov 2023 17:29:46 -0600 Subject: [PATCH 069/124] rewrite to use a new field in Array, non_equality_tags --- pytato/array.py | 69 ++++++++++++++++++------------------ pytato/cmath.py | 3 +- pytato/equality.py | 14 ++------ pytato/stringifier.py | 2 ++ pytato/transform/__init__.py | 54 +++++++--------------------- pytato/utils.py | 10 ++++-- pytato/visualization/dot.py | 35 ++++++++++++++---- test/test_pytato.py | 21 ++++------- 8 files changed, 95 insertions(+), 113 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 4511b4f96..99f645d37 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -181,7 +181,7 @@ import pymbolic.primitives as prim from pymbolic import var from pytools import memoize_method -from pytools.tag import Tag, Taggable, ToTagSetConvertible +from pytools.tag import Tag, Taggable from pytato.scalar_expr import (ScalarType, SCALAR_CLASSES, ScalarExpression, IntegralT, @@ -448,6 +448,10 @@ class Array(Taggable): axes: AxesT = attrs.field(kw_only=True) tags: FrozenSet[Tag] = attrs.field(kw_only=True) + # These are automatically excluded from equality in EqualityComparer + non_equality_tags: FrozenSet[Tag] = attrs.field(kw_only=True, hash=False, + default=None) + _mapper_method: ClassVar[str] # disallow numpy arithmetic from taking precedence @@ -515,18 +519,6 @@ def T(self) -> Array: tags=_get_default_tags(), axes=_get_default_axes(self.ndim)) - @memoize_method - def __hash__(self) -> int: - attrs_filtered: List[Any] = [] - for field in attrs.fields(type(self)): - attr = getattr(self, field) - if field == "tags": - attr = Taggable.__hash__(self) - if isinstance(attr, dict): - attr = frozenset(attr.items()) - attrs_filtered.append(attr) - return hash(tuple(attrs_filtered)) - def __eq__(self, other: Any) -> bool: if self is other: return True @@ -681,12 +673,6 @@ def __repr__(self) -> str: from pytato.stringifier import Reprifier return Reprifier()(self) - def tagged(self, tags: ToTagSetConvertible) -> Array: - from pytato.equality import preprocess_tags_for_equality - from pytools.tag import normalize_tags - new_tags = preprocess_tags_for_equality(normalize_tags(tags)) - return super().tagged(new_tags) - # }}} @@ -1852,24 +1838,21 @@ def __repr__(self) -> str: return "\n " + "\n ".join([str(f) for f in self.frames]) -def _get_default_tags(existing_tags: Optional[FrozenSet[Tag]] = None) \ - -> FrozenSet[Tag]: +def _get_created_at_tag() -> Optional[Tag]: import traceback from pytato.tags import CreatedAt - from pytato import DEBUG_ENABLED + if not __debug__: + return None - # This has a significant overhead, so only enable it when PYTATO_DEBUG is - # enabled. - if DEBUG_ENABLED and ( - existing_tags is None - or not any(isinstance(tag, CreatedAt) for tag in existing_tags)): - frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) + frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) for s in traceback.extract_stack()) - c = CreatedAt(_PytatoStackSummary(frames)) - return frozenset((c,)) - else: - return frozenset() + + return CreatedAt(_PytatoStackSummary(frames)) + + +def _get_default_tags() -> FrozenSet[Tag]: + return frozenset() def matmul(x1: Array, x2: Array) -> Array: @@ -1931,6 +1914,7 @@ def roll(a: Array, shift: int, axis: Optional[int] = None) -> Array: return Roll(a, shift, axis, tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(a.ndim)) @@ -1953,6 +1937,7 @@ def transpose(a: Array, axes: Optional[Sequence[int]] = None) -> Array: return AxisPermutation(a, tuple(axes), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(a.ndim)) @@ -1987,6 +1972,7 @@ def stack(arrays: Sequence[Array], axis: int = 0) -> Array: return Stack(tuple(arrays), axis, tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(arrays[0].ndim+1)) @@ -2022,6 +2008,7 @@ def shape_except_axis(ary: Array) -> ShapeType: return Concatenate(tuple(arrays), axis, tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(arrays[0].ndim)) @@ -2085,6 +2072,7 @@ def reshape(array: Array, newshape: Union[int, Sequence[int]], return Reshape(array, tuple(newshape_explicit), order, tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(newshape_explicit))) @@ -2128,7 +2116,8 @@ def make_placeholder(name: str, f" expected {len(shape)}, got {len(axes)}.") return Placeholder(name=name, shape=shape, dtype=dtype, axes=axes, - tags=(tags | _get_default_tags(tags))) + tags=(tags | _get_default_tags()), + non_equality_tags=frozenset({_get_created_at_tag()}),) def make_size_param(name: str, @@ -2142,7 +2131,8 @@ def make_size_param(name: str, :param tags: implementation tags """ _check_identifier(name, optional=False) - return SizeParam(name, tags=(tags | _get_default_tags(tags))) + return SizeParam(name, tags=(tags | _get_default_tags()), + non_equality_tags=frozenset({_get_created_at_tag()}),) def make_data_wrapper(data: DataInterface, @@ -2181,7 +2171,8 @@ def make_data_wrapper(data: DataInterface, raise ValueError("'axes' dimensionality mismatch:" f" expected {len(shape)}, got {len(axes)}.") - return DataWrapper(data, shape, axes=axes, tags=(tags | _get_default_tags(tags))) + return DataWrapper(data, shape, axes=axes, tags=(tags | _get_default_tags()), + non_equality_tags=frozenset({_get_created_at_tag()}),) # }}} @@ -2212,6 +2203,7 @@ def full(shape: ConvertibleToShape, fill_value: ScalarType, return IndexLambda(expr=fill_value, shape=shape, dtype=dtype, bindings=immutabledict(), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(shape)), var_to_reduction_descr=immutabledict()) @@ -2258,6 +2250,7 @@ def eye(N: int, M: Optional[int] = None, k: int = 0, # noqa: N803 return IndexLambda(expr=parse(f"1 if ((_1 - _0) == {k}) else 0"), shape=(N, M), dtype=dtype, bindings=immutabledict({}), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(2), var_to_reduction_descr=immutabledict()) @@ -2353,6 +2346,7 @@ def arange(*args: Any, **kwargs: Any) -> Array: return IndexLambda(expr=start + Variable("_0") * step, shape=(size,), dtype=dtype, bindings=immutabledict(), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(1), var_to_reduction_descr=immutabledict()) @@ -2464,6 +2458,7 @@ def logical_not(x: ArrayOrScalar) -> Union[Array, bool]: dtype=np.dtype(np.bool_), bindings={"_in0": x}, tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(x.shape)), var_to_reduction_descr=immutabledict()) @@ -2520,6 +2515,7 @@ def where(condition: ArrayOrScalar, dtype=dtype, bindings=immutabledict(bindings), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(result_shape)), var_to_reduction_descr=immutabledict()) @@ -2618,6 +2614,7 @@ def make_index_lambda( shape=shape, dtype=dtype, tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(shape)), var_to_reduction_descr=immutabledict (processed_var_to_reduction_descr)) @@ -2703,6 +2700,7 @@ def broadcast_to(array: Array, shape: ShapeType) -> Array: dtype=array.dtype, bindings=immutabledict({"in": array}), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(shape)), var_to_reduction_descr=immutabledict()) @@ -2777,6 +2775,7 @@ def expand_dims(array: Array, axis: Union[Tuple[int, ...], int]) -> Array: return Reshape(array=array, newshape=tuple(new_shape), order="C", tags=(_get_default_tags() | {ExpandedDimsReshape(tuple(normalized_axis))}), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(new_shape))) # }}} diff --git a/pytato/cmath.py b/pytato/cmath.py index 38c520c7e..9f8e8c7fa 100644 --- a/pytato/cmath.py +++ b/pytato/cmath.py @@ -59,7 +59,7 @@ import pymbolic.primitives as prim from typing import Tuple, Optional from pytato.array import (Array, ArrayOrScalar, IndexLambda, _dtype_any, - _get_default_axes, _get_default_tags) + _get_default_axes, _get_default_tags, _get_created_at_tag) from pytato.scalar_expr import SCALAR_CLASSES from pymbolic import var from immutabledict import immutabledict @@ -115,6 +115,7 @@ def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...], tuple(sym_args)), shape=shape, dtype=ret_dtype, bindings=immutabledict(bindings), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(shape)), var_to_reduction_descr=immutabledict(), ) diff --git a/pytato/equality.py b/pytato/equality.py index 79b7427c4..76c21b4ed 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -24,7 +24,7 @@ THE SOFTWARE. """ -from typing import Any, Callable, Dict, TYPE_CHECKING, Tuple, Union, FrozenSet +from typing import Any, Callable, Dict, TYPE_CHECKING, Tuple, Union from pytato.array import (AdvancedIndexInContiguousAxes, AdvancedIndexInNoncontiguousAxes, AxisPermutation, BasicIndex, Concatenate, DataWrapper, Einsum, @@ -34,7 +34,7 @@ from pytato.function import Call, NamedCallResult, FunctionDefinition from pytools import memoize_method -from pytools.tag import Tag, IgnoredForEqualityTag, Taggable +from pytools.tag import Taggable if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult @@ -42,22 +42,12 @@ __doc__ = """ -.. autofunction:: preprocess_tags_for_equality .. autoclass:: EqualityComparer """ - ArrayOrNames = Union[Array, AbstractResultWithNamedArrays] -def preprocess_tags_for_equality(tags: FrozenSet[Tag]) -> FrozenSet[Tag]: - """Remove tags of :class:`~pytools.tag.IgnoredForEqualityTag` for equality - comparison.""" - return frozenset(tag - for tag in tags - if not isinstance(tag, IgnoredForEqualityTag)) - - # {{{ EqualityComparer class EqualityComparer: diff --git a/pytato/stringifier.py b/pytato/stringifier.py index b5172e768..9afb887c0 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -97,6 +97,8 @@ def _map_generic_array(self, expr: Array, depth: int) -> str: fields = tuple(field.name for field in attrs.fields(type(expr))) + fields = tuple(field for field in fields if field != "non_equality_tags") + if expr.ndim <= 1: # prettify: if ndim <=1 'expr.axes' would be trivial, # => don't print. diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 40b416750..9cd1ed16a 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -33,7 +33,7 @@ from immutabledict import immutabledict from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, List, Mapping, Iterable, Tuple, Optional, TYPE_CHECKING, - Hashable) + Hashable, cast) from pytato.array import ( Array, IndexLambda, Placeholder, Stack, Roll, @@ -82,7 +82,6 @@ .. autofunction:: copy_dict_of_named_arrays .. autofunction:: get_dependencies .. autofunction:: map_and_copy -.. autofunction:: remove_tags_of_type .. autofunction:: materialize_with_mpms .. autofunction:: deduplicate_data_wrappers .. automodule:: pytato.transform.lower_to_index_lambda @@ -207,8 +206,7 @@ def __init__(self) -> None: def get_cache_key(self, expr: ArrayOrNames) -> Hashable: return expr - # type-ignore-reason: incompatible with super class - def rec(self, expr: ArrayOrNames) -> CachedMapperT: # type: ignore[override] + def rec(self, expr: ArrayOrNames) -> CachedMapperT: key = self.get_cache_key(expr) try: return self._cache[key] @@ -219,9 +217,7 @@ def rec(self, expr: ArrayOrNames) -> CachedMapperT: # type: ignore[override] return result # type: ignore[no-any-return] if TYPE_CHECKING: - # type-ignore-reason: incompatible with super class - def __call__(self, expr: ArrayOrNames # type: ignore[override] - ) -> CachedMapperT: + def __call__(self, expr: ArrayOrNames) -> CachedMapperT: return self.rec(expr) # }}} @@ -241,15 +237,10 @@ class CopyMapper(CachedMapper[ArrayOrNames]): This does not copy the data of a :class:`pytato.array.DataWrapper`. """ if TYPE_CHECKING: - # type-ignore-reason: specialized variant of super-class' rec method - def rec(self, # type: ignore[override] - expr: CopyMapperResultT) -> CopyMapperResultT: - # type-ignore-reason: CachedMapper.rec's return type is imprecise - return super().rec(expr) # type: ignore[return-value] - - # type-ignore-reason: specialized variant of super-class' rec method - def __call__(self, # type: ignore[override] - expr: CopyMapperResultT) -> CopyMapperResultT: + def rec(self, expr: CopyMapperResultT) -> CopyMapperResultT: + return cast(CopyMapperResultT, super().rec(expr)) + + def __call__(self, expr: CopyMapperResultT) -> CopyMapperResultT: return self.rec(expr) def clone_for_callee(self: _SelfMapper) -> _SelfMapper: @@ -1193,17 +1184,13 @@ def __init__(self) -> None: super().__init__() self.topological_order: List[Array] = [] - # type-ignore-reason: dropped the extra `*args, **kwargs`. - def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] + def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) - # type-ignore-reason: dropped the extra `*args, **kwargs`. - def post_visit(self, expr: Any) -> None: # type: ignore[override] + def post_visit(self, expr: Any) -> None: self.topological_order.append(expr) - # type-ignore-reason: dropped the extra `*args, **kwargs`. - def map_function_definition(self, # type: ignore[override] - expr: FunctionDefinition) -> None: + def map_function_definition(self, expr: FunctionDefinition) -> None: # do nothing as it includes arrays from a different namespace. return @@ -1227,8 +1214,7 @@ def clone_for_callee(self: _SelfMapper) -> _SelfMapper: # than Mapper.__init__ and does not have map_fn return type(self)(self.map_fn) # type: ignore[call-arg,attr-defined] - # type-ignore-reason:incompatible with Mapper.rec() - def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] + def rec(self, expr: MappedT) -> MappedT: if expr in self._cache: # type-ignore-reason: parametric Mapping types aren't a thing return self._cache[expr] # type: ignore[return-value] @@ -1239,8 +1225,7 @@ def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] return result # type: ignore[return-value] if TYPE_CHECKING: - # type-ignore-reason: Mapper.__call__ returns Any - def __call__(self, expr: MappedT) -> MappedT: # type: ignore[override] + def __call__(self, expr: MappedT) -> MappedT: return self.rec(expr) # }}} @@ -1532,21 +1517,6 @@ def map_and_copy(expr: MappedT, return CachedMapAndCopyMapper(map_fn)(expr) -def remove_tags_of_type(tag_types: Union[type, Tuple[type]], expr: ArrayOrNames - ) -> ArrayOrNames: - def process_node(expr: ArrayOrNames) -> ArrayOrNames: - if isinstance(expr, Array): - return expr.copy(tags=frozenset({ - tag for tag in expr.tags - if not isinstance(tag, tag_types)})) - elif isinstance(expr, AbstractResultWithNamedArrays): - return expr - else: - raise AssertionError(type(expr)) - - return map_and_copy(expr, process_node) - - def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: r""" Materialize nodes in *expr* with MPMS materialization strategy. diff --git a/pytato/utils.py b/pytato/utils.py index 3937711f6..212197a93 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -179,7 +179,8 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 ) -> ArrayOrScalar: - from pytato.array import _get_default_axes, _get_default_tags + from pytato.array import (_get_default_axes, _get_default_tags, + _get_created_at_tag) if isinstance(a1, SCALAR_CLASSES): a1 = np.dtype(type(a1)).type(a1) @@ -207,6 +208,7 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, dtype=result_dtype, bindings=immutabledict(bindings), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), var_to_reduction_descr=immutabledict(), axes=_get_default_axes(len(result_shape))) @@ -475,7 +477,8 @@ def _normalized_slice_len(slice_: NormalizedSlice) -> ShapeComponent: def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Array: from pytato.diagnostic import CannotBroadcastError - from pytato.array import _get_default_axes, _get_default_tags + from pytato.array import (_get_default_axes, _get_default_tags, + _get_created_at_tag) # {{{ handle ellipsis @@ -562,6 +565,7 @@ def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Arra ary, tuple(normalized_indices), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: @@ -569,6 +573,7 @@ def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Arra ary, tuple(normalized_indices), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: @@ -576,6 +581,7 @@ def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Arra return BasicIndex(ary, tuple(normalized_indices), tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()}), axes=_get_default_axes( len([idx for idx in normalized_indices diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 32a2ae5b0..2cddaa2c2 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -171,9 +171,11 @@ def __init__(self) -> None: def get_common_dot_info(self, expr: Array) -> _DotNodeInfo: title = type(expr).__name__ fields = {"addr": hex(id(expr)), - "shape": stringify_shape(expr.shape), - "dtype": str(expr.dtype), - "tags": stringify_tags(expr.tags)} + "shape": stringify_shape(expr.shape), + "dtype": str(expr.dtype), + "tags": stringify_tags(expr.tags), + "non_equality_tags": expr.non_equality_tags, + } edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] = {} return _DotNodeInfo(title, fields, edges) @@ -188,6 +190,7 @@ def handle_unsupported_array(self, # type: ignore[override] for field in attrs.fields(type(expr)): # type: ignore[misc] if field.name in info.fields: continue + attr = getattr(expr, field.name) if isinstance(attr, Array): @@ -356,15 +359,31 @@ def dot_escape(s: str) -> str: return html.escape(s.replace("\\", "\\\\").replace(" ", "_")) +def dot_escape_leave_space(s: str) -> str: + # "\" and HTML are significant in graphviz. + return html.escape(s.replace("\\", "\\\\")) + + # {{{ emit helpers +def _stringify_created_at(non_equality_tags: frozenset[Tag]) -> str: + from pytato.tags import CreatedAt + for tag in non_equality_tags: + if isinstance(tag, CreatedAt): + return tag.traceback.short_str() + + return "" + + def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, str], dot_node_id: str, color: str = "white") -> None: td_attrib = 'border="0"' table_attrib = 'border="0" cellborder="1" cellspacing="0"' - rows = ['%s' - % (td_attrib, dot_escape(title))] + rows = [f"{dot_escape(title)}"] + + non_equality_tags = fields.pop("non_equality_tags", frozenset()) + tooltip = dot_escape_leave_space(_stringify_created_at(non_equality_tags)) for name, field in fields.items(): field_content = dot_escape(field).replace("\n", "
") @@ -372,8 +391,10 @@ def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, str], f"{dot_escape(name)}:" f"{field_content}" ) - table = "\n%s
" % (table_attrib, "".join(rows)) - emit("%s [label=<%s> style=filled fillcolor=%s]" % (dot_node_id, table, color)) + + table = f"\n{''.join(rows)}
" + emit(f"{dot_node_id} [label=<{table}> style=filled fillcolor={color} " + f'tooltip="{tooltip}"]') def _emit_name_cluster( diff --git a/test/test_pytato.py b/test/test_pytato.py index 0583bdce0..ad865c4ef 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -807,9 +807,6 @@ def test_created_at(): a = pt.make_placeholder("a", (10, 10), "float64") b = pt.make_placeholder("b", (10, 10), "float64") - _prev_debug_enabled = pt.DEBUG_ENABLED - pt.DEBUG_ENABLED = True - # res1 and res2 are defined on different lines and should have different # CreatedAt tags. res1 = a+b @@ -821,23 +818,21 @@ def test_created_at(): # {{{ Check that CreatedAt tags are handled correctly for equality - from pytato.equality import preprocess_tags_for_equality - assert res1 == res2 == res3 == res4 - assert res1.tags != res2.tags - assert res3.tags == res4.tags + assert res1.non_equality_tags != res2.non_equality_tags + assert res3.non_equality_tags == res4.non_equality_tags - assert (preprocess_tags_for_equality(res1.tags) - == preprocess_tags_for_equality(res2.tags)) - assert (preprocess_tags_for_equality(res3.tags) - == preprocess_tags_for_equality(res4.tags)) + assert res1.tags == res2.tags + assert res3.tags == res4.tags # }}} from pytato.tags import CreatedAt - created_tag = res1.tags_of_type(CreatedAt) + created_tag = frozenset({tag + for tag in res1.non_equality_tags + if isinstance(tag, CreatedAt)}) assert len(created_tag) == 1 @@ -866,8 +861,6 @@ def test_created_at(): assert "test_created_at" in s assert "a+b" in s - pt.DEBUG_ENABLED = _prev_debug_enabled - # }}} From a5cec505af2390c0d35a789d04edeae10ba4e8a6 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 14 Nov 2023 17:52:00 -0600 Subject: [PATCH 070/124] misc fixes --- pytato/array.py | 2 +- pytato/utils.py | 4 +--- pytato/visualization/dot.py | 2 +- test/test_pytato.py | 18 ++++++------------ 4 files changed, 9 insertions(+), 17 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 99f645d37..bdb4714db 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -450,7 +450,7 @@ class Array(Taggable): # These are automatically excluded from equality in EqualityComparer non_equality_tags: FrozenSet[Tag] = attrs.field(kw_only=True, hash=False, - default=None) + default=frozenset()) _mapper_method: ClassVar[str] diff --git a/pytato/utils.py b/pytato/utils.py index 212197a93..ad2b77377 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -253,8 +253,6 @@ def dim_to_index_lambda_components(expr: ShapeComponent, .. testsetup:: >>> import pytato as pt - >>> from pytato.transform import remove_tags_of_type - >>> from pytato.tags import CreatedAt >>> from pytato.utils import dim_to_index_lambda_components >>> from pytools import UniqueNameGenerator @@ -264,7 +262,7 @@ def dim_to_index_lambda_components(expr: ShapeComponent, >>> expr, bnds = dim_to_index_lambda_components(3*n+8, UniqueNameGenerator()) >>> print(expr) 3*_in + 8 - >>> {"_in": remove_tags_of_type(CreatedAt, bnds["_in"])} + >>> bnds {'_in': SizeParam(name='n')} """ if isinstance(expr, INT_CLASSES): diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 2cddaa2c2..3984d986e 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -366,7 +366,7 @@ def dot_escape_leave_space(s: str) -> str: # {{{ emit helpers -def _stringify_created_at(non_equality_tags: frozenset[Tag]) -> str: +def _stringify_created_at(non_equality_tags: FrozenSet[Tag]) -> str: from pytato.tags import CreatedAt for tag in non_equality_tags: if isinstance(tag, CreatedAt): diff --git a/test/test_pytato.py b/test/test_pytato.py index ad865c4ef..8976d76d4 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -27,8 +27,6 @@ import sys -from typing import cast - import numpy as np import pytest import attrs @@ -451,16 +449,12 @@ def test_array_dot_repr(): x = pt.make_placeholder("x", (10, 4), np.int64) y = pt.make_placeholder("y", (10, 4), np.int64) - def _assert_stripped_repr(ary: pt.Array, expected_repr: str): - from pytato.transform import remove_tags_of_type - from pytato.tags import CreatedAt - ary = cast(pt.Array, remove_tags_of_type(CreatedAt, ary)) - + def _assert_repr(ary: pt.Array, expected_repr: str): expected_str = "".join([c for c in expected_repr if c not in [" ", "\n"]]) result_str = "".join([c for c in repr(ary)if c not in [" ", "\n"]]) assert expected_str == result_str - _assert_stripped_repr( + _assert_repr( 3*x + 4*y, """ IndexLambda( @@ -489,7 +483,7 @@ def _assert_stripped_repr(ary: pt.Array, expected_repr: str): dtype='int64', name='y')})})""") - _assert_stripped_repr( + _assert_repr( pt.roll(x.reshape(2, 20).reshape(-1), 3), """ Roll( @@ -501,7 +495,7 @@ def _assert_stripped_repr(ary: pt.Array, expected_repr: str): newshape=(40), order='C'), shift=3, axis=0)""") - _assert_stripped_repr(y * pt.not_equal(x, 3), + _assert_repr(y * pt.not_equal(x, 3), """ IndexLambda( shape=(10, 4), @@ -521,7 +515,7 @@ def _assert_stripped_repr(ary: pt.Array, expected_repr: str): bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='x')})})""") - _assert_stripped_repr( + _assert_repr( x[y[:, 2:3], x[2, :]], """ AdvancedIndexInContiguousAxes( @@ -536,7 +530,7 @@ def _assert_stripped_repr(ary: pt.Array, expected_repr: str): name='x'), indices=(2, NormalizedSlice(start=0, stop=4, step=1)))))""") - _assert_stripped_repr( + _assert_repr( pt.stack([x[y[:, 2:3], x[2, :]].T, y[x[:, 2:3], y[2, :]].T]), """ Stack( From d9898c9103314afbbf9117c920fc50195d20bc4f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 14 Nov 2023 18:20:41 -0600 Subject: [PATCH 071/124] undo some unecessary changes --- pytato/equality.py | 38 +++++++++++++++++------------------- pytato/stringifier.py | 7 +++++-- pytato/transform/__init__.py | 38 ++++++++++++++++++++++++------------ test/test_pytato.py | 4 ++-- 4 files changed, 51 insertions(+), 36 deletions(-) diff --git a/pytato/equality.py b/pytato/equality.py index 76c21b4ed..42c2978cd 100644 --- a/pytato/equality.py +++ b/pytato/equality.py @@ -34,17 +34,15 @@ from pytato.function import Call, NamedCallResult, FunctionDefinition from pytools import memoize_method -from pytools.tag import Taggable - if TYPE_CHECKING: from pytato.loopy import LoopyCall, LoopyCallResult from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder - __doc__ = """ .. autoclass:: EqualityComparer """ + ArrayOrNames = Union[Array, AbstractResultWithNamedArrays] @@ -107,14 +105,14 @@ def map_placeholder(self, expr1: Placeholder, expr2: Any) -> bool: and expr1.name == expr2.name and expr1.shape == expr2.shape and expr1.dtype == expr2.dtype - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) def map_size_param(self, expr1: SizeParam, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.name == expr2.name - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) @@ -133,7 +131,7 @@ def map_index_lambda(self, expr1: IndexLambda, expr2: Any) -> bool: if isinstance(dim1, Array) else dim1 == dim2 for dim1, dim2 in zip(expr1.shape, expr2.shape)) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes and expr1.var_to_reduction_descr == expr2.var_to_reduction_descr ) @@ -144,7 +142,7 @@ def map_stack(self, expr1: Stack, expr2: Any) -> bool: and len(expr1.arrays) == len(expr2.arrays) and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) @@ -154,7 +152,7 @@ def map_concatenate(self, expr1: Concatenate, expr2: Any) -> bool: and len(expr1.arrays) == len(expr2.arrays) and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.arrays, expr2.arrays)) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) @@ -163,7 +161,7 @@ def map_roll(self, expr1: Roll, expr2: Any) -> bool: and expr1.axis == expr2.axis and expr1.shift == expr2.shift and self.rec(expr1.array, expr2.array) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) @@ -171,7 +169,7 @@ def map_axis_permutation(self, expr1: AxisPermutation, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.axis_permutation == expr2.axis_permutation and self.rec(expr1.array, expr2.array) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) @@ -184,7 +182,7 @@ def _map_index_base(self, expr1: IndexBase, expr2: Any) -> bool: and isinstance(idx2, Array)) else idx1 == idx2 for idx1, idx2 in zip(expr1.indices, expr2.indices)) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) @@ -207,7 +205,7 @@ def map_reshape(self, expr1: Reshape, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and expr1.newshape == expr2.newshape and self.rec(expr1.array, expr2.array) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes ) @@ -217,7 +215,7 @@ def map_einsum(self, expr1: Einsum, expr2: Any) -> bool: and all(self.rec(ary1, ary2) for ary1, ary2 in zip(expr1.args, expr2.args)) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes and expr1.redn_axis_to_redn_descr == expr2.redn_axis_to_redn_descr ) @@ -225,7 +223,7 @@ def map_einsum(self, expr1: Einsum, expr2: Any) -> bool: def map_named_array(self, expr1: NamedArray, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and self.rec(expr1._container, expr2._container) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes and expr1.name == expr2.name) @@ -239,13 +237,13 @@ def map_loopy_call(self, expr1: LoopyCall, expr2: Any) -> bool: if isinstance(bnd, Array) else bnd == expr2.bindings[name] for name, bnd in expr1.bindings.items()) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags ) def map_loopy_call_result(self, expr1: LoopyCallResult, expr2: Any) -> bool: return (expr1.__class__ is expr2.__class__ and self.rec(expr1._container, expr2._container) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags and expr1.axes == expr2.axes and expr1.name == expr2.name) @@ -254,7 +252,7 @@ def map_dict_of_named_arrays(self, expr1: DictOfNamedArrays, expr2: Any) -> bool and frozenset(expr1._data.keys()) == frozenset(expr2._data.keys()) and all(self.rec(expr1._data[name], expr2._data[name]) for name in expr1._data) - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags ) def map_distributed_send_ref_holder( @@ -264,8 +262,8 @@ def map_distributed_send_ref_holder( and self.rec(expr1.passthrough_data, expr2.passthrough_data) and expr1.send.dest_rank == expr2.send.dest_rank and expr1.send.comm_tag == expr2.send.comm_tag - and Taggable.__eq__(expr1.send, expr2.send) - and Taggable.__eq__(expr1, expr2) + and expr1.send.tags == expr2.send.tags + and expr1.tags == expr2.tags ) def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: @@ -274,7 +272,7 @@ def map_distributed_recv(self, expr1: DistributedRecv, expr2: Any) -> bool: and expr1.comm_tag == expr2.comm_tag and expr1.shape == expr2.shape and expr1.dtype == expr2.dtype - and Taggable.__eq__(expr1, expr2) + and expr1.tags == expr2.tags ) @memoize_method diff --git a/pytato/stringifier.py b/pytato/stringifier.py index 9afb887c0..4fc0fd8bf 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -95,7 +95,9 @@ def _map_generic_array(self, expr: Array, depth: int) -> str: if depth > self.truncation_depth: return self.truncation_string - fields = tuple(field.name for field in attrs.fields(type(expr))) + # type-ignore-reason: https://github.com/python/mypy/issues/16254 + fields = tuple(field.name + for field in attrs.fields(type(expr))) # type: ignore[misc] fields = tuple(field for field in fields if field != "non_equality_tags") @@ -155,7 +157,8 @@ def _get_field_val(field: str) -> str: return (f"{type(expr).__name__}(" + ", ".join(f"{field.name}={_get_field_val(field.name)}" - for field in attrs.fields(type(expr))) + # type-ignore-reason: https://github.com/python/mypy/issues/16254 + for field in attrs.fields(type(expr))) # type: ignore[misc] + ")") def map_loopy_call(self, expr: LoopyCall, depth: int) -> str: diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 9cd1ed16a..e759a7e82 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -33,7 +33,7 @@ from immutabledict import immutabledict from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, List, Mapping, Iterable, Tuple, Optional, TYPE_CHECKING, - Hashable, cast) + Hashable) from pytato.array import ( Array, IndexLambda, Placeholder, Stack, Roll, @@ -206,7 +206,8 @@ def __init__(self) -> None: def get_cache_key(self, expr: ArrayOrNames) -> Hashable: return expr - def rec(self, expr: ArrayOrNames) -> CachedMapperT: + # type-ignore-reason: incompatible with super class + def rec(self, expr: ArrayOrNames) -> CachedMapperT: # type: ignore[override] key = self.get_cache_key(expr) try: return self._cache[key] @@ -217,7 +218,9 @@ def rec(self, expr: ArrayOrNames) -> CachedMapperT: return result # type: ignore[no-any-return] if TYPE_CHECKING: - def __call__(self, expr: ArrayOrNames) -> CachedMapperT: + # type-ignore-reason: incompatible with super class + def __call__(self, expr: ArrayOrNames # type: ignore[override] + ) -> CachedMapperT: return self.rec(expr) # }}} @@ -237,10 +240,15 @@ class CopyMapper(CachedMapper[ArrayOrNames]): This does not copy the data of a :class:`pytato.array.DataWrapper`. """ if TYPE_CHECKING: - def rec(self, expr: CopyMapperResultT) -> CopyMapperResultT: - return cast(CopyMapperResultT, super().rec(expr)) - - def __call__(self, expr: CopyMapperResultT) -> CopyMapperResultT: + # type-ignore-reason: specialized variant of super-class' rec method + def rec(self, # type: ignore[override] + expr: CopyMapperResultT) -> CopyMapperResultT: + # type-ignore-reason: CachedMapper.rec's return type is imprecise + return super().rec(expr) # type: ignore[return-value] + + # type-ignore-reason: specialized variant of super-class' rec method + def __call__(self, # type: ignore[override] + expr: CopyMapperResultT) -> CopyMapperResultT: return self.rec(expr) def clone_for_callee(self: _SelfMapper) -> _SelfMapper: @@ -1184,13 +1192,17 @@ def __init__(self) -> None: super().__init__() self.topological_order: List[Array] = [] - def get_cache_key(self, expr: ArrayOrNames) -> int: + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def get_cache_key(self, expr: ArrayOrNames) -> int: # type: ignore[override] return id(expr) - def post_visit(self, expr: Any) -> None: + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def post_visit(self, expr: Any) -> None: # type: ignore[override] self.topological_order.append(expr) - def map_function_definition(self, expr: FunctionDefinition) -> None: + # type-ignore-reason: dropped the extra `*args, **kwargs`. + def map_function_definition(self, # type: ignore[override] + expr: FunctionDefinition) -> None: # do nothing as it includes arrays from a different namespace. return @@ -1214,7 +1226,8 @@ def clone_for_callee(self: _SelfMapper) -> _SelfMapper: # than Mapper.__init__ and does not have map_fn return type(self)(self.map_fn) # type: ignore[call-arg,attr-defined] - def rec(self, expr: MappedT) -> MappedT: + # type-ignore-reason:incompatible with Mapper.rec() + def rec(self, expr: MappedT) -> MappedT: # type: ignore[override] if expr in self._cache: # type-ignore-reason: parametric Mapping types aren't a thing return self._cache[expr] # type: ignore[return-value] @@ -1225,7 +1238,8 @@ def rec(self, expr: MappedT) -> MappedT: return result # type: ignore[return-value] if TYPE_CHECKING: - def __call__(self, expr: MappedT) -> MappedT: + # type-ignore-reason: Mapper.__call__ returns Any + def __call__(self, expr: MappedT) -> MappedT: # type: ignore[override] return self.rec(expr) # }}} diff --git a/test/test_pytato.py b/test/test_pytato.py index 8976d76d4..be122a925 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -376,7 +376,7 @@ def test_linear_complexity_inequality(): from pytato.equality import EqualityComparer from numpy.random import default_rng - def construct_intestine_graph(depth=90, seed=0): + def construct_intestine_graph(depth=100, seed=0): rng = default_rng(seed) x = pt.make_placeholder("x", shape=(10,), dtype=float) @@ -650,7 +650,7 @@ def post_visit(self, expr): def test_tag_user_nodes_linear_complexity(): from numpy.random import default_rng - def construct_intestine_graph(depth=90, seed=0): + def construct_intestine_graph(depth=100, seed=0): rng = default_rng(seed) x = pt.make_placeholder("x", shape=(10,), dtype=float) y = x From c5c8920ef1f6f97b2ae3e71e7abbbf4af06f7d3c Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 15 Nov 2023 15:16:54 -0600 Subject: [PATCH 072/124] more misc fixes --- pytato/array.py | 2 +- pytato/transform/__init__.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index bdb4714db..fb32fb7f5 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1902,7 +1902,7 @@ def roll(a: Array, shift: int, axis: Optional[int] = None) -> Array: if axis is None: if a.ndim > 1: raise NotImplementedError( - "shifing along more than one dimension is unsupported") + "shifting along more than one dimension is unsupported") else: axis = 0 diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index e759a7e82..e4316995b 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -419,8 +419,8 @@ def map_function_definition(self, def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: return Call(self.map_function_definition(expr.function), immutabledict({name: self.rec(bnd) - for name, bnd in expr.bindings.items()}), - tags=expr.tags, + for name, bnd in sorted(expr.bindings.items())}), + tags=expr.tags ) def map_named_call_result(self, expr: NamedCallResult) -> Array: @@ -642,7 +642,7 @@ def map_call(self, expr: Call, *args: Any, **kwargs: Any) -> AbstractResultWithNamedArrays: return Call(self.map_function_definition(expr.function, *args, **kwargs), immutabledict({name: self.rec(bnd, *args, **kwargs) - for name, bnd in expr.bindings.items()}), + for name, bnd in sorted(expr.bindings.items())}), tags=expr.tags, ) @@ -1540,7 +1540,7 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: - MPMS materialization strategy is a greedy materialization algorithm in which any node with more than 1 materialized predecessors and more than - 1 successors is materialized. + 1 successor is materialized. - Materializing here corresponds to tagging a node with :class:`~pytato.tags.ImplStored`. - Does not attempt to materialize sub-expressions in From 4ec3cbf60dffa9a38e1518269db31dc9c7daa6fc Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 15 Nov 2023 15:48:54 -0600 Subject: [PATCH 073/124] copymapper, tests --- pytato/transform/__init__.py | 78 ++++++++++++++++++++++++------------ test/test_pytato.py | 37 +++++++++++++++++ 2 files changed, 89 insertions(+), 26 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index e4316995b..f04f68b11 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -275,7 +275,8 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: bindings=bindings, axes=expr.axes, var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_placeholder(self, expr: Placeholder) -> Array: assert expr.name is not None @@ -283,35 +284,41 @@ def map_placeholder(self, expr: Placeholder) -> Array: shape=self.rec_idx_or_size_tuple(expr.shape), dtype=expr.dtype, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack) -> Array: arrays = tuple(self.rec(arr) for arr in expr.arrays) - return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags) + return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_concatenate(self, expr: Concatenate) -> Array: arrays = tuple(self.rec(arr) for arr in expr.arrays) return Concatenate(arrays=arrays, axis=expr.axis, - axes=expr.axes, tags=expr.tags) + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_roll(self, expr: Roll) -> Array: return Roll(array=self.rec(expr.array), shift=expr.shift, axis=expr.axis, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation) -> Array: return AxisPermutation(array=self.rec(expr.array), axis_permutation=expr.axis_permutation, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase) -> Array: return type(expr)(self.rec(expr.array), indices=self.rec_idx_or_size_tuple(expr.indices), axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_basic_index(self, expr: BasicIndex) -> Array: return self._map_index_base(expr) @@ -331,7 +338,8 @@ def map_data_wrapper(self, expr: DataWrapper) -> Array: data=expr.data, shape=self.rec_idx_or_size_tuple(expr.shape), axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_size_param(self, expr: SizeParam) -> Array: assert expr.name is not None @@ -343,13 +351,15 @@ def map_einsum(self, expr: Einsum) -> Array: axes=expr.axes, redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, index_to_access_descr=expr.index_to_access_descr, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_named_array(self, expr: NamedArray) -> Array: return type(expr)(self.rec(expr._container), expr.name, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_dict_of_named_arrays(self, expr: DictOfNamedArrays) -> DictOfNamedArrays: @@ -377,14 +387,16 @@ def map_loopy_call_result(self, expr: LoopyCallResult) -> Array: container=rec_container, name=expr.name, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_reshape(self, expr: Reshape) -> Array: return Reshape(self.rec(expr.array), newshape=self.rec_idx_or_size_tuple(expr.newshape), order=expr.order, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_distributed_send_ref_holder( self, expr: DistributedSendRefHolder) -> Array: @@ -400,7 +412,8 @@ def map_distributed_recv(self, expr: DistributedRecv) -> Array: return DistributedRecv( src_rank=expr.src_rank, comm_tag=expr.comm_tag, shape=self.rec_idx_or_size_tuple(expr.shape), - dtype=expr.dtype, tags=expr.tags, axes=expr.axes) + dtype=expr.dtype, tags=expr.tags, axes=expr.axes, + non_equality_tags=expr.non_equality_tags) @memoize_method def map_function_definition(self, @@ -492,7 +505,8 @@ def map_index_lambda(self, expr: IndexLambda, bindings=bindings, axes=expr.axes, var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_placeholder(self, expr: Placeholder, *args: Any, **kwargs: Any) -> Array: assert expr.name is not None @@ -501,37 +515,43 @@ def map_placeholder(self, expr: Placeholder, *args: Any, **kwargs: Any) -> Array *args, **kwargs), dtype=expr.dtype, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_stack(self, expr: Stack, *args: Any, **kwargs: Any) -> Array: arrays = tuple(self.rec(arr, *args, **kwargs) for arr in expr.arrays) - return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags) + return Stack(arrays=arrays, axis=expr.axis, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_concatenate(self, expr: Concatenate, *args: Any, **kwargs: Any) -> Array: arrays = tuple(self.rec(arr, *args, **kwargs) for arr in expr.arrays) return Concatenate(arrays=arrays, axis=expr.axis, - axes=expr.axes, tags=expr.tags) + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_roll(self, expr: Roll, *args: Any, **kwargs: Any) -> Array: return Roll(array=self.rec(expr.array, *args, **kwargs), shift=expr.shift, axis=expr.axis, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_axis_permutation(self, expr: AxisPermutation, *args: Any, **kwargs: Any) -> Array: return AxisPermutation(array=self.rec(expr.array, *args, **kwargs), axis_permutation=expr.axis_permutation, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def _map_index_base(self, expr: IndexBase, *args: Any, **kwargs: Any) -> Array: return type(expr)(self.rec(expr.array, *args, **kwargs), indices=self.rec_idx_or_size_tuple(expr.indices, *args, **kwargs), axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_basic_index(self, expr: BasicIndex, *args: Any, **kwargs: Any) -> Array: return self._map_index_base(expr, *args, **kwargs) @@ -555,7 +575,8 @@ def map_data_wrapper(self, expr: DataWrapper, data=expr.data, shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_size_param(self, expr: SizeParam, *args: Any, **kwargs: Any) -> Array: assert expr.name is not None @@ -567,13 +588,15 @@ def map_einsum(self, expr: Einsum, *args: Any, **kwargs: Any) -> Array: axes=expr.axes, redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, index_to_access_descr=expr.index_to_access_descr, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_named_array(self, expr: NamedArray, *args: Any, **kwargs: Any) -> Array: return type(expr)(self.rec(expr._container, *args, **kwargs), expr.name, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_dict_of_named_arrays(self, expr: DictOfNamedArrays, *args: Any, **kwargs: Any) -> DictOfNamedArrays: @@ -613,7 +636,8 @@ def map_reshape(self, expr: Reshape, *args, **kwargs), order=expr.order, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder, *args: Any, **kwargs: Any) -> Array: @@ -623,14 +647,16 @@ def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder, dest_rank=expr.send.dest_rank, comm_tag=expr.send.comm_tag), self.rec(expr.passthrough_data, *args, **kwargs), - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) def map_distributed_recv(self, expr: DistributedRecv, *args: Any, **kwargs: Any) -> Array: return DistributedRecv( src_rank=expr.src_rank, comm_tag=expr.comm_tag, shape=self.rec_idx_or_size_tuple(expr.shape, *args, **kwargs), - dtype=expr.dtype, tags=expr.tags, axes=expr.axes) + dtype=expr.dtype, tags=expr.tags, axes=expr.axes, + non_equality_tags=expr.non_equality_tags) def map_function_definition(self, expr: FunctionDefinition, *args: Any, **kwargs: Any) -> FunctionDefinition: diff --git a/test/test_pytato.py b/test/test_pytato.py index be122a925..8c0461685 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -857,6 +857,43 @@ def test_created_at(): # }}} + # {{{ Make sure only a single CreatedAt tag is created + + old_tag = tag + + res1 = res1 + res2 + + created_tag = frozenset({tag + for tag in res1.non_equality_tags + if isinstance(tag, CreatedAt)}) + + assert len(created_tag) == 1 + + tag, = created_tag + + # Tag should be recreated + assert tag != old_tag + + # }}} + + # {{{ Make sure that copying preserves the tag + + old_tag = tag + + res1_new = pt.transform.map_and_copy(res1, lambda x: x) + + created_tag = frozenset({tag + for tag in res1_new.non_equality_tags + if isinstance(tag, CreatedAt)}) + + assert len(created_tag) == 1 + + tag, = created_tag + + assert old_tag == tag + + # }}} + def test_pickling_and_unpickling_is_equal(): from testlib import RandomDAGContext, make_random_dag From 176595dd30381f71dbda40066084c1641be76508 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 17 Nov 2023 10:00:47 -0600 Subject: [PATCH 074/124] explicitly enable/disable traceback --- pytato/__init__.py | 2 ++ pytato/array.py | 16 +++++++++++++--- test/test_pytato.py | 16 ++++++++++++++++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/pytato/__init__.py b/pytato/__init__.py index 572e4a7ab..5255820af 100644 --- a/pytato/__init__.py +++ b/pytato/__init__.py @@ -113,6 +113,7 @@ def set_debug_enabled(flag: bool) -> None: rewrite_einsums_with_no_broadcasts) from pytato.transform.metadata import unify_axes_tags from pytato.function import trace_call +from pytato.array import enable_traceback_tag __all__ = ( "dtype", @@ -183,4 +184,5 @@ def set_debug_enabled(flag: bool) -> None: # sub-modules "analysis", "tags", "transform", "function", + "enable_traceback_tag", ) diff --git a/pytato/array.py b/pytato/array.py index fb32fb7f5..bcfaa3a3a 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -145,11 +145,12 @@ .. autoclass:: EinsumReductionAxis .. autoclass:: NormalizedSlice -Internal classes for traceback ------------------------------- +Traceback functionality +----------------------- Please consider these undocumented and subject to change at any time. +.. autofunction:: enable_traceback_tag .. class:: _PytatoFrameSummary .. class:: _PytatoStackSummary @@ -1838,11 +1839,20 @@ def __repr__(self) -> str: return "\n " + "\n ".join([str(f) for f in self.frames]) +_ENABLE_TRACEBACK_TAG = False + + +def enable_traceback_tag(enable: bool = True) -> None: + """Enable or disable the traceback tag.""" + global _ENABLE_TRACEBACK_TAG + _ENABLE_TRACEBACK_TAG = enable + + def _get_created_at_tag() -> Optional[Tag]: import traceback from pytato.tags import CreatedAt - if not __debug__: + if not _ENABLE_TRACEBACK_TAG: return None frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) diff --git a/test/test_pytato.py b/test/test_pytato.py index 8c0461685..674f640d6 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -798,6 +798,8 @@ def test_einsum_dot_axes_has_correct_dim(): def test_created_at(): + pt.enable_traceback_tag() + a = pt.make_placeholder("a", (10, 10), "float64") b = pt.make_placeholder("b", (10, 10), "float64") @@ -894,6 +896,20 @@ def test_created_at(): # }}} + # {{{ Test disabling traceback creation + + pt.enable_traceback_tag(False) + + a = pt.make_placeholder("a", (10, 10), "float64") + + created_tag = frozenset({tag + for tag in a.non_equality_tags + if isinstance(tag, CreatedAt)}) + + assert len(created_tag) == 0 + + # }}} + def test_pickling_and_unpickling_is_equal(): from testlib import RandomDAGContext, make_random_dag From 40557e9f13ac3d7167dac76f93e87a2d70929b84 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 17 Nov 2023 10:41:15 -0600 Subject: [PATCH 075/124] add hash test --- test/test_pytato.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 674f640d6..127f0b7dc 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -812,15 +812,20 @@ def test_created_at(): # CreatedAt tags. res3 = a+b; res4 = a+b # noqa: E702 - # {{{ Check that CreatedAt tags are handled correctly for equality + # {{{ Check that CreatedAt tags are handled correctly for equality/hashing assert res1 == res2 == res3 == res4 + assert hash(res1) == hash(res2) == hash(res3) == hash(res4) assert res1.non_equality_tags != res2.non_equality_tags assert res3.non_equality_tags == res4.non_equality_tags - assert res1.tags == res2.tags - assert res3.tags == res4.tags + assert hash(res1.non_equality_tags) != hash(res2.non_equality_tags) + assert hash(res3.non_equality_tags) == hash(res4.non_equality_tags) + + assert res1.tags == res2.tags == res3.tags == res4.tags + + assert hash(res1.tags) == hash(res2.tags) == hash(res3.tags) == hash(res4.tags) # }}} From 524049564eb60de3083e87a42cf74fdec267495f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 17 Nov 2023 11:16:53 -0600 Subject: [PATCH 076/124] undo more unnecessary changes --- pytato/array.py | 6 +++--- pytato/transform/__init__.py | 8 ++++---- pytato/visualization/dot.py | 1 - test/test_pytato.py | 14 ++++++-------- 4 files changed, 13 insertions(+), 16 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index bcfaa3a3a..adefe787b 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -277,7 +277,7 @@ def normalize_shape_component( # }}} -# {{{ array interface +# {{{ array inteface ConvertibleToIndexExpr = Union[int, slice, "Array", None, EllipsisType] IndexExpr = Union[IntegralT, "NormalizedSlice", "Array", None, EllipsisType] @@ -386,7 +386,7 @@ class Array(Taggable): :class:`~pytato.array.IndexLambda` is used to produce references to named arrays. Since any array that needs to be referenced in this way needs to obey this restriction anyway, - a decision was made to require the same of *all* array expressions. + a decision was made to requir the same of *all* array expressions. .. attribute:: dtype @@ -1912,7 +1912,7 @@ def roll(a: Array, shift: int, axis: Optional[int] = None) -> Array: if axis is None: if a.ndim > 1: raise NotImplementedError( - "shifting along more than one dimension is unsupported") + "shifing along more than one dimension is unsupported") else: axis = 0 diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index f04f68b11..041e8df01 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -432,8 +432,8 @@ def map_function_definition(self, def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: return Call(self.map_function_definition(expr.function), immutabledict({name: self.rec(bnd) - for name, bnd in sorted(expr.bindings.items())}), - tags=expr.tags + for name, bnd in expr.bindings.items()}), + tags=expr.tags, ) def map_named_call_result(self, expr: NamedCallResult) -> Array: @@ -668,7 +668,7 @@ def map_call(self, expr: Call, *args: Any, **kwargs: Any) -> AbstractResultWithNamedArrays: return Call(self.map_function_definition(expr.function, *args, **kwargs), immutabledict({name: self.rec(bnd, *args, **kwargs) - for name, bnd in sorted(expr.bindings.items())}), + for name, bnd in expr.bindings.items()}), tags=expr.tags, ) @@ -1566,7 +1566,7 @@ def materialize_with_mpms(expr: DictOfNamedArrays) -> DictOfNamedArrays: - MPMS materialization strategy is a greedy materialization algorithm in which any node with more than 1 materialized predecessors and more than - 1 successor is materialized. + 1 successors is materialized. - Materializing here corresponds to tagging a node with :class:`~pytato.tags.ImplStored`. - Does not attempt to materialize sub-expressions in diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 3984d986e..2798a5e2a 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -190,7 +190,6 @@ def handle_unsupported_array(self, # type: ignore[override] for field in attrs.fields(type(expr)): # type: ignore[misc] if field.name in info.fields: continue - attr = getattr(expr, field.name) if isinstance(attr, Array): diff --git a/test/test_pytato.py b/test/test_pytato.py index 127f0b7dc..7672999e7 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -449,12 +449,12 @@ def test_array_dot_repr(): x = pt.make_placeholder("x", (10, 4), np.int64) y = pt.make_placeholder("y", (10, 4), np.int64) - def _assert_repr(ary: pt.Array, expected_repr: str): + def _assert_stripped_repr(ary: pt.Array, expected_repr: str): expected_str = "".join([c for c in expected_repr if c not in [" ", "\n"]]) result_str = "".join([c for c in repr(ary)if c not in [" ", "\n"]]) assert expected_str == result_str - _assert_repr( + _assert_stripped_repr( 3*x + 4*y, """ IndexLambda( @@ -483,7 +483,7 @@ def _assert_repr(ary: pt.Array, expected_repr: str): dtype='int64', name='y')})})""") - _assert_repr( + _assert_stripped_repr( pt.roll(x.reshape(2, 20).reshape(-1), 3), """ Roll( @@ -495,7 +495,7 @@ def _assert_repr(ary: pt.Array, expected_repr: str): newshape=(40), order='C'), shift=3, axis=0)""") - _assert_repr(y * pt.not_equal(x, 3), + _assert_stripped_repr(y * pt.not_equal(x, 3), """ IndexLambda( shape=(10, 4), @@ -515,7 +515,7 @@ def _assert_repr(ary: pt.Array, expected_repr: str): bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='x')})})""") - _assert_repr( + _assert_stripped_repr( x[y[:, 2:3], x[2, :]], """ AdvancedIndexInContiguousAxes( @@ -530,7 +530,7 @@ def _assert_repr(ary: pt.Array, expected_repr: str): name='x'), indices=(2, NormalizedSlice(start=0, stop=4, step=1)))))""") - _assert_repr( + _assert_stripped_repr( pt.stack([x[y[:, 2:3], x[2, :]].T, y[x[:, 2:3], y[2, :]].T]), """ Stack( @@ -819,12 +819,10 @@ def test_created_at(): assert res1.non_equality_tags != res2.non_equality_tags assert res3.non_equality_tags == res4.non_equality_tags - assert hash(res1.non_equality_tags) != hash(res2.non_equality_tags) assert hash(res3.non_equality_tags) == hash(res4.non_equality_tags) assert res1.tags == res2.tags == res3.tags == res4.tags - assert hash(res1.tags) == hash(res2.tags) == hash(res3.tags) == hash(res4.tags) # }}} From 9338f0be61a1213563d3851467aca4e80a2cdba7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 21 Nov 2023 15:10:23 -0600 Subject: [PATCH 077/124] more lint fixes --- pytato/array.py | 5 +++-- pytato/distributed/nodes.py | 6 ++++-- pytato/visualization/dot.py | 10 ++++++---- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 2cefb9078..c188ec72c 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -450,8 +450,9 @@ class Array(Taggable): tags: FrozenSet[Tag] = attrs.field(kw_only=True) # These are automatically excluded from equality in EqualityComparer - non_equality_tags: FrozenSet[Tag] = attrs.field(kw_only=True, hash=False, - default=frozenset()) + non_equality_tags: FrozenSet[Optional[Tag]] = attrs.field(kw_only=True, + hash=False, + default=frozenset()) _mapper_method: ClassVar[str] diff --git a/pytato/distributed/nodes.py b/pytato/distributed/nodes.py index 465bda312..e95217b82 100644 --- a/pytato/distributed/nodes.py +++ b/pytato/distributed/nodes.py @@ -149,8 +149,10 @@ class DistributedSendRefHolder(Array): _mapper_method: ClassVar[str] = "map_distributed_send_ref_holder" def __init__(self, send: DistributedSend, passthrough_data: Array, - tags: FrozenSet[Tag] = frozenset()) -> None: - super().__init__(axes=passthrough_data.axes, tags=tags) + tags: FrozenSet[Tag] = frozenset(), + non_equality_tags: FrozenSet[Optional[Tag]] = frozenset()) -> None: + super().__init__(axes=passthrough_data.axes, tags=tags, + non_equality_tags=non_equality_tags) object.__setattr__(self, "send", send) object.__setattr__(self, "passthrough_data", passthrough_data) diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index b4128f1fc..69b7cd21a 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -144,11 +144,11 @@ def emit_subgraph(sg: _SubgraphTree) -> None: @attrs.define class _DotNodeInfo: title: str - fields: Dict[str, str] + fields: Dict[str, Any] edges: Dict[str, Union[ArrayOrNames, FunctionDefinition]] -def stringify_tags(tags: FrozenSet[Tag]) -> str: +def stringify_tags(tags: FrozenSet[Optional[Tag]]) -> str: components = sorted(str(elem) for elem in tags) return "{" + ", ".join(components) + "}" @@ -373,14 +373,16 @@ def _stringify_created_at(non_equality_tags: FrozenSet[Tag]) -> str: return "" -def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, str], +def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, Any], dot_node_id: str, color: str = "white") -> None: td_attrib = 'border="0"' table_attrib = 'border="0" cellborder="1" cellspacing="0"' rows = [f"{dot_escape(title)}"] - non_equality_tags = fields.pop("non_equality_tags", frozenset()) + non_equality_tags: FrozenSet[Any] = fields.pop("non_equality_tags", frozenset()) + + print(non_equality_tags) tooltip = dot_escape_leave_space(_stringify_created_at(non_equality_tags)) for name, field in fields.items(): From f5cb92ffe209ff506e5956ae57975d558dbb537a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 21 Nov 2023 15:56:44 -0600 Subject: [PATCH 078/124] run all examples, fix demo_distributed_node_duplication --- .github/workflows/ci.yml | 2 +- examples/demo_distributed_node_duplication.py | 11 ++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6fd07337c..07d11a052 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -77,7 +77,7 @@ jobs: . ci-support-v0 build_py_project_in_conda_env pip install pytest # for advection.py - run_examples + run_examples --no-require-main docs: name: Documentation diff --git a/examples/demo_distributed_node_duplication.py b/examples/demo_distributed_node_duplication.py index 39307ccfb..9dd0670ae 100644 --- a/examples/demo_distributed_node_duplication.py +++ b/examples/demo_distributed_node_duplication.py @@ -1,15 +1,20 @@ """ An example to demonstrate the behavior of -:func:`pytato.find_distrbuted_partition`. One of the key characteristic of the -partitioning routine is to recompute expressions that appear in the multiple +:func:`pytato.find_distributed_partition`. One of the key characteristics of the +partitioning routine is to recompute expressions that appear in multiple partitions but are not materialized. """ import pytato as pt import numpy as np +from mpi4py import MPI + +comm = MPI.COMM_WORLD size = 2 rank = 0 +pt.enable_traceback_tag() + x1 = pt.make_placeholder("x1", shape=(10, 4), dtype=np.float64) x2 = pt.make_placeholder("x2", shape=(10, 4), dtype=np.float64) x3 = pt.make_placeholder("x3", shape=(10, 4), dtype=np.float64) @@ -30,7 +35,7 @@ out = tmp2 + recv result = pt.make_dict_of_named_arrays({"out": out}) -partitions = pt.find_distributed_partition(result) +partitions = pt.find_distributed_partition(comm, result) # Visualize *partitions* to see that each of the two partitions contains a node # named 'tmp2'. From 36166c6e85a55e6530c98d80d6a759e2e7df6fd5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 21 Nov 2023 16:15:25 -0600 Subject: [PATCH 079/124] enable CreatedAt for distributed nodes --- examples/mpi-distributed.py | 2 ++ examples/visualization.py | 2 ++ pytato/array.py | 9 ++++++--- pytato/distributed/nodes.py | 13 +++++++++---- pytato/visualization/dot.py | 1 - 5 files changed, 19 insertions(+), 8 deletions(-) diff --git a/examples/mpi-distributed.py b/examples/mpi-distributed.py index dd29a82ab..ce8bcdab4 100644 --- a/examples/mpi-distributed.py +++ b/examples/mpi-distributed.py @@ -15,6 +15,8 @@ def main(): + pt.enable_traceback_tag() + ctx = cl.create_some_context() queue = cl.CommandQueue(ctx) diff --git a/examples/visualization.py b/examples/visualization.py index ac71e6060..f18262cb5 100755 --- a/examples/visualization.py +++ b/examples/visualization.py @@ -17,6 +17,8 @@ def main(): + pt.enable_traceback_tag() + n = pt.make_size_param("n") array = pt.make_placeholder(name="array", shape=n, dtype=np.float64) stack = pt.stack([array, 2*array, array + 6]) diff --git a/pytato/array.py b/pytato/array.py index c188ec72c..4b5373505 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1091,6 +1091,7 @@ def with_tagged_reduction(self, (new_redn_axis_to_redn_descr), tags=self.tags, index_to_access_descr=self.index_to_access_descr, + non_equality_tags=self.non_equality_tags, ) @@ -1297,6 +1298,7 @@ def einsum(subscripts: str, *operands: Array, ), redn_axis_to_redn_descr=immutabledict(redn_axis_to_redn_descr), index_to_access_descr=index_to_descr, + non_equality_tags=frozenset({_get_created_at_tag()}), ) # }}} @@ -1825,11 +1827,12 @@ def update_persistent_hash(self, key_hash: int, key_builder: Any) -> None: def short_str(self, maxlen: int = 100) -> str: from os.path import dirname - # Find the first file in the frames that is not in pytato's pytato/ - # directory. + # Find the first file in the frames that is not in pytato's internal + # directories. for frame in reversed(self.frames): frame_dir = dirname(frame.filename) - if not frame_dir.endswith("pytato"): + if (not frame_dir.endswith("pytato") + and not frame_dir.endswith("pytato/distributed")): return frame.short_str(maxlen) # Fallback in case we don't find any file that is not in the pytato/ diff --git a/pytato/distributed/nodes.py b/pytato/distributed/nodes.py index e95217b82..617876c97 100644 --- a/pytato/distributed/nodes.py +++ b/pytato/distributed/nodes.py @@ -64,7 +64,8 @@ from pytato.array import ( Array, _SuppliedShapeAndDtypeMixin, ShapeType, AxesT, - _get_default_axes, ConvertibleToShape, normalize_shape) + _get_default_axes, ConvertibleToShape, normalize_shape, + _get_created_at_tag) CommTagType = Hashable @@ -170,13 +171,15 @@ def copy(self, **kwargs: Any) -> DistributedSendRefHolder: send = kwargs.pop("send", self.send) passthrough_data = kwargs.pop("passthrough_data", self.passthrough_data) tags = kwargs.pop("tags", self.tags) + non_equality_tags = kwargs.pop("non_equality_tags", self.non_equality_tags) if kwargs: raise ValueError("Cannot assign" f" DistributedSendRefHolder.'{set(kwargs)}'") return DistributedSendRefHolder(send, passthrough_data, - tags) + tags, + non_equality_tags) # }}} @@ -238,7 +241,8 @@ def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagT return DistributedSendRefHolder( send=DistributedSend(data=sent_data, dest_rank=dest_rank, comm_tag=comm_tag, tags=send_tags), - passthrough_data=stapled_to, tags=ref_holder_tags) + passthrough_data=stapled_to, tags=ref_holder_tags, + non_equality_tags=frozenset({_get_created_at_tag()})) def make_distributed_recv(src_rank: int, comm_tag: CommTagType, @@ -255,7 +259,8 @@ def make_distributed_recv(src_rank: int, comm_tag: CommTagType, dtype = np.dtype(dtype) return DistributedRecv( src_rank=src_rank, comm_tag=comm_tag, shape=shape, dtype=dtype, - tags=tags, axes=axes) + tags=tags, axes=axes, + non_equality_tags=frozenset({_get_created_at_tag()})) # }}} diff --git a/pytato/visualization/dot.py b/pytato/visualization/dot.py index 69b7cd21a..9c59c4e4a 100644 --- a/pytato/visualization/dot.py +++ b/pytato/visualization/dot.py @@ -382,7 +382,6 @@ def _emit_array(emit: Callable[[str], None], title: str, fields: Dict[str, Any], non_equality_tags: FrozenSet[Any] = fields.pop("non_equality_tags", frozenset()) - print(non_equality_tags) tooltip = dot_escape_leave_space(_stringify_created_at(non_equality_tags)) for name, field in fields.items(): From 4c3b06a82a9e9b2eefc3307d3ed19d29efcb00d2 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 29 Nov 2023 12:43:11 -0600 Subject: [PATCH 080/124] undo MPI tag ordering --- pytato/distributed/tags.py | 49 +++++++++++++++++++++----------------- test/test_distributed.py | 10 ++------ 2 files changed, 29 insertions(+), 30 deletions(-) diff --git a/pytato/distributed/tags.py b/pytato/distributed/tags.py index 1420b7532..f9eb3b7ec 100644 --- a/pytato/distributed/tags.py +++ b/pytato/distributed/tags.py @@ -31,7 +31,7 @@ """ -from typing import TYPE_CHECKING, Tuple, TypeVar +from typing import TYPE_CHECKING, Tuple, FrozenSet, Optional, TypeVar from pytato.distributed.partition import DistributedGraphPartition @@ -63,40 +63,45 @@ def number_distributed_tags( This is a potentially heavyweight MPI-collective operation on *mpi_communicator*. """ - from pytools import flatten - - tags = tuple([ + tags = frozenset({ recv.comm_tag for part in partition.parts.values() for recv in part.name_to_recv_node.values() - ] + [ + } | { send.comm_tag for part in partition.parts.values() for sends in part.name_to_send_nodes.values() - for send in sends]) + for send in sends}) + + from mpi4py import MPI + + def set_union( + set_a: FrozenSet[T], set_b: FrozenSet[T], + mpi_data_type: Optional[MPI.Datatype]) -> FrozenSet[T]: + assert mpi_data_type is None + assert isinstance(set_a, frozenset) + assert isinstance(set_b, frozenset) + + return set_a | set_b root_rank = 0 - all_tags = mpi_communicator.gather(tags, root=root_rank) + set_union_mpi_op = MPI.Op.Create( + # type ignore reason: mpi4py misdeclares op functions as returning + # None. + set_union, # type: ignore[arg-type] + commute=True) + try: + all_tags = mpi_communicator.reduce( + tags, set_union_mpi_op, root=root_rank) + finally: + set_union_mpi_op.Free() if mpi_communicator.rank == root_rank: sym_tag_to_int_tag = {} next_tag = base_tag - assert isinstance(all_tags, list) - assert len(all_tags) == mpi_communicator.size - - # First previous version - # for sym_tag in sorted(all_tags, key=lambda tag: repr(tag)): - # sym_tag_to_int_tag[sym_tag] = next_tag - # next_tag += 1 - # - # - # Second previous version - # for sym_tag in flatten(all_tags): # type: ignore[no-untyped-call] - # if sym_tag not in sym_tag_to_int_tag: - # sym_tag_to_int_tag[sym_tag] = next_tag - # next_tag += 1 - # Current main version + assert isinstance(all_tags, frozenset) + for sym_tag in all_tags: sym_tag_to_int_tag[sym_tag] = next_tag next_tag += 1 diff --git a/test/test_distributed.py b/test/test_distributed.py index 6e2e34376..925d2e070 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -266,7 +266,7 @@ def _do_test_distributed_execution_random_dag(ctx_factory): ntests = 10 for i in range(ntests): seed = 120 + i - print(f"Step {i} {seed=}") + print(f"Step {i} {seed}") # {{{ compute value with communication @@ -278,13 +278,7 @@ def gen_comm(rdagc): nonlocal comm_tag comm_tag += 1 - - if comm_tag % 5 == 1 or 1: - tag = (comm_tag, frozenset([_RandomDAGTag, "a", comm_tag])) - elif comm_tag % 5 == 2: - tag = (comm_tag, (_RandomDAGTag, "b")) - else: - tag = (comm_tag, _RandomDAGTag) # noqa: B023 + tag = (comm_tag, _RandomDAGTag) # noqa: B023 inner = make_random_dag(rdagc) return pt.staple_distributed_send( From 06503b1cc548b6ae6061900bcc933453431e269e Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 2 Feb 2024 11:00:41 -0600 Subject: [PATCH 081/124] get precise traceback of array creation --- pytato/array.py | 47 +++++++++++++++++++++++++++++++++++++++-------- pytato/cmath.py | 2 +- pytato/utils.py | 33 +++++++++++++++++++-------------- 3 files changed, 59 insertions(+), 23 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 4b5373505..ffbee5b21 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -508,7 +508,11 @@ def __getitem__(self, slice_spec = (slice_spec,) from pytato.utils import _index_into - return _index_into(self, slice_spec) + return _index_into( + self, + slice_spec, + tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()})) @property def ndim(self) -> int: @@ -553,13 +557,20 @@ def _binary_op(self, # }}} + tags = _get_default_tags() + non_equality_tags = frozenset({_get_created_at_tag(stacklevel=2)}) + import pytato.utils as utils if reverse: result = utils.broadcast_binary_op(other, self, op, - get_result_type) + get_result_type, + tags=tags, + non_equality_tags=non_equality_tags) else: result = utils.broadcast_binary_op(self, other, op, - get_result_type) + get_result_type, + tags=tags, + non_equality_tags=non_equality_tags) assert isinstance(result, Array) return result @@ -579,6 +590,7 @@ def _unary_op(self, op: Any) -> Array: bindings=bindings, tags=_get_default_tags(), axes=_get_default_axes(self.ndim), + non_equality_tags=frozenset({_get_created_at_tag(stacklevel=2)}), var_to_reduction_descr=immutabledict()) __mul__ = partialmethod(_binary_op, operator.mul) @@ -1852,15 +1864,25 @@ def enable_traceback_tag(enable: bool = True) -> None: _ENABLE_TRACEBACK_TAG = enable -def _get_created_at_tag() -> Optional[Tag]: +def _get_created_at_tag(stacklevel: int = 1) -> Optional[Tag]: + """ + Get a :class:`CreatedAt` tag storing the stack trace of an array's creation. + + :param stacklevel: the number of stack levels above this call to record as the + array creation location + """ import traceback from pytato.tags import CreatedAt if not _ENABLE_TRACEBACK_TAG: return None + # Drop the stack levels corresponding to extract_stack() and any additional + # levels specified via stacklevel + stack = traceback.extract_stack()[:-(1+stacklevel)] + frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) - for s in traceback.extract_stack()) + for s in stack) return CreatedAt(_PytatoStackSummary(frames)) @@ -2376,7 +2398,10 @@ def _compare(x1: ArrayOrScalar, x2: ArrayOrScalar, which: str) -> Union[Array, b # '_compare' returns a bool. return utils.broadcast_binary_op(x1, x2, lambda x, y: prim.Comparison(x, which, y), - lambda x, y: np.dtype(np.bool_) + lambda x, y: np.dtype(np.bool_), + tags=_get_default_tags(), + non_equality_tags=frozenset({ + _get_created_at_tag(stacklevel=2)}), ) # type: ignore[return-value] @@ -2436,7 +2461,10 @@ def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Union[Array, bool]: import pytato.utils as utils return utils.broadcast_binary_op(x1, x2, lambda x, y: prim.LogicalOr((x, y)), - lambda x, y: np.dtype(np.bool_) + lambda x, y: np.dtype(np.bool_), + tags=_get_default_tags(), + non_equality_tags=frozenset({ + _get_created_at_tag()}), ) # type: ignore[return-value] @@ -2450,7 +2478,10 @@ def logical_and(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Union[Array, bool]: import pytato.utils as utils return utils.broadcast_binary_op(x1, x2, lambda x, y: prim.LogicalAnd((x, y)), - lambda x, y: np.dtype(np.bool_) + lambda x, y: np.dtype(np.bool_), + tags=_get_default_tags(), + non_equality_tags=frozenset({ + _get_created_at_tag()}), ) # type: ignore[return-value] diff --git a/pytato/cmath.py b/pytato/cmath.py index 9f8e8c7fa..3afbe82ed 100644 --- a/pytato/cmath.py +++ b/pytato/cmath.py @@ -115,7 +115,7 @@ def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...], tuple(sym_args)), shape=shape, dtype=ret_dtype, bindings=immutabledict(bindings), tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + non_equality_tags=frozenset({_get_created_at_tag(stacklevel=2)}), axes=_get_default_axes(len(shape)), var_to_reduction_descr=immutabledict(), ) diff --git a/pytato/utils.py b/pytato/utils.py index 94c43938c..e5463d8f5 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -27,7 +27,7 @@ import pymbolic.primitives as prim from typing import (Tuple, List, Union, Callable, Any, Sequence, Dict, - Optional, Iterable, TypeVar) + Optional, Iterable, TypeVar, FrozenSet) from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeComponent, DtypeOrScalar, ArrayOrScalar, BasicIndex, AdvancedIndexInContiguousAxes, @@ -38,6 +38,7 @@ SCALAR_CLASSES, INT_CLASSES, BoolT, ScalarType) from pytools import UniqueNameGenerator from pytato.transform import Mapper +from pytools.tag import Tag from immutabledict import immutabledict @@ -178,9 +179,10 @@ def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar, def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 + tags: FrozenSet[Tag], + non_equality_tags: FrozenSet[Optional[Tag]], ) -> ArrayOrScalar: - from pytato.array import (_get_default_axes, _get_default_tags, - _get_created_at_tag) + from pytato.array import _get_default_axes if isinstance(a1, SCALAR_CLASSES): a1 = np.dtype(type(a1)).type(a1) @@ -207,8 +209,8 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, shape=result_shape, dtype=result_dtype, bindings=immutabledict(bindings), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, var_to_reduction_descr=immutabledict(), axes=_get_default_axes(len(result_shape))) @@ -475,10 +477,13 @@ def _normalized_slice_len(slice_: NormalizedSlice) -> ShapeComponent: # }}} -def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Array: +def _index_into( + ary: Array, + indices: Tuple[ConvertibleToIndexExpr, ...], + tags: FrozenSet[Tag], + non_equality_tags: FrozenSet[Optional[Tag]]) -> Array: from pytato.diagnostic import CannotBroadcastError - from pytato.array import (_get_default_axes, _get_default_tags, - _get_created_at_tag) + from pytato.array import _get_default_axes # {{{ handle ellipsis @@ -564,24 +569,24 @@ def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Arra return AdvancedIndexInNoncontiguousAxes( ary, tuple(normalized_indices), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: return AdvancedIndexInContiguousAxes( ary, tuple(normalized_indices), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: # basic indexing expression return BasicIndex(ary, tuple(normalized_indices), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, axes=_get_default_axes( len([idx for idx in normalized_indices From ab87fbf4ad2fb3d1741f31055a5665e94ec6926e Mon Sep 17 00:00:00 2001 From: Matthew Smith Date: Fri, 2 Feb 2024 13:28:02 -0600 Subject: [PATCH 082/124] partialmethod doesn't introduce a stack frame --- pytato/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index ffbee5b21..66d566576 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -558,7 +558,7 @@ def _binary_op(self, # }}} tags = _get_default_tags() - non_equality_tags = frozenset({_get_created_at_tag(stacklevel=2)}) + non_equality_tags = frozenset({_get_created_at_tag()}) import pytato.utils as utils if reverse: @@ -590,7 +590,7 @@ def _unary_op(self, op: Any) -> Array: bindings=bindings, tags=_get_default_tags(), axes=_get_default_axes(self.ndim), - non_equality_tags=frozenset({_get_created_at_tag(stacklevel=2)}), + non_equality_tags=frozenset({_get_created_at_tag()}), var_to_reduction_descr=immutabledict()) __mul__ = partialmethod(_binary_op, operator.mul) From d8df5f83b16762e2f9e10f4f067bad2401c02c35 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 6 Feb 2024 10:15:53 -0600 Subject: [PATCH 083/124] add support for make_distributed_send_ref_holder --- pytato/distributed/nodes.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/pytato/distributed/nodes.py b/pytato/distributed/nodes.py index b5796d888..07befad64 100644 --- a/pytato/distributed/nodes.py +++ b/pytato/distributed/nodes.py @@ -231,12 +231,16 @@ def make_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagTyp def make_distributed_send_ref_holder( send: DistributedSend, passthrough_data: Array, - tags: FrozenSet[Tag] = frozenset() + tags: FrozenSet[Tag] = frozenset(), + non_equality_tags: FrozenSet[Optional[Tag]] = frozenset(), ) -> DistributedSendRefHolder: """Make a :class:`DistributedSendRefHolder` object.""" + if not non_equality_tags: + non_equality_tags = frozenset({_get_created_at_tag()}) return DistributedSendRefHolder( send=send, passthrough_data=passthrough_data, - tags=(tags | _get_default_tags())) + tags=(tags | _get_default_tags()), + non_equality_tags=non_equality_tags) def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagType, @@ -251,7 +255,8 @@ def staple_distributed_send(sent_data: Array, dest_rank: int, comm_tag: CommTagT sent_data=sent_data, dest_rank=dest_rank, comm_tag=comm_tag, send_tags=send_tags), passthrough_data=stapled_to, - tags=ref_holder_tags) + tags=ref_holder_tags, + non_equality_tags=frozenset({_get_created_at_tag()})) def make_distributed_recv(src_rank: int, comm_tag: CommTagType, From e1b918156723d0c2f582a780c4ba6feb413e6c11 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 16 Feb 2024 10:44:27 -0600 Subject: [PATCH 084/124] add to MPMSMaterializer --- pytato/transform/__init__.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 57cf7d6ce..6de82bb29 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -1344,14 +1344,16 @@ def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator: for bnd_name, bnd in sorted(children_rec.items())}), axes=expr.axes, var_to_reduction_descr=expr.var_to_reduction_descr, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], children_rec.values()) def map_stack(self, expr: Stack) -> MPMSMaterializerAccumulator: rec_arrays = [self.rec(ary) for ary in expr.arrays] new_expr = Stack(tuple(ary.expr for ary in rec_arrays), - expr.axis, axes=expr.axes, tags=expr.tags) + expr.axis, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -1362,7 +1364,8 @@ def map_concatenate(self, expr: Concatenate) -> MPMSMaterializerAccumulator: new_expr = Concatenate(tuple(ary.expr for ary in rec_arrays), expr.axis, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], rec_arrays) @@ -1370,7 +1373,8 @@ def map_concatenate(self, expr: Concatenate) -> MPMSMaterializerAccumulator: def map_roll(self, expr: Roll) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) new_expr = Roll(rec_array.expr, expr.shift, expr.axis, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) @@ -1378,7 +1382,8 @@ def map_axis_permutation(self, expr: AxisPermutation ) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) new_expr = AxisPermutation(rec_array.expr, expr.axis_permutation, - axes=expr.axes, tags=expr.tags) + axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], (rec_array,)) @@ -1396,7 +1401,8 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: for i in range( len(expr.indices))), axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -1410,7 +1416,8 @@ def _map_index_base(self, expr: IndexBase) -> MPMSMaterializerAccumulator: def map_reshape(self, expr: Reshape) -> MPMSMaterializerAccumulator: rec_array = self.rec(expr.array) new_expr = Reshape(rec_array.expr, expr.newshape, - expr.order, axes=expr.axes, tags=expr.tags) + expr.order, axes=expr.axes, tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], @@ -1423,7 +1430,8 @@ def map_einsum(self, expr: Einsum) -> MPMSMaterializerAccumulator: expr.redn_axis_to_redn_descr, expr.index_to_access_descr, axes=expr.axes, - tags=expr.tags) + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) return _materialize_if_mpms(new_expr, self.nsuccessors[expr], From be9dcddd8c73475faf7dff21e11a9371e436eead Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Sat, 2 Mar 2024 08:47:37 -0600 Subject: [PATCH 085/124] Spew array tracing to stdout. --- pytato/array.py | 1 + pytato/transform/metadata.py | 6 +++++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index 2b1ec11b1..65d91857c 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1488,6 +1488,7 @@ class Reshape(IndexRemappingBase): if __debug__: def __attrs_post_init__(self) -> None: + assert self.non_equality_tags super().__attrs_post_init__() @property diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 7f015d739..780d96fd4 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -614,8 +614,12 @@ def rec(self, expr: ArrayOrNames) -> Any: assert expr_copy.ndim == expr.ndim for iaxis in range(expr.ndim): + axis_tags = self.axis_to_tags.get((expr, iaxis), []) + if len(axis_tags) == 0: + print(f"failed to infer axis {iaxis} of array of type {type(expr)}.") + print(f"{expr.non_equality_tags=}") expr_copy = expr_copy.with_tagged_axis( - iaxis, self.axis_to_tags.get((expr, iaxis), [])) + iaxis, axis_tags) # {{{ tag reduction descrs From ad0aa4c01f3d9aabcf8a53ea0cbf1ca5b7bce0f1 Mon Sep 17 00:00:00 2001 From: Matt Smith Date: Thu, 7 Mar 2024 12:44:10 -0600 Subject: [PATCH 086/124] Get precise traceback of array creation (#480) --- pytato/array.py | 47 +++++++++++++++++++++++++++++++++++++++-------- pytato/cmath.py | 2 +- pytato/utils.py | 33 +++++++++++++++++++-------------- 3 files changed, 59 insertions(+), 23 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 447b4d3d6..a88c0e948 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -508,7 +508,11 @@ def __getitem__(self, slice_spec = (slice_spec,) from pytato.utils import _index_into - return _index_into(self, slice_spec) + return _index_into( + self, + slice_spec, + tags=_get_default_tags(), + non_equality_tags=frozenset({_get_created_at_tag()})) @property def ndim(self) -> int: @@ -553,13 +557,20 @@ def _binary_op(self, # }}} + tags = _get_default_tags() + non_equality_tags = frozenset({_get_created_at_tag()}) + import pytato.utils as utils if reverse: result = utils.broadcast_binary_op(other, self, op, - get_result_type) + get_result_type, + tags=tags, + non_equality_tags=non_equality_tags) else: result = utils.broadcast_binary_op(self, other, op, - get_result_type) + get_result_type, + tags=tags, + non_equality_tags=non_equality_tags) assert isinstance(result, Array) return result @@ -579,6 +590,7 @@ def _unary_op(self, op: Any) -> Array: bindings=bindings, tags=_get_default_tags(), axes=_get_default_axes(self.ndim), + non_equality_tags=frozenset({_get_created_at_tag()}), var_to_reduction_descr=immutabledict()) __mul__ = partialmethod(_binary_op, operator.mul) @@ -1852,15 +1864,25 @@ def enable_traceback_tag(enable: bool = True) -> None: _ENABLE_TRACEBACK_TAG = enable -def _get_created_at_tag() -> Optional[Tag]: +def _get_created_at_tag(stacklevel: int = 1) -> Optional[Tag]: + """ + Get a :class:`CreatedAt` tag storing the stack trace of an array's creation. + + :param stacklevel: the number of stack levels above this call to record as the + array creation location + """ import traceback from pytato.tags import CreatedAt if not _ENABLE_TRACEBACK_TAG: return None + # Drop the stack levels corresponding to extract_stack() and any additional + # levels specified via stacklevel + stack = traceback.extract_stack()[:-(1+stacklevel)] + frames = tuple(_PytatoFrameSummary(s.filename, s.lineno, s.name, s.line) - for s in traceback.extract_stack()) + for s in stack) return CreatedAt(_PytatoStackSummary(frames)) @@ -2378,7 +2400,10 @@ def _compare(x1: ArrayOrScalar, x2: ArrayOrScalar, which: str) -> Union[Array, b # '_compare' returns a bool. return utils.broadcast_binary_op(x1, x2, lambda x, y: prim.Comparison(x, which, y), - lambda x, y: np.dtype(np.bool_) + lambda x, y: np.dtype(np.bool_), + tags=_get_default_tags(), + non_equality_tags=frozenset({ + _get_created_at_tag(stacklevel=2)}), ) # type: ignore[return-value] @@ -2438,7 +2463,10 @@ def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Union[Array, bool]: import pytato.utils as utils return utils.broadcast_binary_op(x1, x2, lambda x, y: prim.LogicalOr((x, y)), - lambda x, y: np.dtype(np.bool_) + lambda x, y: np.dtype(np.bool_), + tags=_get_default_tags(), + non_equality_tags=frozenset({ + _get_created_at_tag()}), ) # type: ignore[return-value] @@ -2452,7 +2480,10 @@ def logical_and(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Union[Array, bool]: import pytato.utils as utils return utils.broadcast_binary_op(x1, x2, lambda x, y: prim.LogicalAnd((x, y)), - lambda x, y: np.dtype(np.bool_) + lambda x, y: np.dtype(np.bool_), + tags=_get_default_tags(), + non_equality_tags=frozenset({ + _get_created_at_tag()}), ) # type: ignore[return-value] diff --git a/pytato/cmath.py b/pytato/cmath.py index 9f8e8c7fa..3afbe82ed 100644 --- a/pytato/cmath.py +++ b/pytato/cmath.py @@ -115,7 +115,7 @@ def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...], tuple(sym_args)), shape=shape, dtype=ret_dtype, bindings=immutabledict(bindings), tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + non_equality_tags=frozenset({_get_created_at_tag(stacklevel=2)}), axes=_get_default_axes(len(shape)), var_to_reduction_descr=immutabledict(), ) diff --git a/pytato/utils.py b/pytato/utils.py index 94c43938c..e5463d8f5 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -27,7 +27,7 @@ import pymbolic.primitives as prim from typing import (Tuple, List, Union, Callable, Any, Sequence, Dict, - Optional, Iterable, TypeVar) + Optional, Iterable, TypeVar, FrozenSet) from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeComponent, DtypeOrScalar, ArrayOrScalar, BasicIndex, AdvancedIndexInContiguousAxes, @@ -38,6 +38,7 @@ SCALAR_CLASSES, INT_CLASSES, BoolT, ScalarType) from pytools import UniqueNameGenerator from pytato.transform import Mapper +from pytools.tag import Tag from immutabledict import immutabledict @@ -178,9 +179,10 @@ def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar, def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 + tags: FrozenSet[Tag], + non_equality_tags: FrozenSet[Optional[Tag]], ) -> ArrayOrScalar: - from pytato.array import (_get_default_axes, _get_default_tags, - _get_created_at_tag) + from pytato.array import _get_default_axes if isinstance(a1, SCALAR_CLASSES): a1 = np.dtype(type(a1)).type(a1) @@ -207,8 +209,8 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, shape=result_shape, dtype=result_dtype, bindings=immutabledict(bindings), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, var_to_reduction_descr=immutabledict(), axes=_get_default_axes(len(result_shape))) @@ -475,10 +477,13 @@ def _normalized_slice_len(slice_: NormalizedSlice) -> ShapeComponent: # }}} -def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Array: +def _index_into( + ary: Array, + indices: Tuple[ConvertibleToIndexExpr, ...], + tags: FrozenSet[Tag], + non_equality_tags: FrozenSet[Optional[Tag]]) -> Array: from pytato.diagnostic import CannotBroadcastError - from pytato.array import (_get_default_axes, _get_default_tags, - _get_created_at_tag) + from pytato.array import _get_default_axes # {{{ handle ellipsis @@ -564,24 +569,24 @@ def _index_into(ary: Array, indices: Tuple[ConvertibleToIndexExpr, ...]) -> Arra return AdvancedIndexInNoncontiguousAxes( ary, tuple(normalized_indices), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: return AdvancedIndexInContiguousAxes( ary, tuple(normalized_indices), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, axes=_get_default_axes(len(array_idx_shape) + len(i_basic_indices))) else: # basic indexing expression return BasicIndex(ary, tuple(normalized_indices), - tags=_get_default_tags(), - non_equality_tags=frozenset({_get_created_at_tag()}), + tags=tags, + non_equality_tags=non_equality_tags, axes=_get_default_axes( len([idx for idx in normalized_indices From ab5728e67e4e5ed0f251bfaadd09fb11d8f74fc7 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Thu, 11 Apr 2024 08:06:31 -0500 Subject: [PATCH 087/124] Disable assert non_equality_tag --- pytato/array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/array.py b/pytato/array.py index f4b952947..598fa8c7b 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -1501,7 +1501,7 @@ class Reshape(IndexRemappingBase): if __debug__: def __attrs_post_init__(self) -> None: - assert self.non_equality_tags + # assert self.non_equality_tags super().__attrs_post_init__() @property From 38e4332d8aea44900061b4e4d5891be33eff064b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 25 Sep 2023 17:36:38 -0500 Subject: [PATCH 088/124] add PytatoKeyBuilder --- pytato/analysis/__init__.py | 29 +++++++++++++++++++++++++++++ test/test_pytato.py | 22 ++++++++++++++++++++++ 2 files changed, 51 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 5bf374746..59350329f 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -36,6 +36,7 @@ from pytato.loopy import LoopyCall from pymbolic.mapper.optimize import optimize_mapper from pytools import memoize_method +from loopy.tools import LoopyKeyBuilder, PersistentHashWalkMapper if TYPE_CHECKING: from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder @@ -463,4 +464,32 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: # }}} + +# {{{ PytatoKeyBuilder + +class PytatoKeyBuilder(LoopyKeyBuilder): + """A custom :class:`pytools.persistent_dict.KeyBuilder` subclass + for objects within :mod:`pytato`. + """ + + def update_for_ndarray(self, key_hash: Any, key: Any) -> None: + self.rec(key_hash, hash(key.data.tobytes())) # type: ignore[no-untyped-call] + + def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: + if key is None: + self.update_for_NoneType(key_hash, key) # type: ignore[no-untyped-call] + else: + PersistentHashWalkMapper(key_hash)(key) + + update_for_Product = update_for_pymbolic_expression # noqa: N815 + update_for_Sum = update_for_pymbolic_expression # noqa: N815 + update_for_If = update_for_pymbolic_expression # noqa: N815 + update_for_LogicalOr = update_for_pymbolic_expression # noqa: N815 + update_for_Call = update_for_pymbolic_expression # noqa: N815 + update_for_Comparison = update_for_pymbolic_expression # noqa: N815 + update_for_Quotient = update_for_pymbolic_expression # noqa: N815 + update_for_Power = update_for_pymbolic_expression # noqa: N815 + +# }}} + # vim: fdm=marker diff --git a/test/test_pytato.py b/test/test_pytato.py index 8939073cb..e5d8c2b5f 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1231,6 +1231,28 @@ def test_dot_visualizers(): # }}} +def test_persistent_dict(): + from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError + from pytato.analysis import PytatoKeyBuilder + + axis_len = 5 + + pd = WriteOncePersistentDict("test_persistent_dict", + key_builder=PytatoKeyBuilder(), + container_dir="./pytest-pdict") + + for i in range(100): + rdagc = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=axis_len, use_numpy=True) + + dag = make_random_dag(rdagc) + pd[dag] = 42 + + # Make sure key stays the same + with pytest.raises(ReadOnlyEntryError): + pd[dag] = 42 + + if __name__ == "__main__": if len(sys.argv) > 1: exec(sys.argv[1]) From 4dd3250a4b7be6f86690c075a69c1ef569ca3137 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 25 Sep 2023 18:04:55 -0500 Subject: [PATCH 089/124] mypy fixes --- pytato/analysis/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 59350329f..c31cb1e5a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -467,17 +467,17 @@ def get_num_call_sites(outputs: Union[Array, DictOfNamedArrays]) -> int: # {{{ PytatoKeyBuilder -class PytatoKeyBuilder(LoopyKeyBuilder): +class PytatoKeyBuilder(LoopyKeyBuilder): # type: ignore[misc] """A custom :class:`pytools.persistent_dict.KeyBuilder` subclass for objects within :mod:`pytato`. """ def update_for_ndarray(self, key_hash: Any, key: Any) -> None: - self.rec(key_hash, hash(key.data.tobytes())) # type: ignore[no-untyped-call] + self.rec(key_hash, hash(key.data.tobytes())) def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: if key is None: - self.update_for_NoneType(key_hash, key) # type: ignore[no-untyped-call] + self.update_for_NoneType(key_hash, key) else: PersistentHashWalkMapper(key_hash)(key) From 970e7bbdc0bd055554bdc38e0853243c2b5440ef Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 28 Sep 2023 16:52:12 -0500 Subject: [PATCH 090/124] support TaggableCLArray, Subscript --- pytato/analysis/__init__.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index c31cb1e5a..e5e852143 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -475,6 +475,9 @@ class PytatoKeyBuilder(LoopyKeyBuilder): # type: ignore[misc] def update_for_ndarray(self, key_hash: Any, key: Any) -> None: self.rec(key_hash, hash(key.data.tobytes())) + def update_for_TaggableCLArray(self, key_hash: Any, key: Any) -> None: + self.update_for_ndarray(key_hash, key.get()) + def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: if key is None: self.update_for_NoneType(key_hash, key) @@ -489,6 +492,7 @@ def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: update_for_Comparison = update_for_pymbolic_expression # noqa: N815 update_for_Quotient = update_for_pymbolic_expression # noqa: N815 update_for_Power = update_for_pymbolic_expression # noqa: N815 + update_for_Subscript = update_for_pymbolic_expression # noqa: N815 # }}} From 95dec097a03bae5380fa24541ee11b701faef572 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 28 Sep 2023 18:24:55 -0500 Subject: [PATCH 091/124] CL Array, function --- pytato/analysis/__init__.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index e5e852143..9af7b14bb 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -478,6 +478,13 @@ def update_for_ndarray(self, key_hash: Any, key: Any) -> None: def update_for_TaggableCLArray(self, key_hash: Any, key: Any) -> None: self.update_for_ndarray(key_hash, key.get()) + def update_for_Array(self, key_hash: Any, key: Any) -> None: + # CL Array + self.update_for_ndarray(key_hash, key.get()) + + def update_for_function(self, key_hash: Any, key: Any) -> None: + self.rec(key_hash, key.__name__) + def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: if key is None: self.update_for_NoneType(key_hash, key) From 2ac10eea98a6520648a676e8ae089bed5ddb4f2a Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 11:13:20 -0600 Subject: [PATCH 092/124] add prim.Variable --- pytato/analysis/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 9af7b14bb..fb0bb1025 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -500,6 +500,7 @@ def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: update_for_Quotient = update_for_pymbolic_expression # noqa: N815 update_for_Power = update_for_pymbolic_expression # noqa: N815 update_for_Subscript = update_for_pymbolic_expression # noqa: N815 + update_for_Variable = update_for_pymbolic_expression # noqa: N815 # }}} From 62a13aeca6214b8848769507e3a32205c9e76fe5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 12:13:21 -0600 Subject: [PATCH 093/124] fixes to ndarray, pymb expressions --- pytato/analysis/__init__.py | 36 +++++++++++++++--------------------- 1 file changed, 15 insertions(+), 21 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index fb0bb1025..d74092590 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -473,34 +473,28 @@ class PytatoKeyBuilder(LoopyKeyBuilder): # type: ignore[misc] """ def update_for_ndarray(self, key_hash: Any, key: Any) -> None: - self.rec(key_hash, hash(key.data.tobytes())) + self.rec(key_hash, key.data.tobytes()) def update_for_TaggableCLArray(self, key_hash: Any, key: Any) -> None: - self.update_for_ndarray(key_hash, key.get()) + self.rec(key_hash, key.get()) def update_for_Array(self, key_hash: Any, key: Any) -> None: # CL Array - self.update_for_ndarray(key_hash, key.get()) + self.rec(key_hash, key.get()) def update_for_function(self, key_hash: Any, key: Any) -> None: - self.rec(key_hash, key.__name__) - - def update_for_pymbolic_expression(self, key_hash: Any, key: Any) -> None: - if key is None: - self.update_for_NoneType(key_hash, key) - else: - PersistentHashWalkMapper(key_hash)(key) - - update_for_Product = update_for_pymbolic_expression # noqa: N815 - update_for_Sum = update_for_pymbolic_expression # noqa: N815 - update_for_If = update_for_pymbolic_expression # noqa: N815 - update_for_LogicalOr = update_for_pymbolic_expression # noqa: N815 - update_for_Call = update_for_pymbolic_expression # noqa: N815 - update_for_Comparison = update_for_pymbolic_expression # noqa: N815 - update_for_Quotient = update_for_pymbolic_expression # noqa: N815 - update_for_Power = update_for_pymbolic_expression # noqa: N815 - update_for_Subscript = update_for_pymbolic_expression # noqa: N815 - update_for_Variable = update_for_pymbolic_expression # noqa: N815 + self.rec(key_hash, key.__module__ + key.__qualname__) + + update_for_Product = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_Sum = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_If = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_LogicalOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_Call = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_Comparison = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_Quotient = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_Power = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_Subscript = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_Variable = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 # }}} From b8e04bf00130faa3c17ce6f938f2ca15e7ac85b8 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 12:17:49 -0600 Subject: [PATCH 094/124] flake8 --- pytato/analysis/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index d74092590..41a5842b0 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -36,7 +36,7 @@ from pytato.loopy import LoopyCall from pymbolic.mapper.optimize import optimize_mapper from pytools import memoize_method -from loopy.tools import LoopyKeyBuilder, PersistentHashWalkMapper +from loopy.tools import LoopyKeyBuilder if TYPE_CHECKING: from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder From ad9aa2818d45e7e24f733edc67fdb07d225f569f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 12:25:51 -0600 Subject: [PATCH 095/124] improve test --- test/test_pytato.py | 37 ++++++++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index e5d8c2b5f..4befdbf7b 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1231,27 +1231,42 @@ def test_dot_visualizers(): # }}} -def test_persistent_dict(): +def test_persistent_hashing_and_persistent_dict(): from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError from pytato.analysis import PytatoKeyBuilder + import shutil + import tempfile axis_len = 5 - pd = WriteOncePersistentDict("test_persistent_dict", - key_builder=PytatoKeyBuilder(), - container_dir="./pytest-pdict") + try: + tmpdir = tempfile.mkdtemp() - for i in range(100): - rdagc = RandomDAGContext(np.random.default_rng(seed=i), - axis_len=axis_len, use_numpy=True) + pkb = PytatoKeyBuilder() - dag = make_random_dag(rdagc) - pd[dag] = 42 + pd = WriteOncePersistentDict("test_persistent_dict", + key_builder=pkb, + container_dir=tmpdir) + + for i in range(100): + rdagc = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=axis_len, use_numpy=True) - # Make sure key stays the same - with pytest.raises(ReadOnlyEntryError): + dag = make_random_dag(rdagc) + + # Make sure the PytatoKeyBuilder can handle 'dag' pd[dag] = 42 + # make sure the key stays the same across invocations + if i == 0: + assert pkb(dag) == "eaa8ad49c9490cb6f0b61a33c17d0c2fd10fafc6ce02705105cc9c379c91b9c8" + + # Make sure key stays the same + with pytest.raises(ReadOnlyEntryError): + pd[dag] = 42 + finally: + shutil.rmtree(tmpdir) + if __name__ == "__main__": if len(sys.argv) > 1: From 60d8e41452b36f0321972c5017b032530eff9850 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 13:03:26 -0600 Subject: [PATCH 096/124] add full invocation test --- test/test_pytato.py | 72 ++++++++++++++++++++++++++++++++++++++------- 1 file changed, 62 insertions(+), 10 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 4befdbf7b..26af8d746 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1231,14 +1231,42 @@ def test_dot_visualizers(): # }}} -def test_persistent_hashing_and_persistent_dict(): +# {{{ Test PytatoKeyBuilder + +def run_test_with_new_python_invocation(f, *args, extra_env_vars = None) -> None: + import os + if extra_env_vars is None: + extra_env_vars = {} + + from base64 import b64encode + from pickle import dumps + from subprocess import check_call + + env_vars = { + "INVOCATION_INFO": b64encode(dumps((f, args))).decode(), + } + env_vars.update(extra_env_vars) + + my_env = os.environ.copy() + my_env.update(env_vars) + + check_call([sys.executable, __file__], env=my_env) + + +def run_test_with_new_python_invocation_inner() -> None: + from base64 import b64decode + from pickle import loads + f, args = loads(b64decode(os.environ["INVOCATION_INFO"].encode())) + + f(*args) + + +def test_persistent_hashing_and_persistent_dict() -> None: from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError from pytato.analysis import PytatoKeyBuilder import shutil import tempfile - axis_len = 5 - try: tmpdir = tempfile.mkdtemp() @@ -1250,26 +1278,50 @@ def test_persistent_hashing_and_persistent_dict(): for i in range(100): rdagc = RandomDAGContext(np.random.default_rng(seed=i), - axis_len=axis_len, use_numpy=True) + axis_len=5, use_numpy=True) dag = make_random_dag(rdagc) # Make sure the PytatoKeyBuilder can handle 'dag' pd[dag] = 42 - # make sure the key stays the same across invocations - if i == 0: - assert pkb(dag) == "eaa8ad49c9490cb6f0b61a33c17d0c2fd10fafc6ce02705105cc9c379c91b9c8" - - # Make sure key stays the same + # Make sure that the key stays the same within the same Python invocation with pytest.raises(ReadOnlyEntryError): pd[dag] = 42 + + # Make sure that the key stays the same across Python invocations + run_test_with_new_python_invocation(_test_persistent_hashing_and_persistent_dict_stage2, + tmpdir) finally: shutil.rmtree(tmpdir) +def _test_persistent_hashing_and_persistent_dict_stage2(tmpdir) -> None: + from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError + + from pytato.analysis import PytatoKeyBuilder + pkb = PytatoKeyBuilder() + + pd = WriteOncePersistentDict("test_persistent_dict", + key_builder=pkb, + container_dir=tmpdir) + + for i in range(100): + rdagc = RandomDAGContext(np.random.default_rng(seed=i), + axis_len=5, use_numpy=True) + + dag = make_random_dag(rdagc) + + with pytest.raises(ReadOnlyEntryError): + pd[dag] = 42 + +# }}} + if __name__ == "__main__": - if len(sys.argv) > 1: + import os + if "INVOCATION_INFO" in os.environ: + run_test_with_new_python_invocation_inner() + elif len(sys.argv) > 1: exec(sys.argv[1]) else: from pytest import main From 9d45e6577675c99cec4a0a836f2bbf49e7ad274e Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 13:10:22 -0600 Subject: [PATCH 097/124] lint fixes --- test/test_pytato.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/test/test_pytato.py b/test/test_pytato.py index 26af8d746..cb4438eea 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -1233,7 +1233,7 @@ def test_dot_visualizers(): # {{{ Test PytatoKeyBuilder -def run_test_with_new_python_invocation(f, *args, extra_env_vars = None) -> None: +def run_test_with_new_python_invocation(f, *args, extra_env_vars=None) -> None: import os if extra_env_vars is None: extra_env_vars = {} @@ -1256,6 +1256,8 @@ def run_test_with_new_python_invocation(f, *args, extra_env_vars = None) -> None def run_test_with_new_python_invocation_inner() -> None: from base64 import b64decode from pickle import loads + import os + f, args = loads(b64decode(os.environ["INVOCATION_INFO"].encode())) f(*args) @@ -1290,11 +1292,12 @@ def test_persistent_hashing_and_persistent_dict() -> None: pd[dag] = 42 # Make sure that the key stays the same across Python invocations - run_test_with_new_python_invocation(_test_persistent_hashing_and_persistent_dict_stage2, - tmpdir) + run_test_with_new_python_invocation( + _test_persistent_hashing_and_persistent_dict_stage2, tmpdir) finally: shutil.rmtree(tmpdir) + def _test_persistent_hashing_and_persistent_dict_stage2(tmpdir) -> None: from pytools.persistent_dict import WriteOncePersistentDict, ReadOnlyEntryError From 08be3806b4f2b153ae4adb083fd38ff63c773bde Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 16:30:43 -0600 Subject: [PATCH 098/124] add missing pymbolic expressions --- pytato/analysis/__init__.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 41a5842b0..ffaba4a8d 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -485,15 +485,27 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: def update_for_function(self, key_hash: Any, key: Any) -> None: self.rec(key_hash, key.__module__ + key.__qualname__) - update_for_Product = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_Sum = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_If = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_LogicalOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_BitwiseAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_BitwiseNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_BitwiseOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_BitwiseXor = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_Call = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_CallWithKwargs = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 update_for_Comparison = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Quotient = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_If = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_FloorDiv = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_LeftShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_LogicalAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_LogicalNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_LogicalOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_Lookup = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_Power = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_Product = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_Quotient = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_Remainder = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_RightShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_Subscript = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + update_for_Sum = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 update_for_Variable = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 # }}} From 058f6f9284afc63ea5fb910b570c39588c572a0b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 5 Feb 2024 22:28:59 -0600 Subject: [PATCH 099/124] flake8 --- pytato/analysis/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index ffaba4a8d..3f9446dfe 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -490,10 +490,10 @@ def update_for_function(self, key_hash: Any, key: Any) -> None: update_for_BitwiseOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_BitwiseXor = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_Call = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_CallWithKwargs = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_CallWithKwargs = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_Comparison = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_If = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_FloorDiv = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + update_for_FloorDiv = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_LeftShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_LogicalAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_LogicalNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 From 0360e211272675628acfb67726bbbcb4ab52d4b7 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 13 Jun 2024 14:43:04 -0500 Subject: [PATCH 100/124] remove update_for_function (now handled directly by pytools) --- pytato/analysis/__init__.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 3f9446dfe..22ab396b7 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -482,9 +482,6 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # CL Array self.rec(key_hash, key.get()) - def update_for_function(self, key_hash: Any, key: Any) -> None: - self.rec(key_hash, key.__module__ + key.__qualname__) - update_for_BitwiseAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_BitwiseNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 update_for_BitwiseOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 From 3364a4fda9cc11d46c2ebe8e32fee0b161da27cf Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 16:19:06 -0500 Subject: [PATCH 101/124] working pass 1 --- pytato/analysis/__init__.py | 50 +++++++++---------- pytato/distributed/partition.py | 86 ++++++++++++++++++--------------- pytato/transform/__init__.py | 5 +- 3 files changed, 74 insertions(+), 67 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 38ed276fe..1a4359e4b 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -310,48 +310,48 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter +from orderedsets import FrozenOrderedSet + class DirectPredecessorsGetter(Mapper): """ Mapper to get the `direct predecessors `__ of a node. - .. note:: - We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> frozenset[Array]: - return frozenset({dim for dim in shape if isinstance(dim, Array)}) + def _get_preds_from_shape(self, shape: ShapeType) -> abc_Set[Array]: + return FrozenOrderedSet([dim for dim in shape if isinstance(dim, Array)]) - def map_index_lambda(self, expr: IndexLambda) -> frozenset[Array]: - return (frozenset(expr.bindings.values()) + def map_index_lambda(self, expr: IndexLambda) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.bindings.values()) | self._get_preds_from_shape(expr.shape)) - def map_stack(self, expr: Stack) -> frozenset[Array]: - return (frozenset(expr.arrays) + def map_stack(self, expr: Stack) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_concatenate(self, expr: Concatenate) -> frozenset[Array]: - return (frozenset(expr.arrays) + def map_concatenate(self, expr: Concatenate) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.arrays) | self._get_preds_from_shape(expr.shape)) - def map_einsum(self, expr: Einsum) -> frozenset[Array]: - return (frozenset(expr.args) + def map_einsum(self, expr: Einsum) -> abc_Set[Array]: + return (FrozenOrderedSet(expr.args) | self._get_preds_from_shape(expr.shape)) - def map_loopy_call_result(self, expr: NamedArray) -> frozenset[Array]: - from pytato.loopy import LoopyCall, LoopyCallResult + def map_loopy_call_result(self, expr: NamedArray) -> abc_Set[Array]: + from pytato.loopy import LoopyCallResult, LoopyCall assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) - return (frozenset(ary + return (FrozenOrderedSet(ary for ary in expr._container.bindings.values() if isinstance(ary, Array)) | self._get_preds_from_shape(expr.shape)) - def _map_index_base(self, expr: IndexBase) -> frozenset[Array]: - return (frozenset([expr.array]) - | frozenset(idx for idx in expr.indices + def _map_index_base(self, expr: IndexBase) -> abc_Set[Array]: + return (FrozenOrderedSet([expr.array]) + | FrozenOrderedSet(idx for idx in expr.indices if isinstance(idx, Array)) | self._get_preds_from_shape(expr.shape)) @@ -360,29 +360,29 @@ def _map_index_base(self, expr: IndexBase) -> frozenset[Array]: map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase - ) -> frozenset[Array]: - return frozenset([expr.array]) + ) -> abc_Set[Array]: + return FrozenOrderedSet([expr.array]) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> frozenset[Array]: + def _map_input_base(self, expr: InputArgumentBase) -> abc_Set[Array]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> frozenset[Array]: + def map_distributed_recv(self, expr: DistributedRecv) -> abc_Set[Array]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> frozenset[Array]: - return frozenset([expr.passthrough_data]) + ) -> abc_Set[Array]: + return FrozenOrderedSet([expr.passthrough_data]) - def map_named_call_result(self, expr: NamedCallResult) -> frozenset[Array]: + def map_named_call_result(self, expr: NamedCallResult) -> abc_Set[Array]: raise NotImplementedError( "DirectPredecessorsGetter does not yet support expressions containing " "functions.") diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 5865ec491..7c9b510e7 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -476,9 +476,10 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: frozenset[CommunicationOpIdentifier] - ) -> frozenset[CommunicationOpIdentifier]: - return reduce(frozenset.union, args, frozenset()) + self, *args: Tuple[CommunicationOpIdentifier] + ) -> Tuple[CommunicationOpIdentifier]: + from pytools import unique + return reduce(lambda x, y: tuple(unique(x+y)), args, tuple()) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder @@ -496,8 +497,8 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> frozenset[CommunicationOpIdentifier]: - return frozenset() + def _map_input_base(self, expr: Array) -> Tuple[CommunicationOpIdentifier]: + return tuple() map_placeholder = _map_input_base map_data_wrapper = _map_input_base @@ -505,21 +506,21 @@ def _map_input_base(self, expr: Array) -> frozenset[CommunicationOpIdentifier]: def map_distributed_recv( self, expr: DistributedRecv - ) -> frozenset[CommunicationOpIdentifier]: + ) -> Tuple[CommunicationOpIdentifier]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: from pytato.distributed.verify import DuplicateRecvError raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'") - self.local_comm_ids_to_needed_comm_ids[recv_id] = frozenset() + self.local_comm_ids_to_needed_comm_ids[recv_id] = tuple() self.local_recv_id_to_recv_node[recv_id] = expr - return frozenset({recv_id}) + return (recv_id,) def map_named_call_result( - self, expr: NamedCallResult) -> frozenset[CommunicationOpIdentifier]: + self, expr: NamedCallResult) -> Tuple[CommunicationOpIdentifier]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -557,10 +558,10 @@ def _schedule_task_batches_counted( task_to_dep_level, visits_in_depend = \ _calculate_dependency_levels(task_ids_to_needed_task_ids) nlevels = 1 + max(task_to_dep_level.values(), default=-1) - task_batches: Sequence[set[TaskType]] = [set() for _ in range(nlevels)] + task_batches: Sequence[List[TaskType]] = [list() for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): - task_batches[dep_level].add(task_id) + task_batches[dep_level].append(task_id) return task_batches, visits_in_depend + len(task_to_dep_level.keys()) @@ -623,7 +624,7 @@ class _MaterializedArrayCollector(CachedWalkMapper): """ def __init__(self) -> None: super().__init__() - self.materialized_arrays: _OrderedSet[Array] = _OrderedSet() + self.materialized_arrays: List[Array] = [] def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @@ -633,15 +634,15 @@ def post_visit(self, expr: Any) -> None: from pytato.tags import ImplStored if (isinstance(expr, Array) and expr.tags_of_type(ImplStored)): - self.materialized_arrays.add(expr) + self.materialized_arrays.append(expr) if isinstance(expr, LoopyCallResult): - self.materialized_arrays.add(expr) + self.materialized_arrays.append(expr) from pytato.loopy import LoopyCall assert isinstance(expr._container, LoopyCall) for _, subexpr in sorted(expr._container.bindings.items()): if isinstance(subexpr, Array): - self.materialized_arrays.add(subexpr) + self.materialized_arrays.append(subexpr) else: assert isinstance(subexpr, SCALAR_CLASSES) @@ -651,13 +652,13 @@ def post_visit(self, expr: Any) -> None: # {{{ _set_dict_union_mpi def _set_dict_union_mpi( - dict_a: Mapping[_KeyT, frozenset[_ValueT]], - dict_b: Mapping[_KeyT, frozenset[_ValueT]], - mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, frozenset[_ValueT]]: + dict_a: Mapping[_KeyT, Sequence[_ValueT]], + dict_b: Mapping[_KeyT, Sequence[_ValueT]], + mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, Sequence[_ValueT]]: assert mpi_data_type is None result = dict(dict_a) for key, values in dict_b.items(): - result[key] = result.get(key, frozenset()) | values + result[key] = result.get(key, tuple()) + values return result # }}} @@ -782,6 +783,8 @@ def find_distributed_partition( - Gather sent arrays into assigned in :attr:`DistributedGraphPart.name_to_send_nodes`. """ + from pytools import unique + import mpi4py.MPI as MPI from pytato.transform import SubsetDependencyMapper @@ -833,12 +836,13 @@ def find_distributed_partition( # {{{ create (local) parts out of batch ids - part_comm_ids: list[_PartCommIDs] = [] + + part_comm_ids: List[_PartCommIDs] = [] if comm_batches: - recv_ids: frozenset[CommunicationOpIdentifier] = frozenset() + recv_ids: Tuple[CommunicationOpIdentifier] = tuple() for batch in comm_batches: - send_ids = frozenset( - comm_id for comm_id in batch + send_ids = tuple( + comm_id for comm_id in unique(batch) if comm_id.src_rank == local_rank) if recv_ids or send_ids: part_comm_ids.append( @@ -846,19 +850,19 @@ def find_distributed_partition( recv_ids=recv_ids, send_ids=send_ids)) # These go into the next part - recv_ids = frozenset( - comm_id for comm_id in batch + recv_ids = tuple( + comm_id for comm_id in unique(batch) if comm_id.dest_rank == local_rank) if recv_ids: part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, - send_ids=frozenset())) + send_ids=tuple())) else: part_comm_ids.append( _PartCommIDs( - recv_ids=frozenset(), - send_ids=frozenset())) + recv_ids=tuple(), + send_ids=tuple())) nparts = len(part_comm_ids) @@ -876,7 +880,7 @@ def find_distributed_partition( comm_id_to_part_id = { comm_id: ipart for ipart, comm_ids in enumerate(part_comm_ids) - for comm_id in comm_ids.send_ids | comm_ids.recv_ids} + for comm_id in unique(comm_ids.send_ids + comm_ids.recv_ids)} # }}} @@ -888,10 +892,10 @@ def find_distributed_partition( # The sets of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic - sent_arrays = _OrderedSet( + sent_arrays = tuple( send_node.data for send_node in lsrdg.local_send_id_to_send_node.values()) - received_arrays = _OrderedSet(lsrdg.local_recv_id_to_recv_node.values()) + received_arrays = tuple(lsrdg.local_recv_id_to_recv_node.values()) # While receive nodes may be marked as materialized, we shouldn't be # including them here because we're using them (along with the send nodes) @@ -899,14 +903,16 @@ def find_distributed_partition( # We could allow sent *arrays* to be included here because they are distinct # from send *nodes*, but we choose to exclude them in order to simplify the # processing below. - materialized_arrays = ( - materialized_arrays_collector.materialized_arrays - - received_arrays - - sent_arrays) + materialized_arrays_set = set(materialized_arrays_collector.materialized_arrays) \ + - set(received_arrays) \ + - set(sent_arrays) + + from pytools import unique + materialized_arrays = tuple(a for a in materialized_arrays_collector.materialized_arrays if a in materialized_arrays_set) # "mso" for "materialized/sent/output" - output_arrays = _OrderedSet(outputs._data.values()) - mso_arrays = materialized_arrays | sent_arrays | output_arrays + output_arrays = tuple(outputs._data.values()) + mso_arrays = materialized_arrays + sent_arrays + output_arrays # FIXME: This gathers up materialized_arrays recursively, leading to # result sizes potentially quadratic in the number of materialized arrays. @@ -970,7 +976,7 @@ def find_distributed_partition( assert all(0 <= part_id < nparts for part_id in stored_ary_to_part_id.values()) - stored_arrays = _OrderedSet(stored_ary_to_part_id) + stored_arrays = tuple(unique(stored_ary_to_part_id)) # {{{ find which stored arrays should become part outputs # (because they are used in not just their local part, but also others) @@ -986,13 +992,13 @@ def get_materialized_predecessors(ary: Array) -> _OrderedSet[Array]: materialized_preds |= get_materialized_predecessors(pred) return materialized_preds - stored_arrays_promoted_to_part_outputs = { + stored_arrays_promoted_to_part_outputs = tuple(unique( stored_pred for stored_ary in stored_arrays for stored_pred in get_materialized_predecessors(stored_ary) if (stored_ary_to_part_id[stored_ary] != stored_ary_to_part_id[stored_pred]) - } + )) # }}} diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index b78c24301..56b2a53d6 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -926,9 +926,10 @@ def __init__(self, universe: frozenset[Array]): def combine(self, *args: frozenset[Array]) -> frozenset[Array]: from functools import reduce - return reduce(lambda acc, arg: acc | (arg & self.universe), + from pytools import unique + return reduce(lambda acc, arg: unique(tuple(acc) + tuple(set(arg) & self.universe)), args, - frozenset()) + tuple()) # }}} From ef7ea0bb74e1543b317558017f03f7be6352e5c1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 16:33:06 -0500 Subject: [PATCH 102/124] cleanups --- pytato/analysis/__init__.py | 6 ++++- pytato/distributed/partition.py | 43 ++++++++++++++++++--------------- pytato/transform/__init__.py | 6 +++-- setup.py | 1 + 4 files changed, 33 insertions(+), 23 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 1a4359e4b..d7a8e3353 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -310,14 +310,18 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter +from collections.abc import Set as abc_Set + from orderedsets import FrozenOrderedSet + class DirectPredecessorsGetter(Mapper): """ Mapper to get the `direct predecessors `__ of a node. + .. note:: We only consider the predecessors of a nodes in a data-flow sense. """ @@ -341,7 +345,7 @@ def map_einsum(self, expr: Einsum) -> abc_Set[Array]: | self._get_preds_from_shape(expr.shape)) def map_loopy_call_result(self, expr: NamedArray) -> abc_Set[Array]: - from pytato.loopy import LoopyCallResult, LoopyCall + from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) return (FrozenOrderedSet(ary diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 7c9b510e7..2d4b1c93a 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -476,10 +476,10 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: Tuple[CommunicationOpIdentifier] - ) -> Tuple[CommunicationOpIdentifier]: + self, *args: tuple[CommunicationOpIdentifier] + ) -> tuple[CommunicationOpIdentifier]: from pytools import unique - return reduce(lambda x, y: tuple(unique(x+y)), args, tuple()) + return reduce(lambda x, y: tuple(unique(x+y)), args, ()) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder @@ -497,8 +497,8 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> Tuple[CommunicationOpIdentifier]: - return tuple() + def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier]: + return () map_placeholder = _map_input_base map_data_wrapper = _map_input_base @@ -506,21 +506,21 @@ def _map_input_base(self, expr: Array) -> Tuple[CommunicationOpIdentifier]: def map_distributed_recv( self, expr: DistributedRecv - ) -> Tuple[CommunicationOpIdentifier]: + ) -> tuple[CommunicationOpIdentifier]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: from pytato.distributed.verify import DuplicateRecvError raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'") - self.local_comm_ids_to_needed_comm_ids[recv_id] = tuple() + self.local_comm_ids_to_needed_comm_ids[recv_id] = () self.local_recv_id_to_recv_node[recv_id] = expr return (recv_id,) def map_named_call_result( - self, expr: NamedCallResult) -> Tuple[CommunicationOpIdentifier]: + self, expr: NamedCallResult) -> tuple[CommunicationOpIdentifier]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -558,7 +558,7 @@ def _schedule_task_batches_counted( task_to_dep_level, visits_in_depend = \ _calculate_dependency_levels(task_ids_to_needed_task_ids) nlevels = 1 + max(task_to_dep_level.values(), default=-1) - task_batches: Sequence[List[TaskType]] = [list() for _ in range(nlevels)] + task_batches: Sequence[list[TaskType]] = [[] for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): task_batches[dep_level].append(task_id) @@ -624,7 +624,7 @@ class _MaterializedArrayCollector(CachedWalkMapper): """ def __init__(self) -> None: super().__init__() - self.materialized_arrays: List[Array] = [] + self.materialized_arrays: list[Array] = [] def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @@ -658,7 +658,7 @@ def _set_dict_union_mpi( assert mpi_data_type is None result = dict(dict_a) for key, values in dict_b.items(): - result[key] = result.get(key, tuple()) + values + result[key] = result.get(key, ()) + values return result # }}} @@ -783,10 +783,10 @@ def find_distributed_partition( - Gather sent arrays into assigned in :attr:`DistributedGraphPart.name_to_send_nodes`. """ - from pytools import unique - import mpi4py.MPI as MPI + from pytools import unique + from pytato.transform import SubsetDependencyMapper local_rank = mpi_communicator.rank @@ -837,9 +837,9 @@ def find_distributed_partition( # {{{ create (local) parts out of batch ids - part_comm_ids: List[_PartCommIDs] = [] + part_comm_ids: list[_PartCommIDs] = [] if comm_batches: - recv_ids: Tuple[CommunicationOpIdentifier] = tuple() + recv_ids: tuple[CommunicationOpIdentifier] = () for batch in comm_batches: send_ids = tuple( comm_id for comm_id in unique(batch) @@ -857,12 +857,12 @@ def find_distributed_partition( part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, - send_ids=tuple())) + send_ids=())) else: part_comm_ids.append( _PartCommIDs( - recv_ids=tuple(), - send_ids=tuple())) + recv_ids=(), + send_ids=())) nparts = len(part_comm_ids) @@ -908,7 +908,9 @@ def find_distributed_partition( - set(sent_arrays) from pytools import unique - materialized_arrays = tuple(a for a in materialized_arrays_collector.materialized_arrays if a in materialized_arrays_set) + materialized_arrays = tuple( + a for a in materialized_arrays_collector.materialized_arrays + if a in materialized_arrays_set) # "mso" for "materialized/sent/output" output_arrays = tuple(outputs._data.values()) @@ -927,7 +929,8 @@ def find_distributed_partition( comm_id_to_part_id[send_id]) if __debug__: - recvd_array_dep_mapper = SubsetDependencyMapper(frozenset(received_arrays)) + recvd_array_dep_mapper = SubsetDependencyMapper(frozenset + (received_arrays)) mso_ary_to_last_dep_recv_part_id: dict[Array, int] = { ary: max( diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 56b2a53d6..642f52839 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -926,10 +926,12 @@ def __init__(self, universe: frozenset[Array]): def combine(self, *args: frozenset[Array]) -> frozenset[Array]: from functools import reduce + from pytools import unique - return reduce(lambda acc, arg: unique(tuple(acc) + tuple(set(arg) & self.universe)), + return reduce(lambda acc, arg: + unique(tuple(acc) + tuple(set(arg) & self.universe)), args, - tuple()) + ()) # }}} diff --git a/setup.py b/setup.py index ba0bd1b4d..9fe0df6b1 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ "immutabledict", "attrs", "bidict", + "orderedsets", ], package_data={"pytato": ["py.typed"]}, author="Andreas Kloeckner, Matt Wala, Xiaoyu Wei", From 817b255ca54f3232bc01fbd9c11ece34b54b7cac Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 16:44:33 -0500 Subject: [PATCH 103/124] enable determinism test --- test/test_distributed.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_distributed.py b/test/test_distributed.py index ac7ca1389..1554a024b 100644 --- a/test/test_distributed.py +++ b/test/test_distributed.py @@ -899,13 +899,11 @@ def test_number_symbolic_tags_bare_classes(ctx_factory): outputs = pt.make_dict_of_named_arrays({"out": res}) partition = pt.find_distributed_partition(comm, outputs) - (_distp, next_tag) = pt.number_distributed_tags(comm, partition, base_tag=4242) + (distp, next_tag) = pt.number_distributed_tags(comm, partition, base_tag=4242) assert next_tag == 4244 - # FIXME: For the next assertion, find_distributed_partition needs to be - # deterministic too (https://github.com/inducer/pytato/pull/465). - # assert next(iter(distp.parts[0].name_to_send_nodes.values()))[0].comm_tag == 4242 # noqa: E501 + assert next(iter(distp.parts[0].name_to_send_nodes.values()))[0].comm_tag == 4242 # }}} From f3f3c7df968088a05a2792e189d43a70c14e34c5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 16:54:43 -0500 Subject: [PATCH 104/124] eliminate _OrderedSets --- pytato/distributed/partition.py | 70 +++------------------------------ 1 file changed, 6 insertions(+), 64 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 2d4b1c93a..6d3adb319 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -62,7 +62,6 @@ THE SOFTWARE. """ -import collections from functools import reduce from typing import ( TYPE_CHECKING, @@ -70,8 +69,6 @@ Any, FrozenSet, Hashable, - Iterable, - Iterator, Mapping, Sequence, TypeVar, @@ -131,61 +128,6 @@ class CommunicationOpIdentifier: _ValueT = TypeVar("_ValueT") -# {{{ crude ordered set - - -class _OrderedSet(collections.abc.MutableSet[_ValueT]): - def __init__(self, items: Iterable[_ValueT] | None = None): - # Could probably also use a valueless dictionary; not sure if it matters - self._items: set[_ValueT] = set() - self._items_ordered: list[_ValueT] = [] - if items is not None: - for item in items: - self.add(item) - - def add(self, item: _ValueT) -> None: - if item not in self._items: - self._items.add(item) - self._items_ordered.append(item) - - def discard(self, item: _ValueT) -> None: - # Not currently needed - raise NotImplementedError - - def __len__(self) -> int: - return len(self._items) - - def __iter__(self) -> Iterator[_ValueT]: - return iter(self._items_ordered) - - def __contains__(self, item: Any) -> bool: - return item in self._items - - def __and__(self, other: AbstractSet[_ValueT]) -> _OrderedSet[_ValueT]: - result: _OrderedSet[_ValueT] = _OrderedSet() - for item in self._items_ordered: - if item in other: - result.add(item) - return result - - # Must be "Any" instead of "_ValueT", otherwise it violates Liskov substitution - # according to mypy. *shrug* - def __or__(self, other: AbstractSet[Any]) -> _OrderedSet[_ValueT]: - result: _OrderedSet[_ValueT] = _OrderedSet(self._items_ordered) - for item in other: - result.add(item) - return result - - def __sub__(self, other: AbstractSet[_ValueT]) -> _OrderedSet[_ValueT]: - result: _OrderedSet[_ValueT] = _OrderedSet() - for item in self._items_ordered: - if item not in other: - result.add(item) - return result - -# }}} - - # {{{ distributed graph part PartId = Hashable @@ -836,7 +778,6 @@ def find_distributed_partition( # {{{ create (local) parts out of batch ids - part_comm_ids: list[_PartCommIDs] = [] if comm_batches: recv_ids: tuple[CommunicationOpIdentifier] = () @@ -986,14 +927,15 @@ def find_distributed_partition( direct_preds_getter = DirectPredecessorsGetter() - def get_materialized_predecessors(ary: Array) -> _OrderedSet[Array]: - materialized_preds: _OrderedSet[Array] = _OrderedSet() + def get_materialized_predecessors(ary: Array) -> tuple[Array]: + materialized_preds: dict[Array, None] = {} for pred in direct_preds_getter(ary): if pred in materialized_arrays: - materialized_preds.add(pred) + materialized_preds[pred] = None else: - materialized_preds |= get_materialized_predecessors(pred) - return materialized_preds + for p in get_materialized_predecessors(pred): + materialized_preds[p] = None + return tuple(materialized_preds.keys()) stored_arrays_promoted_to_part_outputs = tuple(unique( stored_pred From 8bf2daf69260673b03ed8ef6fd2c03d01483eb04 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 17:29:07 -0500 Subject: [PATCH 105/124] misc improvements --- pytato/distributed/partition.py | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 6d3adb319..81a493c9a 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -503,7 +503,8 @@ def _schedule_task_batches_counted( task_batches: Sequence[list[TaskType]] = [[] for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): - task_batches[dep_level].append(task_id) + if task_id not in task_batches[dep_level]: + task_batches[dep_level].append(task_id) return task_batches, visits_in_depend + len(task_to_dep_level.keys()) @@ -566,7 +567,7 @@ class _MaterializedArrayCollector(CachedWalkMapper): """ def __init__(self) -> None: super().__init__() - self.materialized_arrays: list[Array] = [] + self.materialized_arrays: dict[Array, None] = {} def get_cache_key(self, expr: ArrayOrNames) -> int: return id(expr) @@ -576,15 +577,15 @@ def post_visit(self, expr: Any) -> None: from pytato.tags import ImplStored if (isinstance(expr, Array) and expr.tags_of_type(ImplStored)): - self.materialized_arrays.append(expr) + self.materialized_arrays[expr] = None if isinstance(expr, LoopyCallResult): - self.materialized_arrays.append(expr) + self.materialized_arrays[expr] = None from pytato.loopy import LoopyCall assert isinstance(expr._container, LoopyCall) for _, subexpr in sorted(expr._container.bindings.items()): if isinstance(subexpr, Array): - self.materialized_arrays.append(subexpr) + self.materialized_arrays[subexpr] = None else: assert isinstance(subexpr, SCALAR_CLASSES) @@ -596,11 +597,12 @@ def post_visit(self, expr: Any) -> None: def _set_dict_union_mpi( dict_a: Mapping[_KeyT, Sequence[_ValueT]], dict_b: Mapping[_KeyT, Sequence[_ValueT]], - mpi_data_type: mpi4py.MPI.Datatype) -> Mapping[_KeyT, Sequence[_ValueT]]: + mpi_data_type: mpi4py.MPI.Datatype | None) -> Mapping[_KeyT, Sequence[_ValueT]]: assert mpi_data_type is None + from pytools import unique result = dict(dict_a) for key, values in dict_b.items(): - result[key] = result.get(key, ()) + values + result[key] = tuple(unique(result.get(key, ()) + values)) return result # }}} @@ -833,10 +835,10 @@ def find_distributed_partition( # The sets of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic - sent_arrays = tuple( - send_node.data for send_node in lsrdg.local_send_id_to_send_node.values()) + sent_arrays = tuple(unique( + send_node.data for send_node in lsrdg.local_send_id_to_send_node.values())) - received_arrays = tuple(lsrdg.local_recv_id_to_recv_node.values()) + received_arrays = tuple(unique(lsrdg.local_recv_id_to_recv_node.values())) # While receive nodes may be marked as materialized, we shouldn't be # including them here because we're using them (along with the send nodes) @@ -849,13 +851,13 @@ def find_distributed_partition( - set(sent_arrays) from pytools import unique - materialized_arrays = tuple( + materialized_arrays = tuple(unique( a for a in materialized_arrays_collector.materialized_arrays - if a in materialized_arrays_set) + if a in materialized_arrays_set)) # "mso" for "materialized/sent/output" - output_arrays = tuple(outputs._data.values()) - mso_arrays = materialized_arrays + sent_arrays + output_arrays + output_arrays = tuple(unique(outputs._data.values())) + mso_arrays = tuple(unique(materialized_arrays + sent_arrays + output_arrays)) # FIXME: This gathers up materialized_arrays recursively, leading to # result sizes potentially quadratic in the number of materialized arrays. @@ -870,8 +872,7 @@ def find_distributed_partition( comm_id_to_part_id[send_id]) if __debug__: - recvd_array_dep_mapper = SubsetDependencyMapper(frozenset - (received_arrays)) + recvd_array_dep_mapper = SubsetDependencyMapper(frozenset(received_arrays)) mso_ary_to_last_dep_recv_part_id: dict[Array, int] = { ary: max( From 5d906b5edb98425f48bcc3aada14725f0c9223b1 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 17:41:48 -0500 Subject: [PATCH 106/124] revert change to SubsetDependencyMapper --- pytato/transform/__init__.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 642f52839..b78c24301 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -926,12 +926,9 @@ def __init__(self, universe: frozenset[Array]): def combine(self, *args: frozenset[Array]) -> frozenset[Array]: from functools import reduce - - from pytools import unique - return reduce(lambda acc, arg: - unique(tuple(acc) + tuple(set(arg) & self.universe)), + return reduce(lambda acc, arg: acc | (arg & self.universe), args, - ()) + frozenset()) # }}} From 142c8e63cf4ca9d5c570fb4e99a3d86594770bc8 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 25 Jul 2024 17:50:03 -0500 Subject: [PATCH 107/124] some mypy fixes --- pytato/distributed/partition.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 81a493c9a..e8f4b1fb2 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -67,7 +67,6 @@ TYPE_CHECKING, AbstractSet, Any, - FrozenSet, Hashable, Mapping, Sequence, @@ -316,8 +315,8 @@ def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: class _PartCommIDs: """A *part*, unlike a *batch*, begins with receives and ends with sends. """ - recv_ids: frozenset[CommunicationOpIdentifier] - send_ids: frozenset[CommunicationOpIdentifier] + recv_ids: tuple[CommunicationOpIdentifier] + send_ids: tuple[CommunicationOpIdentifier] # {{{ _make_distributed_partition @@ -403,12 +402,12 @@ def _recv_to_comm_id( class _LocalSendRecvDepGatherer( - CombineMapper[FrozenSet[CommunicationOpIdentifier]]): + CombineMapper[tuple[CommunicationOpIdentifier]]): def __init__(self, local_rank: int) -> None: super().__init__() self.local_comm_ids_to_needed_comm_ids: \ dict[CommunicationOpIdentifier, - frozenset[CommunicationOpIdentifier]] = {} + tuple[CommunicationOpIdentifier]] = {} self.local_recv_id_to_recv_node: \ dict[CommunicationOpIdentifier, DistributedRecv] = {} @@ -425,7 +424,7 @@ def combine( def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> frozenset[CommunicationOpIdentifier]: + ) -> tuple[CommunicationOpIdentifier]: send_id = _send_to_comm_id(self.local_rank, expr.send) if send_id in self.local_send_id_to_send_node: @@ -476,7 +475,7 @@ def map_named_call_result( def _schedule_task_batches( task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ - -> Sequence[AbstractSet[TaskType]]: + -> Sequence[list[TaskType]]: """For each :type:`TaskType`, determine the 'round'/'batch' during which it will be performed. A 'batch' of tasks consists of tasks which do not depend on each other. @@ -491,7 +490,7 @@ def _schedule_task_batches( def _schedule_task_batches_counted( task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ - -> tuple[Sequence[AbstractSet[TaskType]], int]: + -> tuple[Sequence[list[TaskType]], int]: """ Static type checkers need the functions to return the same type regardless of the input. The testing code needs to know about the number of tasks visited @@ -773,7 +772,7 @@ def find_distributed_partition( raise comm_batches_or_exc comm_batches = cast( - Sequence[AbstractSet[CommunicationOpIdentifier]], + Sequence[list[CommunicationOpIdentifier]], comm_batches_or_exc) # }}} @@ -928,7 +927,7 @@ def find_distributed_partition( direct_preds_getter = DirectPredecessorsGetter() - def get_materialized_predecessors(ary: Array) -> tuple[Array]: + def get_materialized_predecessors(ary: Array) -> tuple[Array, ...]: materialized_preds: dict[Array, None] = {} for pred in direct_preds_getter(ary): if pred in materialized_arrays: From bd7062071e283e7b063da74ca3f3d15c51986854 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Mon, 12 Aug 2024 16:12:04 -0700 Subject: [PATCH 108/124] ruff --- pytato/reductions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/reductions.py b/pytato/reductions.py index 999ef1af4..b0f2b7fb2 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -178,9 +178,9 @@ def _normalize_reduction_axes( raise ValueError(f"{axis} is out of bounds for array of dimension" f" {len(shape)}.") - new_shape = tuple([axis_len + new_shape = tuple(axis_len for i, axis_len in enumerate(shape) - if i not in reduction_axes]) + if i not in reduction_axes) return new_shape, reduction_axes From 076a76ebe152f8d82c47de31a05f25586b981e4f Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 13 Aug 2024 13:51:12 -0500 Subject: [PATCH 109/124] replace orderedsets with unique tuples in DirectPredecessorsGetter --- pytato/analysis/__init__.py | 70 +++++++++++++++++-------------------- setup.py | 1 - 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index fa8ac31e7..883030a43 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any, Mapping from pymbolic.mapper.optimize import optimize_mapper -from pytools import memoize_method +from pytools import memoize_method, unique from pytato.array import ( Array, @@ -314,11 +314,6 @@ def is_einsum_similar_to_subscript(expr: Einsum, subscripts: str) -> bool: # {{{ DirectPredecessorsGetter -from collections.abc import Set as abc_Set - -from orderedsets import FrozenOrderedSet - - class DirectPredecessorsGetter(Mapper): """ Mapper to get the @@ -327,74 +322,75 @@ class DirectPredecessorsGetter(Mapper): of a node. .. note:: + We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> abc_Set[ArrayOrNames]: - return FrozenOrderedSet([dim for dim in shape if isinstance(dim, Array)]) + def _get_preds_from_shape(self, shape: ShapeType) -> tuple[ArrayOrNames]: + return tuple(unique(dim for dim in shape if isinstance(dim, Array))) - def map_index_lambda(self, expr: IndexLambda) -> abc_Set[ArrayOrNames]: - return (FrozenOrderedSet(expr.bindings.values()) - | self._get_preds_from_shape(expr.shape)) + def map_index_lambda(self, expr: IndexLambda) -> tuple[ArrayOrNames]: + return tuple(unique(tuple(expr.bindings.values()) + + self._get_preds_from_shape(expr.shape))) - def map_stack(self, expr: Stack) -> abc_Set[ArrayOrNames]: - return (FrozenOrderedSet(expr.arrays) - | self._get_preds_from_shape(expr.shape)) + def map_stack(self, expr: Stack) -> tuple[ArrayOrNames]: + return tuple(unique(tuple(expr.arrays) + + self._get_preds_from_shape(expr.shape))) - def map_concatenate(self, expr: Concatenate) -> abc_Set[ArrayOrNames]: - return (FrozenOrderedSet(expr.arrays) - | self._get_preds_from_shape(expr.shape)) + map_concatenate = map_stack - def map_einsum(self, expr: Einsum) -> abc_Set[ArrayOrNames]: - return (FrozenOrderedSet(expr.args) - | self._get_preds_from_shape(expr.shape)) + def map_einsum(self, expr: Einsum) -> tuple[ArrayOrNames]: + return tuple(unique(tuple(expr.args) + + self._get_preds_from_shape(expr.shape))) - def map_loopy_call_result(self, expr: NamedArray) -> abc_Set[Array]: + def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames]: from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) - return (FrozenOrderedSet(ary + return tuple(unique(tuple(ary for ary in expr._container.bindings.values() if isinstance(ary, Array)) - | self._get_preds_from_shape(expr.shape)) + + self._get_preds_from_shape(expr.shape))) - def _map_index_base(self, expr: IndexBase) -> abc_Set[ArrayOrNames]: - return (FrozenOrderedSet([expr.array]) - | FrozenOrderedSet(idx for idx in expr.indices + def _map_index_base(self, expr: IndexBase) -> tuple[ArrayOrNames]: + return tuple(unique((expr.array,) # noqa: RUF005 + + tuple(idx for idx in expr.indices if isinstance(idx, Array)) - | self._get_preds_from_shape(expr.shape)) + + self._get_preds_from_shape(expr.shape))) map_basic_index = _map_index_base map_contiguous_advanced_index = _map_index_base map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase - ) -> abc_Set[ArrayOrNames]: - return FrozenOrderedSet([expr.array]) + ) -> tuple[ArrayOrNames]: + return (expr.array,) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> abc_Set[ArrayOrNames]: + def _map_input_base(self, expr: InputArgumentBase) -> tuple[ArrayOrNames]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> abc_Set[ArrayOrNames]: + def map_distributed_recv(self, expr: DistributedRecv) -> tuple[ArrayOrNames]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> abc_Set[ArrayOrNames]: - return FrozenOrderedSet([expr.passthrough_data]) + ) -> tuple[ArrayOrNames]: + return (expr.passthrough_data,) + + def map_call(self, expr: Call) -> tuple[ArrayOrNames]: + return tuple(unique(expr.bindings.values())) - def map_call(self, expr: Call) -> abc_Set[ArrayOrNames]: - return FrozenOrderedSet(expr.bindings.values()) + def map_named_call_result( + self, expr: NamedCallResult) -> tuple[ArrayOrNames]: + return (expr._container,) - def map_named_call_result(self, expr: NamedCallResult) -> abc_Set[ArrayOrNames]: - return FrozenOrderedSet([expr._container]) # }}} diff --git a/setup.py b/setup.py index 9fe0df6b1..ba0bd1b4d 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,6 @@ "immutabledict", "attrs", "bidict", - "orderedsets", ], package_data={"pytato": ["py.typed"]}, author="Andreas Kloeckner, Matt Wala, Xiaoyu Wei", From ea1462c0bd71ea8bf2a6d80e26e83fdbc255fc57 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Aug 2024 11:37:53 -0500 Subject: [PATCH 110/124] mypy fixes --- pytato/analysis/__init__.py | 18 +++++++++--------- pytato/distributed/partition.py | 26 +++++++++++++------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 883030a43..c568c8f9c 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -325,24 +325,24 @@ class DirectPredecessorsGetter(Mapper): We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> tuple[ArrayOrNames]: + def _get_preds_from_shape(self, shape: ShapeType) -> tuple[ArrayOrNames, ...]: return tuple(unique(dim for dim in shape if isinstance(dim, Array))) - def map_index_lambda(self, expr: IndexLambda) -> tuple[ArrayOrNames]: + def map_index_lambda(self, expr: IndexLambda) -> tuple[ArrayOrNames, ...]: return tuple(unique(tuple(expr.bindings.values()) + self._get_preds_from_shape(expr.shape))) - def map_stack(self, expr: Stack) -> tuple[ArrayOrNames]: + def map_stack(self, expr: Stack) -> tuple[ArrayOrNames, ...]: return tuple(unique(tuple(expr.arrays) + self._get_preds_from_shape(expr.shape))) map_concatenate = map_stack - def map_einsum(self, expr: Einsum) -> tuple[ArrayOrNames]: + def map_einsum(self, expr: Einsum) -> tuple[ArrayOrNames, ...]: return tuple(unique(tuple(expr.args) + self._get_preds_from_shape(expr.shape))) - def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames]: + def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames, ...]: from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) @@ -351,7 +351,7 @@ def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames]: if isinstance(ary, Array)) + self._get_preds_from_shape(expr.shape))) - def _map_index_base(self, expr: IndexBase) -> tuple[ArrayOrNames]: + def _map_index_base(self, expr: IndexBase) -> tuple[ArrayOrNames, ...]: return tuple(unique((expr.array,) # noqa: RUF005 + tuple(idx for idx in expr.indices if isinstance(idx, Array)) @@ -369,14 +369,14 @@ def _map_index_remapping_base(self, expr: IndexRemappingBase map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> tuple[ArrayOrNames]: + def _map_input_base(self, expr: InputArgumentBase) -> tuple[ArrayOrNames, ...]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> tuple[ArrayOrNames]: + def map_distributed_recv(self, expr: DistributedRecv) -> tuple[ArrayOrNames, ...]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, @@ -384,7 +384,7 @@ def map_distributed_send_ref_holder(self, ) -> tuple[ArrayOrNames]: return (expr.passthrough_data,) - def map_call(self, expr: Call) -> tuple[ArrayOrNames]: + def map_call(self, expr: Call) -> tuple[ArrayOrNames, ...]: return tuple(unique(expr.bindings.values())) def map_named_call_result( diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index e8f4b1fb2..68a924c8a 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -315,8 +315,8 @@ def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: class _PartCommIDs: """A *part*, unlike a *batch*, begins with receives and ends with sends. """ - recv_ids: tuple[CommunicationOpIdentifier] - send_ids: tuple[CommunicationOpIdentifier] + recv_ids: tuple[CommunicationOpIdentifier, ...] + send_ids: tuple[CommunicationOpIdentifier, ...] # {{{ _make_distributed_partition @@ -402,12 +402,12 @@ def _recv_to_comm_id( class _LocalSendRecvDepGatherer( - CombineMapper[tuple[CommunicationOpIdentifier]]): + CombineMapper[tuple[CommunicationOpIdentifier, ...]]): def __init__(self, local_rank: int) -> None: super().__init__() self.local_comm_ids_to_needed_comm_ids: \ dict[CommunicationOpIdentifier, - tuple[CommunicationOpIdentifier]] = {} + tuple[CommunicationOpIdentifier, ...]] = {} self.local_recv_id_to_recv_node: \ dict[CommunicationOpIdentifier, DistributedRecv] = {} @@ -417,14 +417,14 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: tuple[CommunicationOpIdentifier] - ) -> tuple[CommunicationOpIdentifier]: + self, *args: tuple[CommunicationOpIdentifier, ...] + ) -> tuple[CommunicationOpIdentifier, ...]: from pytools import unique return reduce(lambda x, y: tuple(unique(x+y)), args, ()) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> tuple[CommunicationOpIdentifier]: + ) -> tuple[CommunicationOpIdentifier, ...]: send_id = _send_to_comm_id(self.local_rank, expr.send) if send_id in self.local_send_id_to_send_node: @@ -438,7 +438,7 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier]: + def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier, ...]: return () map_placeholder = _map_input_base @@ -447,7 +447,7 @@ def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier]: def map_distributed_recv( self, expr: DistributedRecv - ) -> tuple[CommunicationOpIdentifier]: + ) -> tuple[CommunicationOpIdentifier, ...]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: @@ -461,7 +461,7 @@ def map_distributed_recv( return (recv_id,) def map_named_call_result( - self, expr: NamedCallResult) -> tuple[CommunicationOpIdentifier]: + self, expr: NamedCallResult) -> tuple[CommunicationOpIdentifier, ...]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -594,8 +594,8 @@ def post_visit(self, expr: Any) -> None: # {{{ _set_dict_union_mpi def _set_dict_union_mpi( - dict_a: Mapping[_KeyT, Sequence[_ValueT]], - dict_b: Mapping[_KeyT, Sequence[_ValueT]], + dict_a: Mapping[_KeyT, tuple[_ValueT, ...]], + dict_b: Mapping[_KeyT, tuple[_ValueT, ...]], mpi_data_type: mpi4py.MPI.Datatype | None) -> Mapping[_KeyT, Sequence[_ValueT]]: assert mpi_data_type is None from pytools import unique @@ -781,7 +781,7 @@ def find_distributed_partition( part_comm_ids: list[_PartCommIDs] = [] if comm_batches: - recv_ids: tuple[CommunicationOpIdentifier] = () + recv_ids: tuple[CommunicationOpIdentifier, ...] = () for batch in comm_batches: send_ids = tuple( comm_id for comm_id in unique(batch) From 168ef532057be4e81649acb9353510a28ed62d84 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Aug 2024 11:48:48 -0500 Subject: [PATCH 111/124] remove unnecesary cast --- pytato/distributed/partition.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 68a924c8a..38f26c8d5 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -771,9 +771,7 @@ def find_distributed_partition( if isinstance(comm_batches_or_exc, Exception): raise comm_batches_or_exc - comm_batches = cast( - Sequence[list[CommunicationOpIdentifier]], - comm_batches_or_exc) + comm_batches = comm_batches_or_exc # }}} From d711989c8cc0198cafdbbd94d294f384667b7c25 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Wed, 14 Aug 2024 14:43:43 -0500 Subject: [PATCH 112/124] adjust comment --- pytato/distributed/partition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 38f26c8d5..dea81e925 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -829,7 +829,7 @@ def find_distributed_partition( materialized_arrays_collector = _MaterializedArrayCollector() materialized_arrays_collector(outputs) - # The sets of arrays below must have a deterministic order in order to ensure + # The collections of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic sent_arrays = tuple(unique( From 679f5cdb64e7166ee2acd609fe8bd6a53274aee8 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 20 Aug 2024 12:40:19 -0500 Subject: [PATCH 113/124] Revert "Implement numpy 2 type promotion" This reverts commit f0deaea1130fe0dd4dfa52bee540910305de9c70. --- README.rst | 7 +- pytato/array.py | 153 +++++++++--------------------------------- pytato/scalar_expr.py | 1 - pytato/utils.py | 67 ++++++++---------- test/test_codegen.py | 16 +---- test/test_pytato.py | 4 +- 6 files changed, 67 insertions(+), 181 deletions(-) diff --git a/README.rst b/README.rst index 524d923de..71ebccaa8 100644 --- a/README.rst +++ b/README.rst @@ -32,9 +32,4 @@ Numpy compatibility Pytato is written to pose no particular restrictions on the version of numpy used for execution. To use mypy-based type checking on Pytato itself or packages using Pytato, numpy 1.20 or newer is required, due to the -typing-based changes to numpy in that release. Furthermore, pytato -now uses type promotion rules aiming to match those in -`numpy 2 `__. -This will not break compatibility with older numpy versions, but may -result in differing data types between computations carried out in -numpy and pytato. +typing-based changes to numpy in that release. diff --git a/pytato/array.py b/pytato/array.py index eefe7c69a..2b0f38681 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -302,96 +302,22 @@ def normalize_shape_component( ConvertibleToIndexExpr = Union[int, slice, "Array", None, EllipsisType] IndexExpr = Union[IntegralT, "NormalizedSlice", "Array", None, EllipsisType] -PyScalarType = Union[type[bool], type[int], type[float], type[complex]] -DtypeOrPyScalarType = Union[_dtype_any, PyScalarType] +DtypeOrScalar = Union[_dtype_any, Scalar] ArrayOrScalar = Union["Array", Scalar] -# https://numpy.org/neps/nep-0050-scalar-promotion.html -class DtypeKindCategory(IntEnum): - BOOLEAN = 0 - INTEGRAL = 1 - INEXACT = 2 +# https://github.com/numpy/numpy/issues/19302 +def _np_result_type( + # actual dtype: + #*arrays_and_dtypes: Union[np.typing.ArrayLike, np.typing.DTypeLike], + # our dtype: + *arrays_and_dtypes: DtypeOrScalar, + ) -> np.dtype[Any]: + return np.result_type(*arrays_and_dtypes) -_dtype_kind_char_to_kind_cat = { - "b": DtypeKindCategory.BOOLEAN, - "i": DtypeKindCategory.INTEGRAL, - "u": DtypeKindCategory.INTEGRAL, - "f": DtypeKindCategory.INEXACT, - "c": DtypeKindCategory.INEXACT, -} - - -_py_type_to_kind_cat = { - bool: DtypeKindCategory.BOOLEAN, - int: DtypeKindCategory.INTEGRAL, - float: DtypeKindCategory.INEXACT, - complex: DtypeKindCategory.INEXACT, -} - - -_float_dtype_to_complex: dict[np.dtype[Any], np.dtype[Any]] = { - np.dtype(np.float32): np.dtype(np.complex64), - np.dtype(np.float64): np.dtype(np.complex128), -} - - -def _complexify_dtype(dtype: np.dtype[Any]) -> np.dtype[Any]: - if dtype.kind == "c": - return dtype - elif dtype.kind == "f": - return _float_dtype_to_complex[dtype] - else: - raise ValueError("can only complexify types that are already inexact") - - -def _np_result_dtype(*dtypes: DtypeOrPyScalarType) -> np.dtype[Any]: - # For numpy 2.0, np.result_type does not implement numpy's type - # promotion behavior. Weird. Hence all this nonsense is needed. - - py_types = [dtype for dtype in dtypes if isinstance(dtype, type)] - - if not py_types: - return np.result_type(*dtypes) - - np_dtypes = [dtype for dtype in dtypes if isinstance(dtype, np.dtype)] - np_kind_cats = { - _dtype_kind_char_to_kind_cat[dtype.kind] for dtype in np_dtypes} - py_kind_cats = {_py_type_to_kind_cat[tp] for tp in py_types} - kind_cats = np_kind_cats | py_kind_cats - - res_kind_cat = max(kind_cats) - max_py_kind_cats = max(py_kind_cats) - max_np_kind_cats = max(np_kind_cats) - - is_complex = (complex in py_types - or any(dtype.kind == "c" for dtype in np_dtypes)) - - if max_py_kind_cats > max_np_kind_cats: - if res_kind_cat == DtypeKindCategory.INTEGRAL: - # FIXME: Perhaps this should be int32 "on some systems, e.g. Windows" - py_promotion_dtype: np.dtype[Any] = np.dtype(np.int64) - elif res_kind_cat == DtypeKindCategory.INEXACT: - if is_complex: - py_promotion_dtype = np.dtype(np.complex128) - else: - py_promotion_dtype = np.dtype(np.float64) - else: - # bool won't ever be promoted to - raise AssertionError() - return np.result_type(*([*np_dtypes, py_promotion_dtype])) - - else: - # Just ignore the python types for promotion. - result = np.result_type(*np_dtypes) - if is_complex: - result = _complexify_dtype(result) - return result - - -def _truediv_result_type(*dtypes: DtypeOrPyScalarType) -> np.dtype[Any]: - dtype = _np_result_dtype(*dtypes) +def _truediv_result_type(arg1: DtypeOrScalar, arg2: DtypeOrScalar) -> np.dtype[Any]: + dtype = _np_result_type(arg1, arg2) # See: test_true_divide in numpy/core/tests/test_ufunc.py # pylint: disable=no-member if dtype.kind in "iu": @@ -652,16 +578,11 @@ def __matmul__(self, other: Array, reverse: bool = False) -> Array: __rmatmul__ = partialmethod(__matmul__, reverse=True) - def _binary_op( - self, - op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], - other: ArrayOrScalar, - get_result_type: Callable[ - [DtypeOrPyScalarType, DtypeOrPyScalarType], - np.dtype[Any]] = _np_result_dtype, - reverse: bool = False, - cast_to_result_dtype: bool = True, - ) -> Array: + def _binary_op(self, + op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], + other: ArrayOrScalar, + get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]] = _np_result_type, # noqa + reverse: bool = False) -> Array: # {{{ sanity checks @@ -675,19 +596,15 @@ def _binary_op( import pytato.utils as utils if reverse: - result = utils.broadcast_binary_op( - other, self, op, - get_result_type, - tags=tags, - non_equality_tags=non_equality_tags, - cast_to_result_dtype=cast_to_result_dtype) + result = utils.broadcast_binary_op(other, self, op, + get_result_type, + tags=tags, + non_equality_tags=non_equality_tags) else: - result = utils.broadcast_binary_op( - self, other, op, - get_result_type, - tags=tags, - non_equality_tags=non_equality_tags, - cast_to_result_dtype=cast_to_result_dtype) + result = utils.broadcast_binary_op(self, other, op, + get_result_type, + tags=tags, + non_equality_tags=non_equality_tags) assert isinstance(result, Array) return result @@ -1495,7 +1412,7 @@ class Stack(_SuppliedAxesAndTagsMixin, Array): @property def dtype(self) -> np.dtype[Any]: - return _np_result_dtype(*(arr.dtype for arr in self.arrays)) + return _np_result_type(*(arr.dtype for arr in self.arrays)) @property def shape(self) -> ShapeType: @@ -1528,7 +1445,7 @@ class Concatenate(_SuppliedAxesAndTagsMixin, Array): @property def dtype(self) -> np.dtype[Any]: - return _np_result_dtype(*(arr.dtype for arr in self.arrays)) + return _np_result_type(*(arr.dtype for arr in self.arrays)) @property def shape(self) -> ShapeType: @@ -2149,7 +2066,7 @@ def reshape(array: Array, newshape: int | Sequence[int], """ :param array: array to be reshaped :param newshape: shape of the resulting array - :param order: ``"C"`` or ``"F"``. Layout order of the resulting array. + :param order: ``"C"`` or ``"F"``. Layout order of the resulting array. .. note:: @@ -2491,14 +2408,12 @@ def _compare(x1: ArrayOrScalar, x2: ArrayOrScalar, which: str) -> Array | bool: import pytato.utils as utils # type-ignored because 'broadcast_binary_op' returns Scalar, while # '_compare' returns a bool. - return utils.broadcast_binary_op( - x1, x2, - lambda x, y: prim.Comparison(x, which, y), - lambda x, y: np.dtype(np.bool_), - tags=_get_default_tags(), - non_equality_tags=_get_created_at_tag(stacklevel=2), - cast_to_result_dtype=False - ) # type: ignore[return-value] + return utils.broadcast_binary_op(x1, x2, + lambda x, y: prim.Comparison(x, which, y), + lambda x, y: np.dtype(np.bool_), + tags=_get_default_tags(), + non_equality_tags=_get_created_at_tag(stacklevel=2), + ) # type: ignore[return-value] def equal(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: @@ -2560,7 +2475,6 @@ def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: lambda x, y: np.dtype(np.bool_), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), - cast_to_result_dtype=False, ) # type: ignore[return-value] @@ -2577,7 +2491,6 @@ def logical_and(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: lambda x, y: np.dtype(np.bool_), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), - cast_to_result_dtype=False, ) # type: ignore[return-value] diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 989d2c405..606e4752a 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -82,7 +82,6 @@ Scalar = Union[np.number[Any], int, np.bool_, bool, float, complex] ScalarExpression = Union[Scalar, prim.Expression] -PYTHON_SCALAR_CLASSES = (int, float, complex, bool) SCALAR_CLASSES = prim.VALID_CONSTANT_CLASSES diff --git a/pytato/utils.py b/pytato/utils.py index a7f817df4..722a0e3b2 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -31,9 +31,21 @@ TypeVar, ) -import islpy as isl -import numpy as np +from typing import (Tuple, List, Union, Callable, Any, Sequence, Dict, + Optional, Iterable, TypeVar, FrozenSet) +from pytato.array import (Array, ShapeType, IndexLambda, SizeParam, ShapeComponent, + DtypeOrScalar, ArrayOrScalar, BasicIndex, + AdvancedIndexInContiguousAxes, + AdvancedIndexInNoncontiguousAxes, + ConvertibleToIndexExpr, IndexExpr, NormalizedSlice, + _dtype_any, Einsum) +from pytato.scalar_expr import (ScalarExpression, IntegralScalarExpression, + SCALAR_CLASSES, INT_CLASSES, BoolT) +from pytools import UniqueNameGenerator +from pytato.transform import Mapper +from pytools.tag import Tag from immutabledict import immutabledict +import numpy as np import pymbolic.primitives as prim from pytools import UniqueNameGenerator @@ -46,7 +58,6 @@ ArrayOrScalar, BasicIndex, ConvertibleToIndexExpr, - DtypeOrPyScalarType, Einsum, IndexExpr, IndexLambda, @@ -58,7 +69,6 @@ ) from pytato.scalar_expr import ( INT_CLASSES, - PYTHON_SCALAR_CLASSES, SCALAR_CLASSES, BoolT, IntegralScalarExpression, @@ -164,18 +174,15 @@ def with_indices_for_broadcasted_shape(val: prim.Variable, shape: ShapeType, return val[get_indexing_expression(shape, result_shape)] -def _extract_dtypes( - exprs: Sequence[ArrayOrScalar]) -> list[DtypeOrPyScalarType]: - dtypes: list[DtypeOrPyScalarType] = [] +def extract_dtypes_or_scalars( + exprs: Sequence[ArrayOrScalar]) -> List[DtypeOrScalar]: + dtypes: List[DtypeOrScalar] = [] for expr in exprs: if isinstance(expr, Array): dtypes.append(expr.dtype) - elif isinstance(expr, np.generic): - dtypes.append(expr.dtype) - elif isinstance(expr, PYTHON_SCALAR_CLASSES): - dtypes.append(type(expr)) else: - raise TypeError(f"unexpected expression type: '{type(expr)}'") + assert isinstance(expr, SCALAR_CLASSES) + dtypes.append(expr) return dtypes @@ -208,21 +215,24 @@ def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar, def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 - get_result_type: Callable[[DtypeOrPyScalarType, DtypeOrPyScalarType], np.dtype[Any]], # noqa:E501 - *, - tags: frozenset[Tag], - non_equality_tags: frozenset[Tag], - cast_to_result_dtype: bool, + get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 + tags: FrozenSet[Tag], + non_equality_tags: FrozenSet[Tag], ) -> ArrayOrScalar: from pytato.array import _get_default_axes + if isinstance(a1, SCALAR_CLASSES): + a1 = np.dtype(type(a1)).type(a1) + + if isinstance(a2, SCALAR_CLASSES): + a2 = np.dtype(type(a2)).type(a2) + if np.isscalar(a1) and np.isscalar(a2): from pytato.scalar_expr import evaluate return evaluate(op(a1, a2)) # type: ignore result_shape = get_shape_after_broadcasting([a1, a2]) - - dtypes = _extract_dtypes([a1, a2]) + dtypes = extract_dtypes_or_scalars([a1, a2]) result_dtype = get_result_type(*dtypes) bindings: dict[str, Array] = {} @@ -232,25 +242,6 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, expr2 = update_bindings_and_get_broadcasted_expr(a2, "_in1", bindings, result_shape) - def cast_to_result_type( - array: ArrayOrScalar, - expr: ScalarExpression - ) -> ScalarExpression: - if ((isinstance(array, Array) or isinstance(array, np.generic)) - and array.dtype != result_dtype): - # Loopy's type casts don't like casting to bool - assert result_dtype != np.bool_ - - expr = TypeCast(result_dtype, expr) - elif isinstance(expr, SCALAR_CLASSES): - expr = result_dtype.type(expr) - - return expr - - if cast_to_result_dtype: - expr1 = cast_to_result_type(a1, expr1) - expr2 = cast_to_result_type(a2, expr2) - return IndexLambda(expr=op(expr1, expr2), shape=result_shape, dtype=result_dtype, diff --git a/test/test_codegen.py b/test/test_codegen.py index a7c2d615a..661a4092f 100755 --- a/test/test_codegen.py +++ b/test/test_codegen.py @@ -272,9 +272,6 @@ def wrapper(*args): "logical_or")) @pytest.mark.parametrize("reverse", (False, True)) def test_scalar_array_binary_arith(ctx_factory, which, reverse): - from numpy.lib import NumpyVersion - is_old_numpy = NumpyVersion(np.__version__) < "2.0.0" - cl_ctx = ctx_factory() queue = cl.CommandQueue(cl_ctx) not_valid_in_complex = which in ["equal", "not_equal", "less", "less_equal", @@ -319,18 +316,9 @@ def test_scalar_array_binary_arith(ctx_factory, which, reverse): out = outputs[dtype] out_ref = np_op(x_in, y_orig.astype(dtype)) - if not is_old_numpy: - assert out.dtype == out_ref.dtype, (out.dtype, out_ref.dtype) - + assert out.dtype == out_ref.dtype, (out.dtype, out_ref.dtype) # In some cases ops are done in float32 in loopy but float64 in numpy. - is_allclose = np.allclose(out, out_ref), (out, out_ref) - if not is_old_numpy: - assert is_allclose - else: - if out_ref.dtype.itemsize == 1: - pass - else: - assert is_allclose + assert np.allclose(out, out_ref), (out, out_ref) @pytest.mark.parametrize("which", ("add", "sub", "mul", "truediv", "pow", diff --git a/test/test_pytato.py b/test/test_pytato.py index e9c23d8fe..946983ac6 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -514,8 +514,8 @@ def _assert_stripped_repr(ary: pt.Array, expected_repr: str): dtype='int64', expr=Product((Subscript(Variable('_in0'), (Variable('_0'), Variable('_1'))), - TypeCast(dtype('int64'), Subscript(Variable('_in1'), - (Variable('_0'), Variable('_1')))))), + Subscript(Variable('_in1'), + (Variable('_0'), Variable('_1'))))), bindings={'_in0': Placeholder(shape=(10, 4), dtype='int64', name='y'), '_in1': IndexLambda( shape=(10, 4), From a6a91c6f6844bedd57e360254f581b6377470a82 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Thu, 26 Sep 2024 10:24:19 -0500 Subject: [PATCH 114/124] Fix a merge fail --- pytato/array.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 10e2ba727..3867b0641 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -586,11 +586,12 @@ def __matmul__(self, other: Array, reverse: bool = False) -> Array: # ======= def _binary_op( self, - op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], + op: Callable[[ScalarExpression, ScalarExpression], + ScalarExpression], other: ArrayOrScalar, get_result_type: Callable[ [ArrayOrScalar, ArrayOrScalar], - np.dtype[Any]] = _np_result_dtype, + np.dtype[Any]] = _np_result_type, reverse: bool = False, cast_to_result_dtype: bool = True, is_pow: bool = False, From 3b7385e6fa7b5806ebe169491eecda30eb5a3b72 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Thu, 26 Sep 2024 12:25:47 -0500 Subject: [PATCH 115/124] Merge with main --- pytato/array.py | 40 ---------------------------------------- pytato/utils.py | 9 --------- 2 files changed, 49 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 3867b0641..b375ebfb3 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -577,13 +577,6 @@ def __matmul__(self, other: Array, reverse: bool = False) -> Array: __rmatmul__ = partialmethod(__matmul__, reverse=True) -# <<<<<<< HEAD -# def _binary_op(self, -# op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], -# other: ArrayOrScalar, -# get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]] = _np_result_type, # noqa -# reverse: bool = False) -> Array: -# ======= def _binary_op( self, op: Callable[[ScalarExpression, ScalarExpression], @@ -596,7 +589,6 @@ def _binary_op( cast_to_result_dtype: bool = True, is_pow: bool = False, ) -> Array: -# >>>>>>> main # {{{ sanity checks @@ -610,17 +602,6 @@ def _binary_op( import pytato.utils as utils if reverse: -# <<<<<<< HEAD -# result = utils.broadcast_binary_op(other, self, op, -# get_result_type, -# tags=tags, -# non_equality_tags=non_equality_tags) -# else: -# result = utils.broadcast_binary_op(self, other, op, -# get_result_type, -# tags=tags, -# non_equality_tags=non_equality_tags) -# ======= result = utils.broadcast_binary_op( other, self, op, get_result_type, @@ -636,7 +617,6 @@ def _binary_op( non_equality_tags=non_equality_tags, cast_to_result_dtype=cast_to_result_dtype, is_pow=is_pow) -# >>>>>>> main assert isinstance(result, Array) return result @@ -688,13 +668,8 @@ def _unary_op(self, op: Any) -> Array: __rtruediv__ = partialmethod(_binary_op, prim.Quotient, get_result_type=_truediv_result_type, reverse=True) -# <<<<<<< HEAD -# __pow__ = partialmethod(_binary_op, prim.Power) -# __rpow__ = partialmethod(_binary_op, prim.Power, reverse=True) -#======= __pow__ = partialmethod(_binary_op, operator.pow, is_pow=True) __rpow__ = partialmethod(_binary_op, operator.pow, reverse=True, is_pow=True) -# >>>>>>> main __neg__ = partialmethod(_unary_op, operator.neg) @@ -2443,14 +2418,6 @@ def _compare(x1: ArrayOrScalar, x2: ArrayOrScalar, which: str) -> Array | bool: import pytato.utils as utils # type-ignored because 'broadcast_binary_op' returns Scalar, while # '_compare' returns a bool. -# <<<<<<< HEAD -# return utils.broadcast_binary_op(x1, x2, -# lambda x, y: prim.Comparison(x, which, y), -# lambda x, y: np.dtype(np.bool_), -# tags=_get_default_tags(), -# non_equality_tags=_get_created_at_tag(stacklevel=2), -# ) # type: ignore[return-value] -# ======= return utils.broadcast_binary_op( x1, x2, lambda x, y: prim.Comparison(x, which, y), @@ -2460,7 +2427,6 @@ def _compare(x1: ArrayOrScalar, x2: ArrayOrScalar, which: str) -> Array | bool: cast_to_result_dtype=False, is_pow=False, ) # type: ignore[return-value] -# >>>>>>> main def equal(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: @@ -2522,11 +2488,8 @@ def logical_or(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: lambda x, y: np.dtype(np.bool_), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), -# <<<<<<< HEAD -# ======= cast_to_result_dtype=False, is_pow=False, -# >>>>>>> main ) # type: ignore[return-value] @@ -2543,11 +2506,8 @@ def logical_and(x1: ArrayOrScalar, x2: ArrayOrScalar) -> Array | bool: lambda x, y: np.dtype(np.bool_), tags=_get_default_tags(), non_equality_tags=_get_created_at_tag(), -# <<<<<<< HEAD -# ======= cast_to_result_dtype=False, is_pow=False, -# >>>>>>> main ) # type: ignore[return-value] diff --git a/pytato/utils.py b/pytato/utils.py index 9b429f097..9cb4a9d92 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -215,18 +215,12 @@ def update_bindings_and_get_broadcasted_expr(arr: ArrayOrScalar, def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, op: Callable[[ScalarExpression, ScalarExpression], ScalarExpression], # noqa:E501 -# <<<<<<< HEAD -# get_result_type: Callable[[DtypeOrScalar, DtypeOrScalar], np.dtype[Any]], # noqa:E501 -# tags: FrozenSet[Tag], -# non_equality_tags: FrozenSet[Tag], -# ======= get_result_type: Callable[[ArrayOrScalar, ArrayOrScalar], np.dtype[Any]], # noqa:E501 *, tags: frozenset[Tag], non_equality_tags: frozenset[Tag], cast_to_result_dtype: bool, is_pow: bool, -# >>>>>>> main ) -> ArrayOrScalar: from pytato.array import _get_default_axes @@ -251,8 +245,6 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, expr2 = update_bindings_and_get_broadcasted_expr(a2, "_in1", bindings, result_shape) -# <<<<<<< HEAD -# ======= def cast_to_result_type( array: ArrayOrScalar, expr: ScalarExpression @@ -282,7 +274,6 @@ def cast_to_result_type( expr1 = cast_to_result_type(a1, expr1) expr2 = cast_to_result_type(a2, expr2) -# >>>>>>> main return IndexLambda(expr=op(expr1, expr2), shape=result_shape, dtype=result_dtype, From 58478004b85604d9ac0b238ab55d601326ef3b94 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 27 Sep 2024 13:01:19 -0500 Subject: [PATCH 116/124] performance fix --- pytato/distributed/partition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index dea81e925..03691e6ae 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -928,7 +928,7 @@ def find_distributed_partition( def get_materialized_predecessors(ary: Array) -> tuple[Array, ...]: materialized_preds: dict[Array, None] = {} for pred in direct_preds_getter(ary): - if pred in materialized_arrays: + if pred in materialized_arrays_set: materialized_preds[pred] = None else: for p in get_materialized_predecessors(pred): From 7dd83bb7ab25e701f3a9b5d05477d991d0fbb1fa Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 27 Sep 2024 14:22:21 -0500 Subject: [PATCH 117/124] switch to dicts --- pytato/analysis/__init__.py | 62 ++++++++++++++++---------------- pytato/distributed/partition.py | 63 +++++++++++++++------------------ 2 files changed, 61 insertions(+), 64 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index c568c8f9c..e721a7a8a 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -29,7 +29,7 @@ from typing import TYPE_CHECKING, Any, Mapping from pymbolic.mapper.optimize import optimize_mapper -from pytools import memoize_method, unique +from pytools import memoize_method from pytato.array import ( Array, @@ -325,71 +325,73 @@ class DirectPredecessorsGetter(Mapper): We only consider the predecessors of a nodes in a data-flow sense. """ - def _get_preds_from_shape(self, shape: ShapeType) -> tuple[ArrayOrNames, ...]: - return tuple(unique(dim for dim in shape if isinstance(dim, Array))) + def _get_preds_from_shape(self, shape: ShapeType) -> dict[Array, None]: + return dict.fromkeys(dim for dim in shape if isinstance(dim, Array)) - def map_index_lambda(self, expr: IndexLambda) -> tuple[ArrayOrNames, ...]: - return tuple(unique(tuple(expr.bindings.values()) - + self._get_preds_from_shape(expr.shape))) + def map_index_lambda(self, expr: IndexLambda) -> dict[Array, None]: + return (dict.fromkeys(expr.bindings.values()) + | self._get_preds_from_shape(expr.shape)) - def map_stack(self, expr: Stack) -> tuple[ArrayOrNames, ...]: - return tuple(unique(tuple(expr.arrays) - + self._get_preds_from_shape(expr.shape))) + def map_stack(self, expr: Stack) -> dict[Array, None]: + return (dict.fromkeys(expr.arrays) + | self._get_preds_from_shape(expr.shape)) - map_concatenate = map_stack + def map_concatenate(self, expr: Concatenate) -> dict[Array, None]: + return (dict.fromkeys(expr.arrays) + | self._get_preds_from_shape(expr.shape)) - def map_einsum(self, expr: Einsum) -> tuple[ArrayOrNames, ...]: - return tuple(unique(tuple(expr.args) - + self._get_preds_from_shape(expr.shape))) + def map_einsum(self, expr: Einsum) -> dict[Array, None]: + return (dict.fromkeys(expr.args) + | self._get_preds_from_shape(expr.shape)) - def map_loopy_call_result(self, expr: NamedArray) -> tuple[ArrayOrNames, ...]: + def map_loopy_call_result(self, expr: NamedArray) -> dict[Array, None]: from pytato.loopy import LoopyCall, LoopyCallResult assert isinstance(expr, LoopyCallResult) assert isinstance(expr._container, LoopyCall) - return tuple(unique(tuple(ary + return (dict.fromkeys(ary for ary in expr._container.bindings.values() if isinstance(ary, Array)) - + self._get_preds_from_shape(expr.shape))) + | self._get_preds_from_shape(expr.shape)) - def _map_index_base(self, expr: IndexBase) -> tuple[ArrayOrNames, ...]: - return tuple(unique((expr.array,) # noqa: RUF005 - + tuple(idx for idx in expr.indices + def _map_index_base(self, expr: IndexBase) -> dict[Array, None]: + return (dict.fromkeys([expr.array]) + | dict.fromkeys(idx for idx in expr.indices if isinstance(idx, Array)) - + self._get_preds_from_shape(expr.shape))) + | self._get_preds_from_shape(expr.shape)) map_basic_index = _map_index_base map_contiguous_advanced_index = _map_index_base map_non_contiguous_advanced_index = _map_index_base def _map_index_remapping_base(self, expr: IndexRemappingBase - ) -> tuple[ArrayOrNames]: - return (expr.array,) + ) -> dict[ArrayOrNames, None]: + return dict.fromkeys([expr.array]) map_roll = _map_index_remapping_base map_axis_permutation = _map_index_remapping_base map_reshape = _map_index_remapping_base - def _map_input_base(self, expr: InputArgumentBase) -> tuple[ArrayOrNames, ...]: + def _map_input_base(self, expr: InputArgumentBase) -> dict[Array, None]: return self._get_preds_from_shape(expr.shape) map_placeholder = _map_input_base map_data_wrapper = _map_input_base map_size_param = _map_input_base - def map_distributed_recv(self, expr: DistributedRecv) -> tuple[ArrayOrNames, ...]: + def map_distributed_recv(self, expr: DistributedRecv) -> dict[Array, None]: return self._get_preds_from_shape(expr.shape) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> tuple[ArrayOrNames]: - return (expr.passthrough_data,) + ) -> dict[ArrayOrNames, None]: + return dict.fromkeys([expr.passthrough_data]) - def map_call(self, expr: Call) -> tuple[ArrayOrNames, ...]: - return tuple(unique(expr.bindings.values())) + def map_call(self, expr: Call) -> dict[ArrayOrNames, None]: + return dict.fromkeys(expr.bindings.values()) def map_named_call_result( - self, expr: NamedCallResult) -> tuple[ArrayOrNames]: - return (expr._container,) + self, expr: NamedCallResult) -> dict[ArrayOrNames, None]: + return dict.fromkeys([expr._container]) # }}} diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 03691e6ae..9e9f47913 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -315,8 +315,8 @@ def _get_placeholder_for(self, name: str, expr: Array) -> Placeholder: class _PartCommIDs: """A *part*, unlike a *batch*, begins with receives and ends with sends. """ - recv_ids: tuple[CommunicationOpIdentifier, ...] - send_ids: tuple[CommunicationOpIdentifier, ...] + recv_ids: immutabledict[CommunicationOpIdentifier, None] + send_ids: immutabledict[CommunicationOpIdentifier, None] # {{{ _make_distributed_partition @@ -727,8 +727,7 @@ def find_distributed_partition( assigned in :attr:`DistributedGraphPart.name_to_send_nodes`. """ import mpi4py.MPI as MPI - - from pytools import unique + from immutabledict import immutabledict from pytato.transform import SubsetDependencyMapper @@ -779,30 +778,31 @@ def find_distributed_partition( part_comm_ids: list[_PartCommIDs] = [] if comm_batches: - recv_ids: tuple[CommunicationOpIdentifier, ...] = () + recv_ids: immutabledict[CommunicationOpIdentifier, None] = immutabledict() for batch in comm_batches: - send_ids = tuple( - comm_id for comm_id in unique(batch) - if comm_id.src_rank == local_rank) + send_ids: immutabledict[CommunicationOpIdentifier, None] \ + = immutabledict.fromkeys( + comm_id for comm_id in batch + if comm_id.src_rank == local_rank) if recv_ids or send_ids: part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, send_ids=send_ids)) # These go into the next part - recv_ids = tuple( - comm_id for comm_id in unique(batch) + recv_ids = immutabledict.fromkeys( + comm_id for comm_id in batch if comm_id.dest_rank == local_rank) if recv_ids: part_comm_ids.append( _PartCommIDs( recv_ids=recv_ids, - send_ids=())) + send_ids=immutabledict())) else: part_comm_ids.append( _PartCommIDs( - recv_ids=(), - send_ids=())) + recv_ids=immutabledict(), + send_ids=immutabledict())) nparts = len(part_comm_ids) @@ -820,7 +820,7 @@ def find_distributed_partition( comm_id_to_part_id = { comm_id: ipart for ipart, comm_ids in enumerate(part_comm_ids) - for comm_id in unique(comm_ids.send_ids + comm_ids.recv_ids)} + for comm_id in comm_ids.send_ids | comm_ids.recv_ids} # }}} @@ -832,10 +832,10 @@ def find_distributed_partition( # The collections of arrays below must have a deterministic order in order to ensure # that the resulting partition is also deterministic - sent_arrays = tuple(unique( - send_node.data for send_node in lsrdg.local_send_id_to_send_node.values())) + sent_arrays = dict.fromkeys( + send_node.data for send_node in lsrdg.local_send_id_to_send_node.values()) - received_arrays = tuple(unique(lsrdg.local_recv_id_to_recv_node.values())) + received_arrays = dict.fromkeys(lsrdg.local_recv_id_to_recv_node.values()) # While receive nodes may be marked as materialized, we shouldn't be # including them here because we're using them (along with the send nodes) @@ -843,18 +843,13 @@ def find_distributed_partition( # We could allow sent *arrays* to be included here because they are distinct # from send *nodes*, but we choose to exclude them in order to simplify the # processing below. - materialized_arrays_set = set(materialized_arrays_collector.materialized_arrays) \ - - set(received_arrays) \ - - set(sent_arrays) - - from pytools import unique - materialized_arrays = tuple(unique( - a for a in materialized_arrays_collector.materialized_arrays - if a in materialized_arrays_set)) + materialized_arrays = {a: None + for a in materialized_arrays_collector.materialized_arrays + if a not in received_arrays | sent_arrays} # "mso" for "materialized/sent/output" - output_arrays = tuple(unique(outputs._data.values())) - mso_arrays = tuple(unique(materialized_arrays + sent_arrays + output_arrays)) + output_arrays = dict.fromkeys(outputs._data.values()) + mso_arrays = materialized_arrays | sent_arrays | output_arrays # FIXME: This gathers up materialized_arrays recursively, leading to # result sizes potentially quadratic in the number of materialized arrays. @@ -918,30 +913,30 @@ def find_distributed_partition( assert all(0 <= part_id < nparts for part_id in stored_ary_to_part_id.values()) - stored_arrays = tuple(unique(stored_ary_to_part_id)) + stored_arrays = dict.fromkeys(stored_ary_to_part_id) # {{{ find which stored arrays should become part outputs # (because they are used in not just their local part, but also others) direct_preds_getter = DirectPredecessorsGetter() - def get_materialized_predecessors(ary: Array) -> tuple[Array, ...]: + def get_materialized_predecessors(ary: Array) -> dict[Array, None]: materialized_preds: dict[Array, None] = {} for pred in direct_preds_getter(ary): - if pred in materialized_arrays_set: + if pred in materialized_arrays: materialized_preds[pred] = None else: for p in get_materialized_predecessors(pred): materialized_preds[p] = None - return tuple(materialized_preds.keys()) + return materialized_preds - stored_arrays_promoted_to_part_outputs = tuple(unique( - stored_pred + stored_arrays_promoted_to_part_outputs = { + stored_pred: None for stored_ary in stored_arrays for stored_pred in get_materialized_predecessors(stored_ary) if (stored_ary_to_part_id[stored_ary] != stored_ary_to_part_id[stored_pred]) - )) + } # }}} From 1ea962cc9460b6407e453dd5bebfc28cc02830ae Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Fri, 27 Sep 2024 14:40:51 -0500 Subject: [PATCH 118/124] more dict usage --- pytato/distributed/partition.py | 43 ++++++++++++++++----------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 9e9f47913..111f07d2e 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -402,12 +402,12 @@ def _recv_to_comm_id( class _LocalSendRecvDepGatherer( - CombineMapper[tuple[CommunicationOpIdentifier, ...]]): + CombineMapper[dict[CommunicationOpIdentifier, None]]): def __init__(self, local_rank: int) -> None: super().__init__() self.local_comm_ids_to_needed_comm_ids: \ dict[CommunicationOpIdentifier, - tuple[CommunicationOpIdentifier, ...]] = {} + dict[CommunicationOpIdentifier, None]] = {} self.local_recv_id_to_recv_node: \ dict[CommunicationOpIdentifier, DistributedRecv] = {} @@ -417,14 +417,13 @@ def __init__(self, local_rank: int) -> None: self.local_rank = local_rank def combine( - self, *args: tuple[CommunicationOpIdentifier, ...] - ) -> tuple[CommunicationOpIdentifier, ...]: - from pytools import unique - return reduce(lambda x, y: tuple(unique(x+y)), args, ()) + self, *args: dict[CommunicationOpIdentifier, None] + ) -> dict[CommunicationOpIdentifier, None]: + return reduce(lambda x, y: x | y, args, {}) def map_distributed_send_ref_holder(self, expr: DistributedSendRefHolder - ) -> tuple[CommunicationOpIdentifier, ...]: + ) -> dict[CommunicationOpIdentifier, None]: send_id = _send_to_comm_id(self.local_rank, expr.send) if send_id in self.local_send_id_to_send_node: @@ -438,8 +437,8 @@ def map_distributed_send_ref_holder(self, return self.rec(expr.passthrough_data) - def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier, ...]: - return () + def _map_input_base(self, expr: Array) -> dict[CommunicationOpIdentifier, None]: + return {} map_placeholder = _map_input_base map_data_wrapper = _map_input_base @@ -447,21 +446,21 @@ def _map_input_base(self, expr: Array) -> tuple[CommunicationOpIdentifier, ...]: def map_distributed_recv( self, expr: DistributedRecv - ) -> tuple[CommunicationOpIdentifier, ...]: + ) -> dict[CommunicationOpIdentifier, None]: recv_id = _recv_to_comm_id(self.local_rank, expr) if recv_id in self.local_recv_id_to_recv_node: from pytato.distributed.verify import DuplicateRecvError raise DuplicateRecvError(f"Multiple receives found for '{recv_id}'") - self.local_comm_ids_to_needed_comm_ids[recv_id] = () + self.local_comm_ids_to_needed_comm_ids[recv_id] = {} self.local_recv_id_to_recv_node[recv_id] = expr - return (recv_id,) + return {recv_id: None} def map_named_call_result( - self, expr: NamedCallResult) -> tuple[CommunicationOpIdentifier, ...]: + self, expr: NamedCallResult) -> dict[CommunicationOpIdentifier, None]: raise NotImplementedError( "LocalSendRecvDepGatherer does not support functions.") @@ -475,7 +474,7 @@ def map_named_call_result( def _schedule_task_batches( task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ - -> Sequence[list[TaskType]]: + -> Sequence[dict[TaskType, None]]: """For each :type:`TaskType`, determine the 'round'/'batch' during which it will be performed. A 'batch' of tasks consists of tasks which do not depend on each other. @@ -490,7 +489,7 @@ def _schedule_task_batches( def _schedule_task_batches_counted( task_ids_to_needed_task_ids: Mapping[TaskType, AbstractSet[TaskType]]) \ - -> tuple[Sequence[list[TaskType]], int]: + -> tuple[Sequence[dict[TaskType, None]], int]: """ Static type checkers need the functions to return the same type regardless of the input. The testing code needs to know about the number of tasks visited @@ -499,11 +498,11 @@ def _schedule_task_batches_counted( task_to_dep_level, visits_in_depend = \ _calculate_dependency_levels(task_ids_to_needed_task_ids) nlevels = 1 + max(task_to_dep_level.values(), default=-1) - task_batches: Sequence[list[TaskType]] = [[] for _ in range(nlevels)] + task_batches: Sequence[dict[TaskType, None]] = [{} for _ in range(nlevels)] for task_id, dep_level in task_to_dep_level.items(): if task_id not in task_batches[dep_level]: - task_batches[dep_level].append(task_id) + task_batches[dep_level][task_id] = None return task_batches, visits_in_depend + len(task_to_dep_level.keys()) @@ -594,14 +593,14 @@ def post_visit(self, expr: Any) -> None: # {{{ _set_dict_union_mpi def _set_dict_union_mpi( - dict_a: Mapping[_KeyT, tuple[_ValueT, ...]], - dict_b: Mapping[_KeyT, tuple[_ValueT, ...]], - mpi_data_type: mpi4py.MPI.Datatype | None) -> Mapping[_KeyT, Sequence[_ValueT]]: + dict_a: Mapping[_KeyT, dict[_ValueT, None]], + dict_b: Mapping[_KeyT, dict[_ValueT, None]], + mpi_data_type: mpi4py.MPI.Datatype | None) \ + -> Mapping[_KeyT, dict[_ValueT, None]]: assert mpi_data_type is None - from pytools import unique result = dict(dict_a) for key, values in dict_b.items(): - result[key] = tuple(unique(result.get(key, ()) + values)) + result[key] = result.get(key, {}) | values return result # }}} From 99f4d10ae535dc93afe5b1304966a808e33ca938 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Mon, 11 Nov 2024 11:12:22 -0600 Subject: [PATCH 119/124] Import union --- pytato/scalar_expr.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index bfbda9e71..bc137def8 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -50,6 +50,7 @@ TYPE_CHECKING, Any, Never, + Union, cast, ) From cede10c40d803590f326f22e890c7243ca95d487 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Mon, 11 Nov 2024 11:35:15 -0600 Subject: [PATCH 120/124] Use IntegralT --> IntegerT --- pytato/array.py | 2 +- pytato/scalar_expr.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index 2c197b552..a58338e12 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -297,7 +297,7 @@ def normalize_shape_component( # {{{ array interface ConvertibleToIndexExpr = Union[int, slice, "Array", None, EllipsisType] -IndexExpr = Union[IntegralT, "NormalizedSlice", "Array", None, EllipsisType] +IndexExpr = Union[IntegerT, "NormalizedSlice", "Array", None, EllipsisType] DtypeOrScalar = Union[_dtype_any, Scalar] ArrayOrScalar = Union["Array", Scalar] PyScalarType = type[bool] | type[int] | type[float] | type[complex] diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index bc137def8..58564b2fb 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -85,7 +85,7 @@ # {{{ scalar expressions INT_CLASSES = (int, np.integer) -IntegralScalarExpression = Union[IntegralT, prim.Expression] +IntegralScalarExpression = Union[IntegerT, prim.Expression] Scalar = Union[np.number[Any], int, np.bool_, bool, float, complex] ScalarExpression = Union[Scalar, prim.Expression] PYTHON_SCALAR_CLASSES = (int, float, complex, bool) From 02b1980cafbf0fa124ed43db218df467cd8a5983 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Mon, 11 Nov 2024 11:45:27 -0600 Subject: [PATCH 121/124] Use Scalar --> ScalarT --- pytato/array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytato/array.py b/pytato/array.py index a58338e12..805b3671e 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -298,8 +298,8 @@ def normalize_shape_component( ConvertibleToIndexExpr = Union[int, slice, "Array", None, EllipsisType] IndexExpr = Union[IntegerT, "NormalizedSlice", "Array", None, EllipsisType] -DtypeOrScalar = Union[_dtype_any, Scalar] -ArrayOrScalar = Union["Array", Scalar] +DtypeOrScalar = Union[_dtype_any, ScalarT] +ArrayOrScalar = Union["Array", ScalarT] PyScalarType = type[bool] | type[int] | type[float] | type[complex] DtypeOrPyScalarType = _dtype_any | PyScalarType From b820049160910fa48b1a2fd2efcd757b787ed113 Mon Sep 17 00:00:00 2001 From: Mike Campbell Date: Mon, 11 Nov 2024 15:26:50 -0600 Subject: [PATCH 122/124] Disable update_for_pymbolic_expression --- pytato/analysis/__init__.py | 44 ++++++++++++++++++------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/pytato/analysis/__init__.py b/pytato/analysis/__init__.py index 8577961d0..9c156786d 100644 --- a/pytato/analysis/__init__.py +++ b/pytato/analysis/__init__.py @@ -584,28 +584,28 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None: # CL Array self.rec(key_hash, key.get()) - update_for_BitwiseAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_BitwiseNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_BitwiseOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_BitwiseXor = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Call = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_CallWithKwargs = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Comparison = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_If = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_FloorDiv = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_LeftShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_LogicalAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_LogicalNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_LogicalOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Lookup = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Power = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_Product = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_Quotient = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Remainder = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_RightShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Subscript = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 - update_for_Sum = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 - update_for_Variable = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_BitwiseAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_BitwiseNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_BitwiseOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_BitwiseXor = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Call = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + # update_for_CallWithKwargs = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Comparison = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_If = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + # update_for_FloorDiv = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_LeftShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_LogicalAnd = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_LogicalNot = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_LogicalOr = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Lookup = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Power = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + # update_for_Product = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + # update_for_Quotient = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Remainder = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_RightShift = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Subscript = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 + # update_for_Sum = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815 + # update_for_Variable = LoopyKeyBuilder.update_for_pymbolic_expression # noqa: N815, E501 # }}} From 1337590a29acb3fcc3b487734a200eda32a1ebc5 Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 12 Nov 2024 13:22:16 -0600 Subject: [PATCH 123/124] remove duplicate Hashable --- pytato/distributed/partition.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index 43f9d9197..86fe5df51 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -69,9 +69,6 @@ TYPE_CHECKING, Any, Generic, - Hashable, - Mapping, - Sequence, TypeVar, cast, ) From 2c001d6be6b876e2471da2f2261eff5e081e753d Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Tue, 12 Nov 2024 13:23:56 -0600 Subject: [PATCH 124/124] add missing import --- pytato/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytato/utils.py b/pytato/utils.py index a8f9a2086..e243a20ee 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -44,6 +44,7 @@ from pytools.tag import Tag from immutabledict import immutabledict import numpy as np +import islpy as isl import pymbolic.primitives as prim from pymbolic import ScalarT