Skip to content

Commit

Permalink
Require that happens_after is not mutable (inducer#866)
Browse files Browse the repository at this point in the history
* Require that happens_after is not mutable

* Tweak type tests for happens_after

---------

Co-authored-by: Andreas Kloeckner <[email protected]>
  • Loading branch information
kaushikcfd and inducer authored Sep 9, 2024
1 parent 0f78426 commit 070df9f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 12 deletions.
34 changes: 22 additions & 12 deletions loopy/kernel/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
THE SOFTWARE.
"""

from collections.abc import Mapping as MappingABC, Set as abc_Set
from collections.abc import (
Mapping as MappingABC,
Set as abc_Set,
)
from dataclasses import dataclass
from functools import cached_property
from sys import intern
Expand Down Expand Up @@ -283,6 +286,7 @@ def __init__(self,
*,
depends_on: Union[FrozenSet[str], str, None] = None,
) -> None:
from immutabledict import immutabledict

if predicates is None:
predicates = frozenset()
Expand Down Expand Up @@ -314,28 +318,29 @@ def __init__(self,
raise LoopyError("Setting depends_on_is_final to True requires "
"actually specifying happens_after/depends_on")

if happens_after is None:
happens_after = {}
if isinstance(happens_after, immutabledict):
pass
elif happens_after is None:
happens_after = immutabledict()
elif isinstance(happens_after, str):
warn("Passing a string for happens_after/depends_on is deprecated and "
"will stop working in 2025. Instead, pass a full-fledged "
"happens_after data structure.", DeprecationWarning, stacklevel=2)

happens_after = {
happens_after = immutabledict({
after_id.strip(): HappensAfter(
variable_name=None,
instances_rel=None)
for after_id in happens_after.split(",")
if after_id.strip()}
if after_id.strip()})
elif isinstance(happens_after, frozenset):
happens_after = {
happens_after = immutabledict({
after_id: HappensAfter(
variable_name=None,
instances_rel=None)
for after_id in happens_after}
elif isinstance(happens_after, MappingABC):
if isinstance(happens_after, dict):
happens_after = happens_after
for after_id in happens_after})
elif isinstance(happens_after, dict):
happens_after = immutabledict(happens_after)
else:
raise TypeError("'happens_after' has unexpected type: "
f"{type(happens_after)}")
Expand Down Expand Up @@ -390,6 +395,9 @@ def __init__(self,
assert isinstance(groups, abc_Set)
assert isinstance(conflicts_with_groups, abc_Set)

from loopy.tools import is_hashable
assert is_hashable(happens_after)

ImmutableRecord.__init__(self,
id=id,
happens_after=happens_after,
Expand Down Expand Up @@ -573,13 +581,15 @@ def update_persistent_hash(self, key_hash, key_builder):
def __setstate__(self, val):
super().__setstate__(val)

from immutabledict import immutabledict

from loopy.tools import intern_frozenset_of_ids

if self.id is not None: # pylint:disable=access-member-before-definition
self.id = intern(self.id)
self.happens_after = {
self.happens_after = immutabledict({
intern(after_id): ha
for after_id, ha in self.happens_after.items()}
for after_id, ha in self.happens_after.items()})
self.groups = intern_frozenset_of_ids(self.groups)
self.conflicts_with_groups = (
intern_frozenset_of_ids(self.conflicts_with_groups))
Expand Down
9 changes: 9 additions & 0 deletions loopy/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,4 +972,13 @@ def _get_persistent_hashable_arg(arg):

# }}}


def is_hashable(o: object) -> bool:
try:
hash(o)
except TypeError:
return False
return True


# vim: fdm=marker
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ dependencies = [
"Mako",
"pyrsistent",
"immutables",
"immutabledict",

# for Self, TypeAlias
"typing-extensions>=4; python_version<'3.12'",
Expand Down

0 comments on commit 070df9f

Please sign in to comment.