diff --git a/lambeq/backend/grammar.py b/lambeq/backend/grammar.py index 9b5534e3..2aab6381 100644 --- a/lambeq/backend/grammar.py +++ b/lambeq/backend/grammar.py @@ -362,6 +362,9 @@ def rotate(self, z: int) -> Diagrammable: """ + def dagger(self) -> Diagrammable: + """Implements conjugation of diagrams.""" + def __matmul__(self, rhs: Diagrammable | Ty) -> Diagrammable: """Implements the tensor operator `@` with another diagram.""" @@ -1572,6 +1575,9 @@ def rotate(self, z: int) -> Self: def dagger(self) -> Box: return self.box + def apply_functor(self, functor: Functor) -> Diagrammable: + return functor(self.dagger()).dagger() + @classmethod def from_json(cls, data: _JSONDictT | str) -> Self: data_dict = json.loads(data) if isinstance(data, str) else data diff --git a/tests/backend/test_grammar.py b/tests/backend/test_grammar.py index 1ba68b3f..661d3884 100644 --- a/tests/backend/test_grammar.py +++ b/tests/backend/test_grammar.py @@ -277,6 +277,7 @@ def ar(func, box): assert func(g.r) == func(g).r == g_z.r assert func(f >> g) == f_z >> g_z assert func(f @ g) == f_z @ g_z + assert func(f.dagger()) == func(f).dagger() def bad_ar(func, box): return Box("BOX", a, c) if box.cod == b else box diff --git a/tests/test_circuit.py b/tests/test_circuit.py index 29aa80d2..2f2bd54e 100644 --- a/tests/test_circuit.py +++ b/tests/test_circuit.py @@ -320,3 +320,20 @@ def test_lambeq_tket_conversion(): def test_special_characters(box, expected_sym_count): ansatz = Sim15Ansatz({n_ty: 2, comma_ty: 2, space_ty: 2}, n_layers=1) assert(len(ansatz(box).free_symbols) == expected_sym_count) + + +def test_ansatz_is_dagger_functor(): + ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1) + diagram = Word('John', N) + circuit1 = ansatz(diagram).dagger() + circuit2 = ansatz(diagram.dagger()) + assert circuit1 == circuit2 + +def test_ansatz_is_dagger_functor_sentence(): + ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1) + diagram = (Word('Alice', N) @ Word('runs', N >> S) >> + Cup(N, N.r) @ S) + + circuit1 = ansatz(diagram).dagger().normal_form() + circuit2 = ansatz(diagram.dagger()).normal_form() + assert circuit1 == circuit2