From 8dd2f5ace16dbf54b9d9a5d2568071cfdb9322a7 Mon Sep 17 00:00:00 2001 From: Kin Ian Lo Date: Wed, 25 Sep 2024 21:52:14 +0100 Subject: [PATCH 1/6] split CopyNode --- lambeq/backend/tensor.py | 26 ++++++++++++++++++++------ 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/lambeq/backend/tensor.py b/lambeq/backend/tensor.py index af7a9f3f..37b7817a 100644 --- a/lambeq/backend/tensor.py +++ b/lambeq/backend/tensor.py @@ -385,21 +385,35 @@ def to_tn(self, dtype: type | None = None): del scan[len(l): len(l) + 2] else: if isinstance(box, Spider): - node = tn.CopyNode(box.n_legs_in + box.n_legs_out, - box.type.product, dtype=dtype, - backend=backend) + rank = box.n_legs_in + box.n_legs_out + dim = box.type.product + if rank <= 3: + node = tn.CopyNode(rank, dim, dtype=dtype, backend=backend) + nodes.append(node) + legs = node.edges + else: + internal_nodes = [tn.CopyNode(3, dim, dtype=dtype, backend=backend) for _ in range(rank-2)] + nodes.extend(internal_nodes) + + for i in range(len(internal_nodes)-1): + tn.connect(internal_nodes[i][0], internal_nodes[i+1][1]) + + legs = ([internal_nodes[0][1]] + + [n[2] for n in internal_nodes] + + [internal_nodes[-1][0]]) else: node = tn.Node(box.array, str(box.name), backend=backend) - nodes.append(node) + nodes.append(node) + legs = node.edges for i in range(len(box.dom)): - tn.connect(scan[len(l) + i], node[i]) + tn.connect(scan[len(l) + i], legs[i]) scan = (scan[:len(l)] - + node[len(box.dom):] + + legs[len(box.dom):] + scan[len(l) + len(box.dom):]) # nodes, input_edge_order, output_edge_order From 7d0b1d63d0946309b23d6f7b35249cf6be76d53e Mon Sep 17 00:00:00 2001 From: Kin Ian Lo Date: Thu, 26 Sep 2024 01:15:59 +0100 Subject: [PATCH 2/6] add comments --- lambeq/backend/tensor.py | 32 ++++++++++++++++++++++---------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/lambeq/backend/tensor.py b/lambeq/backend/tensor.py index 37b7817a..e0ac9a07 100644 --- a/lambeq/backend/tensor.py +++ b/lambeq/backend/tensor.py @@ -388,19 +388,31 @@ def to_tn(self, dtype: type | None = None): rank = box.n_legs_in + box.n_legs_out dim = box.type.product if rank <= 3: - node = tn.CopyNode(rank, dim, dtype=dtype, backend=backend) + node = tn.CopyNode(rank, dim, dtype=dtype, + backend=backend) nodes.append(node) legs = node.edges else: - internal_nodes = [tn.CopyNode(3, dim, dtype=dtype, backend=backend) for _ in range(rank-2)] - nodes.extend(internal_nodes) - - for i in range(len(internal_nodes)-1): - tn.connect(internal_nodes[i][0], internal_nodes[i+1][1]) - - legs = ([internal_nodes[0][1]] - + [n[2] for n in internal_nodes] - + [internal_nodes[-1][0]]) + # Decompose the spider into a chain of + # three-legged spiders of length rank - 2 + # For example, a 5-legged spider will be + # decomposed into: + # 2 2 2 + # | | | + # ---1-[N0]-0----1-[N1]-0----1-[N2]-0--- + # where the numbers indicate the leg indices of + # the spiders. + spiders = [tn.CopyNode(3, dim, dtype=dtype, + backend=backend) + for _ in range(rank-2)] + nodes.extend(spiders) + + for i in range(len(spiders)-1): + tn.connect(spiders[i][0], spiders[i+1][1]) + + legs = ([spiders[0][1]] + + [n[2] for n in spiders] + + [spiders[-1][0]]) else: node = tn.Node(box.array, str(box.name), From cfd0a7045e5be716dfcd47e3ba702fe8accaeefa Mon Sep 17 00:00:00 2001 From: Kin Ian Lo Date: Fri, 27 Sep 2024 16:03:53 +0100 Subject: [PATCH 3/6] add test for spider --- tests/backend/test_tensor.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/backend/test_tensor.py b/tests/backend/test_tensor.py index cc189b73..22c9bd47 100644 --- a/tests/backend/test_tensor.py +++ b/tests/backend/test_tensor.py @@ -6,6 +6,9 @@ import lambeq.backend.grammar as grammar from lambeq.backend.tensor import * +@pytest.fixture +def spider(): + return Spider(Dim(3), 5, 3) def test_Ty(): assert Dim(1,1,1,1,1) == Dim(1) == Dim() @@ -85,3 +88,22 @@ def test_lambdify(): assert bx1.lambdify(a,b,c,d)(1,2,3,4) == bx1_concrete assert (bx1 >> bx2).lambdify(a,b,c,d)(1,2,3,4) == bx1_concrete >> bx2_concrete + +def test_to_tn_spider_unfuse(spider): + nodes, edges = spider.to_tn() + + assert len(edges) == spider.n_legs_in + spider.n_legs_out + assert all(node.rank <= 3 for node in nodes) + +def test_spider_eval(spider): + n_legs = 8 + dim = 3 + + expected = np.zeros(tuple(dim for _ in range(n_legs))) + for i in range(dim): + expected[tuple(i for _ in range(n_legs))] = 1 + + result = spider.eval() + + assert result.shape == tuple(dim for _ in range(n_legs)) + assert np.allclose(result, expected) \ No newline at end of file From e1d555c054b0beab6c2f41a5f5ce5fb620f91304 Mon Sep 17 00:00:00 2001 From: Kin Ian Lo Date: Mon, 30 Sep 2024 16:46:04 +0100 Subject: [PATCH 4/6] Circuit ansatze should be dagger functors (#155) --- lambeq/backend/grammar.py | 6 ++++++ tests/backend/test_grammar.py | 1 + tests/test_circuit.py | 17 +++++++++++++++++ 3 files changed, 24 insertions(+) diff --git a/lambeq/backend/grammar.py b/lambeq/backend/grammar.py index 9b5534e3..2aab6381 100644 --- a/lambeq/backend/grammar.py +++ b/lambeq/backend/grammar.py @@ -362,6 +362,9 @@ def rotate(self, z: int) -> Diagrammable: """ + def dagger(self) -> Diagrammable: + """Implements conjugation of diagrams.""" + def __matmul__(self, rhs: Diagrammable | Ty) -> Diagrammable: """Implements the tensor operator `@` with another diagram.""" @@ -1572,6 +1575,9 @@ def rotate(self, z: int) -> Self: def dagger(self) -> Box: return self.box + def apply_functor(self, functor: Functor) -> Diagrammable: + return functor(self.dagger()).dagger() + @classmethod def from_json(cls, data: _JSONDictT | str) -> Self: data_dict = json.loads(data) if isinstance(data, str) else data diff --git a/tests/backend/test_grammar.py b/tests/backend/test_grammar.py index 1ba68b3f..661d3884 100644 --- a/tests/backend/test_grammar.py +++ b/tests/backend/test_grammar.py @@ -277,6 +277,7 @@ def ar(func, box): assert func(g.r) == func(g).r == g_z.r assert func(f >> g) == f_z >> g_z assert func(f @ g) == f_z @ g_z + assert func(f.dagger()) == func(f).dagger() def bad_ar(func, box): return Box("BOX", a, c) if box.cod == b else box diff --git a/tests/test_circuit.py b/tests/test_circuit.py index 29aa80d2..2f2bd54e 100644 --- a/tests/test_circuit.py +++ b/tests/test_circuit.py @@ -320,3 +320,20 @@ def test_lambeq_tket_conversion(): def test_special_characters(box, expected_sym_count): ansatz = Sim15Ansatz({n_ty: 2, comma_ty: 2, space_ty: 2}, n_layers=1) assert(len(ansatz(box).free_symbols) == expected_sym_count) + + +def test_ansatz_is_dagger_functor(): + ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1) + diagram = Word('John', N) + circuit1 = ansatz(diagram).dagger() + circuit2 = ansatz(diagram.dagger()) + assert circuit1 == circuit2 + +def test_ansatz_is_dagger_functor_sentence(): + ansatz = IQPAnsatz({N: 1, S: 1}, n_layers=1) + diagram = (Word('Alice', N) @ Word('runs', N >> S) >> + Cup(N, N.r) @ S) + + circuit1 = ansatz(diagram).dagger().normal_form() + circuit2 = ansatz(diagram.dagger()).normal_form() + assert circuit1 == circuit2 From 393733083fc9f4d1a86ae53be145b92dcfc12fa5 Mon Sep 17 00:00:00 2001 From: Dimitri Kartsaklis Date: Tue, 1 Oct 2024 09:18:22 +0100 Subject: [PATCH 5/6] Update README.md Fix broken logo link. --- README.md | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/README.md b/README.md index f2a13ec0..90ca37bf 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,4 @@ -# lambeq - -[![lambeq logo](https://cqcl.github.io/lambeq-docs/_static/lambeq_logo.png)](//cqcl.github.io/lambeq-docs) +# λambeq ![Build status](https://github.com/CQCL/lambeq/actions/workflows/build_test.yml/badge.svg) [![License](https://img.shields.io/github/license/CQCL/lambeq)](LICENSE) From e8f5e0f6c5e6dc77e55afb95957103a99c78f4fb Mon Sep 17 00:00:00 2001 From: Kin Ian Lo Date: Fri, 11 Oct 2024 02:21:50 +0100 Subject: [PATCH 6/6] add spider_chain_reader --- lambeq/__init__.py | 3 ++- lambeq/text2diagram/__init__.py | 4 +++- lambeq/text2diagram/spiders_reader.py | 5 ++++- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/lambeq/__init__.py b/lambeq/__init__.py index 2ff4c656..187962e4 100644 --- a/lambeq/__init__.py +++ b/lambeq/__init__.py @@ -61,6 +61,7 @@ 'bag_of_words_reader', 'cups_reader', 'spiders_reader', + 'spider_chain_reader', 'stairs_reader', 'word_sequence_reader', @@ -124,7 +125,7 @@ WebParseError, WebParser, Reader, LinearReader, TreeReader, TreeReaderMode, bag_of_words_reader, cups_reader, spiders_reader, - stairs_reader, word_sequence_reader) + spider_chain_reader, stairs_reader, word_sequence_reader) from lambeq.tokeniser import Tokeniser, SpacyTokeniser from lambeq.training import (Checkpoint, Dataset, Optimizer, NelderMeadOptimizer, RotosolveOptimizer, diff --git a/lambeq/text2diagram/__init__.py b/lambeq/text2diagram/__init__.py index a5debd31..04ae5be7 100644 --- a/lambeq/text2diagram/__init__.py +++ b/lambeq/text2diagram/__init__.py @@ -34,6 +34,7 @@ 'bag_of_words_reader', 'cups_reader', 'spiders_reader', + 'spider_chain_reader', 'stairs_reader', 'word_sequence_reader'] @@ -53,5 +54,6 @@ stairs_reader, word_sequence_reader) from lambeq.text2diagram.spiders_reader import (bag_of_words_reader, - spiders_reader) + spiders_reader, + spider_chain_reader) from lambeq.text2diagram.tree_reader import TreeReader, TreeReaderMode diff --git a/lambeq/text2diagram/spiders_reader.py b/lambeq/text2diagram/spiders_reader.py index 733b47da..7bb1dd5b 100644 --- a/lambeq/text2diagram/spiders_reader.py +++ b/lambeq/text2diagram/spiders_reader.py @@ -12,12 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -__all__ = ['SpidersReader', 'bag_of_words_reader', 'spiders_reader'] +__all__ = ['SpidersReader', 'bag_of_words_reader', + 'spiders_reader', 'spider_chain_reader'] from lambeq.backend.grammar import Diagram, Id, Spider, Word from lambeq.core.types import AtomicType from lambeq.core.utils import SentenceType, tokenised_sentence_type_check from lambeq.text2diagram.base import Reader +from lambeq.text2diagram.linear_reader import LinearReader S = AtomicType.SENTENCE @@ -46,3 +48,4 @@ def sentence2diagram(self, spiders_reader = SpidersReader() bag_of_words_reader = spiders_reader +spider_chain_reader = LinearReader(Spider(AtomicType.SENTENCE, 2, 1))