Skip to content

Commit

Permalink
ENH: implement FrozenDict with frozendict (#310)
Browse files Browse the repository at this point in the history
* DX: implement hash test for `FrozenDict` and `ReactionInfo`
* ENH: inherit `FrozenDict` from `frozendict`
* MAINT: install `frozendict` as direct dependency
  • Loading branch information
redeboer authored Dec 20, 2024
1 parent 63fff63 commit 769ade3
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 68 deletions.
19 changes: 7 additions & 12 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from _extend_docstrings import extend_docstrings # noqa: PLC2701


def pick_newtype_attrs(some_type: type) -> list:
def __get_newtypes(some_type: type) -> list:
return [
attr
for attr in dir(some_type)
Expand Down Expand Up @@ -278,25 +278,20 @@ def pick_newtype_attrs(some_type: type) -> list:
nb_execution_show_tb = True
nb_execution_timeout = -1
nb_output_stderr = "remove"


nitpick_temp_names = [
*pick_newtype_attrs(EdgeQuantumNumbers),
*pick_newtype_attrs(NodeQuantumNumbers),
]
nitpick_temp_patterns = [
(r"py:(class|obj)", r"qrules\.quantum_numbers\." + name)
for name in nitpick_temp_names
]
nitpick_ignore_regex = [
(r"py:(class|obj)", "json.encoder.JSONEncoder"),
(r"py:(class|obj)", r"frozendict(\.frozendict)?"),
(r"py:(class|obj)", r"qrules\.topology\.EdgeType"),
(r"py:(class|obj)", r"qrules\.topology\.KT"),
(r"py:(class|obj)", r"qrules\.topology\.NewEdgeType"),
(r"py:(class|obj)", r"qrules\.topology\.NewNodeType"),
(r"py:(class|obj)", r"qrules\.topology\.NodeType"),
(r"py:(class|obj)", r"qrules\.topology\.VT"),
*nitpick_temp_patterns,
*[
(r"py:(class|obj)", r"qrules\.quantum_numbers\." + name)
for name in __get_newtypes(EdgeQuantumNumbers)
+ __get_newtypes(NodeQuantumNumbers)
],
]
nitpicky = True
primary_domain = "py"
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ classifiers = [
dependencies = [
"PyYAML",
"attrs >=20.1.0", # on_setattr and https://www.attrs.org/en/stable/api.html#next-gen
"frozendict",
"jsonschema",
"particle",
"python-constraint",
Expand Down
59 changes: 4 additions & 55 deletions src/qrules/topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,12 @@
import attrs
from attrs import define, field, frozen
from attrs.validators import deep_iterable, deep_mapping, instance_of
from frozendict import frozendict

from qrules._implementers import implement_pretty_repr

if TYPE_CHECKING:
from collections.abc import (
ItemsView,
Iterable,
Iterator,
KeysView,
Mapping,
Sequence,
ValuesView,
)
from collections.abc import Iterable, Mapping, Sequence

from IPython.lib.pretty import PrettyPrinter

Expand All @@ -56,31 +49,8 @@ def __lt__(self, other: Any) -> bool: ...


@total_ordering
class FrozenDict(abc.Hashable, abc.Mapping, Generic[KT, VT]):
"""An **immutable** and **hashable** version of a `dict`.
`FrozenDict` makes it possible to make classes hashable if they are decorated with
:func:`attr.frozen` and contain `~typing.Mapping`-like attributes. If these
attributes were to be implemented with a normal `dict`, the instance is strictly
speaking still mutable (even if those attributes are a `property`) and the class is
therefore not safely hashable.
.. warning:: The keys have to be comparable, that is, they need to have a
:meth:`~object.__lt__` method.
"""

def __init__(self, mapping: Mapping | None = None) -> None:
self.__mapping: dict[KT, VT] = {}
if mapping is not None:
self.__mapping = dict(mapping)
self.__hash = hash(None)
if len(self.__mapping) != 0:
self.__hash = 0
for key_value_pair in self.items():
self.__hash ^= hash(key_value_pair)

def __repr__(self) -> str:
return f"{type(self).__name__}({self.__mapping})"
class FrozenDict(frozendict, Generic[KT, VT]):
"""A sortable version of :code:`frozendict`."""

def _repr_pretty_(self, p: PrettyPrinter, cycle: bool) -> None:
class_name = type(self).__name__
Expand All @@ -96,15 +66,6 @@ def _repr_pretty_(self, p: PrettyPrinter, cycle: bool) -> None:
p.breakable()
p.text("})")

def __iter__(self) -> Iterator[KT]:
return iter(self.__mapping)

def __len__(self) -> int:
return len(self.__mapping)

def __getitem__(self, key: KT) -> VT:
return self.__mapping[key]

def __gt__(self, other: Any) -> bool:
if isinstance(other, abc.Mapping):
sorted_self = _convert_mapping_to_sorted_tuple(self)
Expand All @@ -117,18 +78,6 @@ def __gt__(self, other: Any) -> bool:
)
raise NotImplementedError(msg)

def __hash__(self) -> int:
return self.__hash

def keys(self) -> KeysView[KT]:
return self.__mapping.keys()

def items(self) -> ItemsView[KT, VT]:
return self.__mapping.items()

def values(self) -> ValuesView[VT]:
return self.__mapping.values()


def _convert_mapping_to_sorted_tuple(
mapping: Mapping[KT, VT],
Expand Down
36 changes: 35 additions & 1 deletion tests/unit/test_topology.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import hashlib
import pickle # noqa: S403
import typing

import pytest
Expand All @@ -6,7 +8,7 @@

from qrules.topology import (
Edge,
FrozenDict, # noqa: F401 # pyright: ignore[reportUnusedImport]
FrozenDict, # pyright: ignore[reportUnusedImport]
InteractionNode,
MutableTopology,
SimpleStateTransitionTopologyBuilder,
Expand Down Expand Up @@ -39,6 +41,23 @@ def test_immutability(self):
edge.ending_node_id += 1


class TestFrozenDict:
def test_hash(self):
obj: FrozenDict = FrozenDict({})
assert _compute_hash(obj) == "067705e70d037311d05daae1e32e1fce"

obj = FrozenDict({"key1": "value1"})
assert _compute_hash(obj) == "56b0520e2a3af550c0f488cd5de2d474"

obj = FrozenDict({
"key1": "value1",
"key2": 2,
"key3": (1, 2, 3),
"key4": FrozenDict({"nested_key": "nested_value"}),
})
assert _compute_hash(obj) == "8568f73c07fce099311f010061f070c6"


class TestInteractionNode:
def test_constructor_exceptions(self):
with pytest.raises(TypeError):
Expand Down Expand Up @@ -188,6 +207,9 @@ def test_constructor_exceptions(self, nodes, edges):
):
assert Topology(nodes, edges)

def test_hash(self, two_to_three_decay: Topology):
assert _compute_hash(two_to_three_decay) == "cbaea5d94038a3ad30888014a7b3ae20"

@pytest.mark.parametrize("repr_method", [repr, pretty])
def test_repr_and_eq(self, repr_method, two_to_three_decay: Topology):
topology = eval(repr_method(two_to_three_decay))
Expand Down Expand Up @@ -299,3 +321,15 @@ def test_create_n_body_topology(n_initial: int, n_final: int, exception):
assert len(topology.outgoing_edge_ids) == n_final
assert len(topology.intermediate_edge_ids) == 0
assert len(topology.nodes) == 1


def _compute_hash(obj) -> str:
b = _to_bytes(obj)
h = hashlib.md5(b) # noqa: S324
return h.hexdigest()


def _to_bytes(obj) -> bytes:
if isinstance(obj, bytes | bytearray):
return obj
return pickle.dumps(obj)
21 changes: 21 additions & 0 deletions tests/unit/test_transition.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# pyright: reportUnusedImport=false
import hashlib
import pickle # noqa: S403
from copy import deepcopy
from fractions import Fraction

Expand Down Expand Up @@ -44,6 +46,13 @@ def test_repr(self, repr_method, reaction: ReactionInfo):
def test_hash(self, reaction: ReactionInfo):
assert hash(deepcopy(reaction)) == hash(reaction)

def test_hash_value(self, reaction: ReactionInfo):
expected_hash = {
"canonical-helicity": "65106a44301f9340e633d09f66ad7d17",
"helicity": "9646d3ee5c5e8534deb8019435161f2e",
}[reaction.formalism]
assert _compute_hash(reaction) == expected_hash


class TestState:
@pytest.mark.parametrize(
Expand Down Expand Up @@ -106,3 +115,15 @@ def test_regex_pattern(self):
"Delta(1900)++",
"Delta(1920)++",
]


def _compute_hash(obj) -> str:
b = _to_bytes(obj)
h = hashlib.md5(b) # noqa: S324
return h.hexdigest()


def _to_bytes(obj) -> bytes:
if isinstance(obj, bytes | bytearray):
return obj
return pickle.dumps(obj)
29 changes: 29 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 769ade3

Please sign in to comment.