Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix KeyError encountered with some texts #181

Merged
merged 2 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions lambeq/backend/pregroup_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,7 @@ def is_same_word(self, other: object) -> bool:
if not isinstance(other, PregroupTreeNode):
return NotImplemented
return (self.word == other.word
and self.ind == other.ind
and self.typ == other.typ)
and self.ind == other.ind)

@cached_property
def _tree_repr(self) -> str:
Expand Down Expand Up @@ -371,3 +370,20 @@ def merge(self) -> None:
else:
print('Cannot perform merge when parent and child '
+ "types don't match or tokens are not consecutive.")

def remove_self_cycles(self) -> None:
"""Removes the children of this node that is the same token,
i.e. self-cycles.

This is used before breaking cycles.
"""

new_children = []
for c in self.children:
if self.is_same_word(c):
c.parent = None
else:
new_children.append(c)
self.children = new_children
for c in self.children:
c.remove_self_cycles()
9 changes: 6 additions & 3 deletions lambeq/experimental/discocirc/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,9 @@ def text2circuit(self,
tree = self._sentence2tree(sentence, break_cycles)

tree_toks = tree.get_words()
reidxr = self._calculate_reindices(sentence, tree_toks)
tree_toks_indxs = tree.get_word_indices()
reidxr = self._calculate_reindices(sentence, tree_toks,
tree_toks_indxs)
reidxr[None] = None

tree = rewriter(tree)
Expand Down Expand Up @@ -400,10 +402,11 @@ def text2circuit(self,

def _calculate_reindices(self,
orig_toks,
parsed_toks):
parsed_toks,
parsed_toks_indxs):
reindexer = {}
j = 0
for i, otok in enumerate(parsed_toks):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neiljdo in what situation are an enumeration and parsed_tok_indxs going to disagree? I thought in the tree they're guaranteed to be [0 ... n-1]?

Copy link
Collaborator Author

@neiljdo neiljdo Nov 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nikhilkhatri it happens when the tree is modified after, e.g. when adding the '' word for missing nouns. The '' was never in the original tokens. These were what I saw when investigating the sentence Anna gave. You can use the following example:

text112 = "Gently and Bacchus delve into a world of army secrets when a young former soldier, Scott Tanner, commits a murder in a Turkish bath. Whilst investigating Tanner's history, Gently hears of horrific allegations of what some soldiers have to face from their own side. He is forced to question the uncomfortable truth of what it means to serve one's queen and country, as an event from the past presses on his conscience."

for i, otok in zip(parsed_toks_indxs, parsed_toks):
while j < len(orig_toks) and orig_toks[j] != otok:
j += 1
reindexer[i] = j
Expand Down
5 changes: 5 additions & 0 deletions lambeq/text2diagram/pregroup_tree_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,13 @@ def remove_cycles(root: PregroupTreeNode) -> None:
root_word_idx = root.ind
nodes = root.get_nodes()
assert len(nodes[root_word_idx]) == 1

root_node = nodes[root_word_idx][0]

# Remove nodes that cycles to itself
# (see https://github.com/CQCL/lambeq/issues/180)
root_node.remove_self_cycles()

for _, nodes_for_idx in enumerate(nodes):
if len(nodes_for_idx) > 1:
# Retain the deepest copy of the node
Expand Down
12 changes: 12 additions & 0 deletions tests/backend/test_pregroup_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@
t4_n0 = PregroupTreeNode(word='0', typ=s, ind=0, children=[t4_n1, t4_n2_2])
t4 = t4_n0

t5_n2 = PregroupTreeNode(word='and', typ=n.r @ s @ n.r.r.r @ s.r.r, ind=2)
t5_n1 = PregroupTreeNode(word='an', typ=n, ind=1)
t5_n2_2 = PregroupTreeNode(word='and', typ=s, ind=2, children=[t5_n1, t5_n2])
t5_n0 = PregroupTreeNode(word='when', typ=s, ind=0, children=[t5_n2_2])
t5 = t5_n0


def test_get_nodes():
assert t1.get_nodes() == [
Expand Down Expand Up @@ -228,3 +234,9 @@ def test_merge():
assert t3_n2.typ == n.r @ s
assert t3_n2.ind == 2
assert len(t3_n2.children) == 1


def test_remove_self_cycles():
t5.remove_self_cycles()
assert t5_n2.parent is None
assert t5_n2_2.children == [t5_n1]
27 changes: 26 additions & 1 deletion tests/text2diagram/test_pregroup_tree_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

tokeniser = SpacyTokeniser()
bobcat_parser = BobcatParser(verbose='suppress')
n, s = map(Ty, 'ns')
n, s, p = map(Ty, 'nsp')

s1 = tokeniser.tokenise_sentence(
"Last year's figures include a one-time loss of $12 million for restructuring and unusual items"
Expand Down Expand Up @@ -99,6 +99,10 @@
(Cup, 7, 12),
]
)
s11_diag = bobcat_parser.sentence2diagram(
'When an event puts Errol in danger and the case in jeopardy',
tokenised=False
)

t1_n1 = PregroupTreeNode(word='year', typ=n, ind=1)
t1_n3 = PregroupTreeNode(word='figures', typ=n, ind=3)
Expand Down Expand Up @@ -191,6 +195,22 @@
t10_n0 = PregroupTreeNode(word='0', typ=s, ind=0, children=[t10_n1])
t10_no_cycle = t10_n0

t11_n11 = PregroupTreeNode(word='jeopardy', typ=n, ind=11)
t11_n9 = PregroupTreeNode(word='case', typ=n, ind=9)
t11_n6 = PregroupTreeNode(word='danger', typ=n, ind=6)
t11_n4 = PregroupTreeNode(word='Errol', typ=n, ind=4)
t11_n3 = PregroupTreeNode(word='puts', typ=n.r @ s @ p.l @ n.l, ind=3)
t11_n2 = PregroupTreeNode(word='event', typ=n, ind=2)
t11_n10 = PregroupTreeNode(word='in', typ=p, ind=10, children=[t11_n11])
t11_n8 = PregroupTreeNode(word='the', typ=n, ind=8, children=[t11_n9])
t11_n5 = PregroupTreeNode(word='in', typ=p, ind=5, children=[t11_n6])
t11_n1 = PregroupTreeNode(word='an', typ=n, ind=1, children=[t11_n2])
t11_n7 = PregroupTreeNode(word='and', typ=s, ind=7,
children=[t11_n1, t11_n3, t11_n4,
t11_n5, t11_n8, t11_n10])
t11_n0 = PregroupTreeNode(word='When', typ=s, ind=0, children=[t11_n7])
t11_no_cycle = t11_n0


def test_diagram2tree():
s1_tree = diagram2tree(s1_diag)
Expand Down Expand Up @@ -229,6 +249,11 @@ def test_diagram2tree_no_cycles():
t10_no_cycle.draw()
assert s10_tree == t10_no_cycle

s11_tree = diagram2tree(s11_diag, break_cycles=True)
s11_tree.draw()
t11_no_cycle.draw()
assert s11_tree == t11_no_cycle


def test_tree2diagram():
assert tree2diagram(t1, t1.get_words()).pregroup_normal_form() == s1_diag.pregroup_normal_form()
Expand Down
Loading