Skip to content

Commit

Permalink
Added test cases for relation triplet extensions and multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
KasperFyhn committed Oct 31, 2023
1 parent 5568c55 commit 915645c
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ def token_span_to_spacy_span(span: Tuple[int, int], doc: Doc):
return doc[span[0] : span[1] + 1]


def spacy_span_to_token_span(span: Span) -> Tuple[int, int]:
return span.start, span.end - 1


def wp_span_to_token(
relation_span: List[List[int]],
wp_tokenid_mapping: Dict,
Expand Down Expand Up @@ -113,17 +109,26 @@ def match_extraction_spans_to_wp(
return matched_extractions


def span_to_idx(span: Span) -> Tuple[int, int]:
return span.start, span.end


def idx_to_span(idx: Tuple[int, int], doc: Doc) -> Span:
start, end = idx
return doc[start:end]


def install_extension(ext):
Doc.set_extension(ext + "_idxs", default=None)
Doc.set_extension(
ext,
setter=lambda doc, spans: setattr(
doc._,
ext + "_idxs",
[spacy_span_to_token_span(span) for span in spans],
[span_to_idx(span) for span in spans],
),
getter=lambda doc: [
token_span_to_spacy_span(idx, doc) for idx in getattr(doc._, ext + "_idxs")
idx_to_span(idx, doc) for idx in getattr(doc._, ext + "_idxs")
],
)

Expand Down
21 changes: 20 additions & 1 deletion tests/test_relationextraction_component.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest

import torch

from .utils import nlp_da # noqa F401

Expand All @@ -24,6 +24,24 @@ def test_relationextraction_component_pipe(nlp_da): # noqa F811
print(d.text, "\n", d._.relation_triplets)


@pytest.mark.skip(reason="Avoid downloading the model on GitHub actions")
def test_relationextraction_component_pipe_multiprocessing(nlp_da): # noqa F811
test_sents = [
"Pernille Blume vinder delt EM-sølv i Ungarn.",
"Pernille Blume blev nummer to ved EM på langbane i disciplinen 50 meter fri.",
] * 5

nlp_da.add_pipe("relation_extractor")

# multiprocessing and torch with multiple threads result in a deadlock, therefore:
torch.set_num_threads(1)

pipe = nlp_da.pipe(test_sents, n_process=2, batch_size=5)

for d in pipe:
print(d.text, "\n", d._.relation_triplets)


@pytest.mark.skip(reason="Avoid downloading the model on GitHub actions")
def test_relation_extraction_component_single(nlp_da): # noqa F811
nlp_da.add_pipe("relation_extractor", config={"confidence_threshold": 1.8})
Expand All @@ -36,6 +54,7 @@ def test_relation_extraction_component_single(nlp_da): # noqa F811
]


@pytest.mark.skip(reason="Avoid downloading the model on GitHub actions")
def test_relation_extraction_multi_sentence(nlp_da): # noqa F811
nlp_da.add_pipe("relation_extractor")
doc = nlp_da(
Expand Down
30 changes: 30 additions & 0 deletions tests/test_relationextraction_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from spacy.tokens import Doc
from spacy import Vocab

from conspiracies.docprocessing.relationextraction.multi2oie import multi2oie_utils


def test_relationextraction_doc_extension():
"""Verifies the behavior of the relation triplet extensions and the lambdas
that back them."""
multi2oie_utils.install_extensions()

words = "this is a test".split()
vocab = Vocab(strings=words)
test_doc = Doc(vocab, words=words)

head = test_doc[0:1]
relation = test_doc[1:2]
tail = test_doc[2:4]

test_doc._.relation_head = [head]
test_doc._.relation_relation = [relation]
test_doc._.relation_tail = [tail]

# verify the extension that contains the actual indices
assert test_doc._.relation_head_idxs == [(0, 1)]
assert test_doc._.relation_relation_idxs == [(1, 2)]
assert test_doc._.relation_tail_idxs == [(2, 4)]

# verify that relation_triplets are correctly built from the other extensions
assert test_doc._.relation_triplets == [(head, relation, tail)]

0 comments on commit 915645c

Please sign in to comment.