From 070df9f8393e75a241e5e1e5a77fe2824ec419d2 Mon Sep 17 00:00:00 2001 From: Kaushik Kulkarni <15399010+kaushikcfd@users.noreply.github.com> Date: Mon, 9 Sep 2024 13:25:51 -0700 Subject: [PATCH] Require that happens_after is not mutable (#866) * Require that happens_after is not mutable * Tweak type tests for happens_after --------- Co-authored-by: Andreas Kloeckner --- loopy/kernel/instruction.py | 34 ++++++++++++++++++++++------------ loopy/tools.py | 9 +++++++++ pyproject.toml | 1 + 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/loopy/kernel/instruction.py b/loopy/kernel/instruction.py index d564d5e36..a6420b8fc 100644 --- a/loopy/kernel/instruction.py +++ b/loopy/kernel/instruction.py @@ -20,7 +20,10 @@ THE SOFTWARE. """ -from collections.abc import Mapping as MappingABC, Set as abc_Set +from collections.abc import ( + Mapping as MappingABC, + Set as abc_Set, +) from dataclasses import dataclass from functools import cached_property from sys import intern @@ -283,6 +286,7 @@ def __init__(self, *, depends_on: Union[FrozenSet[str], str, None] = None, ) -> None: + from immutabledict import immutabledict if predicates is None: predicates = frozenset() @@ -314,28 +318,29 @@ def __init__(self, raise LoopyError("Setting depends_on_is_final to True requires " "actually specifying happens_after/depends_on") - if happens_after is None: - happens_after = {} + if isinstance(happens_after, immutabledict): + pass + elif happens_after is None: + happens_after = immutabledict() elif isinstance(happens_after, str): warn("Passing a string for happens_after/depends_on is deprecated and " "will stop working in 2025. Instead, pass a full-fledged " "happens_after data structure.", DeprecationWarning, stacklevel=2) - happens_after = { + happens_after = immutabledict({ after_id.strip(): HappensAfter( variable_name=None, instances_rel=None) for after_id in happens_after.split(",") - if after_id.strip()} + if after_id.strip()}) elif isinstance(happens_after, frozenset): - happens_after = { + happens_after = immutabledict({ after_id: HappensAfter( variable_name=None, instances_rel=None) - for after_id in happens_after} - elif isinstance(happens_after, MappingABC): - if isinstance(happens_after, dict): - happens_after = happens_after + for after_id in happens_after}) + elif isinstance(happens_after, dict): + happens_after = immutabledict(happens_after) else: raise TypeError("'happens_after' has unexpected type: " f"{type(happens_after)}") @@ -390,6 +395,9 @@ def __init__(self, assert isinstance(groups, abc_Set) assert isinstance(conflicts_with_groups, abc_Set) + from loopy.tools import is_hashable + assert is_hashable(happens_after) + ImmutableRecord.__init__(self, id=id, happens_after=happens_after, @@ -573,13 +581,15 @@ def update_persistent_hash(self, key_hash, key_builder): def __setstate__(self, val): super().__setstate__(val) + from immutabledict import immutabledict + from loopy.tools import intern_frozenset_of_ids if self.id is not None: # pylint:disable=access-member-before-definition self.id = intern(self.id) - self.happens_after = { + self.happens_after = immutabledict({ intern(after_id): ha - for after_id, ha in self.happens_after.items()} + for after_id, ha in self.happens_after.items()}) self.groups = intern_frozenset_of_ids(self.groups) self.conflicts_with_groups = ( intern_frozenset_of_ids(self.conflicts_with_groups)) diff --git a/loopy/tools.py b/loopy/tools.py index 50a523ee8..2e3b5db4f 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -972,4 +972,13 @@ def _get_persistent_hashable_arg(arg): # }}} + +def is_hashable(o: object) -> bool: + try: + hash(o) + except TypeError: + return False + return True + + # vim: fdm=marker diff --git a/pyproject.toml b/pyproject.toml index 4134ba24d..70672a1ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ "Mako", "pyrsistent", "immutables", + "immutabledict", # for Self, TypeAlias "typing-extensions>=4; python_version<'3.12'",