Skip to content

Commit

Permalink
implement construct_text_document_from_text_pair_coref_document()
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Sep 15, 2024
1 parent 50026b4 commit 8dc7c5e
Show file tree
Hide file tree
Showing 2 changed files with 101 additions and 7 deletions.
46 changes: 46 additions & 0 deletions src/pie_modules/document/processing/text_pair.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import copy
from collections import defaultdict
from collections.abc import Iterator
from itertools import chain
from typing import Dict, Iterable, List, Tuple, TypeVar

from pytorch_ie.annotations import LabeledSpan, Span
from pytorch_ie.documents import (
TextDocumentWithLabeledSpansAndBinaryRelations,
TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions,
)
from tqdm import tqdm
Expand Down Expand Up @@ -119,6 +122,49 @@ def construct_text_pair_coref_documents_from_partitions_via_relations(
)


def shift_span(span: S, offset: int) -> S:
return span.copy(start=span.start + offset, end=span.end + offset)


def construct_text_document_from_text_pair_coref_document(
document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, glue_text: str
) -> TextDocumentWithLabeledSpansAndBinaryRelations:
if document.text == document.text_pair:
new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
id=document.id, metadata=copy.deepcopy(document.metadata), text=document.text
)
old2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
new2new_spans: Dict[LabeledSpan, LabeledSpan] = {}
for old_span in chain(document.labeled_spans, document.labeled_spans_pair):
new_span = old_span.copy()
# when detaching / copying the span, it may be the same as a previous span from the other
new_span = new2new_spans.get(new_span, new_span)
new2new_spans[new_span] = new_span
old2new_spans[old_span] = new_span
else:
new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
text=document.text + glue_text + document.text_pair,
id=document.id,
metadata=copy.deepcopy(document.metadata),
)
old2new_spans = {}
old2new_spans.update({span: span.copy() for span in document.labeled_spans})
offset = len(document.text) + len(glue_text)
old2new_spans.update(
{span: shift_span(span.copy(), offset) for span in document.labeled_spans_pair}
)

# sort to make order deterministic
new_doc.labeled_spans.extend(
sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label))
)
for old_rel in document.binary_coref_relations:
new_rel = old_rel.copy(head=old2new_spans[old_rel.head], tail=old2new_spans[old_rel.tail])
new_doc.binary_relations.append(new_rel)

return new_doc


def add_negative_coref_relations(
documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations], **kwargs
) -> Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]:
Expand Down
62 changes: 55 additions & 7 deletions tests/document/processing/test_text_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from pie_modules.document.processing.text_pair import (
add_negative_coref_relations,
construct_text_document_from_text_pair_coref_document,
construct_text_pair_coref_documents_from_partitions_via_relations,
)
from pie_modules.documents import (
Expand Down Expand Up @@ -196,21 +197,31 @@ def test_positive_documents(positive_documents):
]


def test_construct_negative_documents(positive_documents):
assert len(positive_documents) == 2
@pytest.fixture(scope="module")
def positive_and_negative_documents(positive_documents):
docs = list(add_negative_coref_relations(positive_documents))
return docs


def test_construct_negative_documents(positive_and_negative_documents):
assert len(positive_and_negative_documents) == 16
TEXTS = [
"Entity A works at B.",
"And she founded C.",
"Bob loves his cat.",
"She sleeps a lot.",
]
assert all(doc.text in TEXTS for doc in docs)
assert all(doc.text_pair in TEXTS for doc in docs)
assert all(doc.text in TEXTS for doc in positive_and_negative_documents)
assert all(doc.text_pair in TEXTS for doc in positive_and_negative_documents)

all_texts = [(doc.text, doc.text_pair) for doc in docs]
all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs]
all_rels_resolved = [doc.binary_coref_relations.resolve() for doc in docs]
all_texts = [(doc.text, doc.text_pair) for doc in positive_and_negative_documents]
all_scores = [
[coref_rel.score for coref_rel in doc.binary_coref_relations]
for doc in positive_and_negative_documents
]
all_rels_resolved = [
doc.binary_coref_relations.resolve() for doc in positive_and_negative_documents
]

all_rels_and_scores = [
(texts, list(zip(scores, rels_resolved)))
Expand Down Expand Up @@ -265,3 +276,40 @@ def test_construct_negative_documents(positive_documents):
(("She sleeps a lot.", "Entity A works at B."), []),
(("She sleeps a lot.", "She sleeps a lot."), []),
]


def test_construct_text_document_from_text_pair_coref_document(positive_and_negative_documents):
glue_text = "<s><s>"
docs = [
construct_text_document_from_text_pair_coref_document(doc, glue_text=glue_text)
for doc in positive_and_negative_documents
]
assert len(docs) == 16
doc = docs[0]
assert doc.text == "And she founded C."
assert doc.labeled_spans.resolve() == [("PERSON", "she"), ("COMPANY", "C")]
assert doc.binary_relations.resolve() == []
assert [rel.score for rel in doc.binary_relations] == []

doc = docs[1]
assert doc.text == "And she founded C.<s><s>Bob loves his cat."
assert doc.labeled_spans.resolve() == [
("PERSON", "she"),
("COMPANY", "C"),
("PERSON", "Bob"),
("ANIMAL", "his cat"),
]
assert doc.binary_relations.resolve() == [("coref", (("PERSON", "she"), ("PERSON", "Bob")))]
assert [rel.score for rel in doc.binary_relations] == [0.0]

doc = docs[7]
assert doc.text == "Bob loves his cat.<s><s>She sleeps a lot."
assert doc.labeled_spans.resolve() == [
("PERSON", "Bob"),
("ANIMAL", "his cat"),
("ANIMAL", "She"),
]
assert doc.binary_relations.resolve() == [
("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She")))
]
assert [rel.score for rel in doc.binary_relations] == [1.0]

0 comments on commit 8dc7c5e

Please sign in to comment.