Skip to content

Commit

Permalink
Add frequency-based noun pruning to DisCoCircReader (#192)
Browse files Browse the repository at this point in the history
Co-authored by: Colin Krawchuk <[email protected]>
  • Loading branch information
nikhilkhatri authored Dec 3, 2024
1 parent e219768 commit 8ec8c89
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 6 deletions.
24 changes: 21 additions & 3 deletions lambeq/experimental/discocirc/coref_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ def dict_from_corefs(self,
for scrf in scoref]

for scoref in scorefs:
corefd[scoref] = scorefs[0]
if scoref not in corefd:
corefd[scoref] = scorefs[0]

return corefd

Expand All @@ -87,14 +88,21 @@ class SpacyCoreferenceResolver(CoreferenceResolver):
"""Corefence resolution and tokenisation based on spaCy."""

def __init__(self):
self.nlp = spacy.load('en_coreference_web_trf',
exclude=('span_resolver', 'span_cleaner'))
# Create basic tokenisation pipeline, for POS
self.nlp = spacy.load('en_core_web_sm')

# Add coreference resolver pipe stage
coref_stage = spacy.load('en_coreference_web_trf',
exclude=('span_resolver', 'span_cleaner'))
self.nlp.add_pipe('transformer', source=coref_stage)
self.nlp.add_pipe('coref', source=coref_stage)

def tokenise_and_coref(self, text):
text = self._clean_text(text)
doc = self.nlp(text)
coreferences = []

# Add all coreference instances
for cluster in doc.spans.values():
sent_clusters = [[] for _ in doc.sents]
for span in cluster:
Expand All @@ -104,4 +112,14 @@ def tokenise_and_coref(self, text):
break
coreferences.append(sent_clusters)

# Add trivial coreferences for all nouns, determined by spacy POS
spacy_noun_pos = {'NOUN', 'PROPN', 'PRON'}

for i, sent in enumerate(doc.sents):
for tok in sent:
if tok.pos_ in spacy_noun_pos:
sent_clusters = [[] for _ in doc.sents]
sent_clusters[i] = [tok.i - sent.start]
coreferences.append(sent_clusters)

return [[str(w) for w in s] for s in doc.sents], coreferences
15 changes: 12 additions & 3 deletions lambeq/experimental/discocirc/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def text2circuit(self,
sandwich: bool = False,
break_cycles: bool = True,
pruned_nouns: Iterable[str] = (),
min_noun_freq: int = 0,
min_noun_freq: int = 1,
rewrite_rules: (
Iterable[TreeRewriteRule | str] | None
) = ('determiner', 'auxiliary'),
Expand All @@ -313,6 +313,9 @@ def text2circuit(self,
If any of the nouns in this list are present in the diagram,
the corresponding state and wire are removed from the
diagram.
min_noun_freq: int, default: 1
Mininum number of times a noun needs to be referenced to
appear in the circuit.
rewrite_rules : list of `TreeRewriteRule` or str
List of rewrite rules to apply to the pregroup tree
before conversion to a circuit.
Expand All @@ -329,10 +332,16 @@ def text2circuit(self,
"""

sentences, corefs = self.coref_resolver.tokenise_and_coref(text)
corefd = self.coref_resolver.dict_from_corefs(corefs)

pruned_ids = self._prune_indices(sentences, corefs, pruned_nouns)
noun_counts = Counter(corefd.values())
freq_pruned_ids = [nid for nid, count in noun_counts.items()
if count < min_noun_freq]

corefd = self.coref_resolver.dict_from_corefs(corefs)
pruned_nouns = set(pruned_nouns).union(
{sentences[i][j] for (i, j) in freq_pruned_ids})

pruned_ids = self._prune_indices(sentences, corefs, pruned_nouns)

rewriter = TreeRewriter(rewrite_rules)

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ experimental =
numpy == 1.26.4
spacy ~= 3.4.0
spacy-experimental ~= 0.6.4
en-core-web-sm
en-coreference-web-trf @ https://github.com/explosion/spacy-experimental/releases/download/v0.6.1/en_coreference_web_trf-3.4.0a2-py3-none-any.whl

[options.entry_points]
Expand Down

0 comments on commit 8ec8c89

Please sign in to comment.