diff --git a/hugr-py/src/hugr/tys.py b/hugr-py/src/hugr/tys.py index ee07bf626..124beaf20 100644 --- a/hugr-py/src/hugr/tys.py +++ b/hugr-py/src/hugr/tys.py @@ -237,15 +237,18 @@ def as_tuple(self) -> Tuple: def __repr__(self) -> str: return f"Sum({self.variant_rows})" + def __eq__(self, other: object) -> bool: + return isinstance(other, Sum) and self.variant_rows == other.variant_rows + def type_bound(self) -> TypeBound: return TypeBound.join(*(t.type_bound() for r in self.variant_rows for t in r)) -@dataclass() +@dataclass(eq=False) class UnitSum(Sum): """Simple :class:`Sum` type with `size` variants of empty rows.""" - size: int + size: int = field(compare=False) def __init__(self, size: int): self.size = size @@ -262,7 +265,7 @@ def __repr__(self) -> str: return f"UnitSum({self.size})" -@dataclass() +@dataclass(eq=False) class Tuple(Sum): """Product type with `tys` elements. Instances of this type correspond to :class:`Sum` with a single variant. diff --git a/hugr-py/tests/test_tys.py b/hugr-py/tests/test_tys.py new file mode 100644 index 000000000..689234426 --- /dev/null +++ b/hugr-py/tests/test_tys.py @@ -0,0 +1,14 @@ +from __future__ import annotations + +from hugr.tys import Bool, Qubit, Sum, Tuple, UnitSum + + +def test_sums(): + assert Sum([[Bool, Qubit]]) == Tuple(Bool, Qubit) + assert Tuple(Bool, Qubit) == Sum([[Bool, Qubit]]) + assert Sum([[Bool, Qubit]]).as_tuple() == Sum([[Bool, Qubit]]) + + assert Tuple() == Sum([[]]) + assert UnitSum(0) == Sum([]) + assert UnitSum(1) == Tuple() + assert UnitSum(4) == Sum([[], [], [], []])