diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index a5afc4c34..a9902293f 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -1,9 +1,12 @@ +import copy from collections import defaultdict from collections.abc import Iterator +from itertools import chain from typing import Dict, Iterable, List, Tuple, TypeVar from pytorch_ie.annotations import LabeledSpan, Span from pytorch_ie.documents import ( + TextDocumentWithLabeledSpansAndBinaryRelations, TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions, ) from tqdm import tqdm @@ -119,6 +122,49 @@ def construct_text_pair_coref_documents_from_partitions_via_relations( ) +def shift_span(span: S, offset: int) -> S: + return span.copy(start=span.start + offset, end=span.end + offset) + + +def construct_text_document_from_text_pair_coref_document( + document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, glue_text: str +) -> TextDocumentWithLabeledSpansAndBinaryRelations: + if document.text == document.text_pair: + new_doc = TextDocumentWithLabeledSpansAndBinaryRelations( + id=document.id, metadata=copy.deepcopy(document.metadata), text=document.text + ) + old2new_spans: Dict[LabeledSpan, LabeledSpan] = {} + new2new_spans: Dict[LabeledSpan, LabeledSpan] = {} + for old_span in chain(document.labeled_spans, document.labeled_spans_pair): + new_span = old_span.copy() + # when detaching / copying the span, it may be the same as a previous span from the other + new_span = new2new_spans.get(new_span, new_span) + new2new_spans[new_span] = new_span + old2new_spans[old_span] = new_span + else: + new_doc = TextDocumentWithLabeledSpansAndBinaryRelations( + text=document.text + glue_text + document.text_pair, + id=document.id, + metadata=copy.deepcopy(document.metadata), + ) + old2new_spans = {} + old2new_spans.update({span: span.copy() for span in document.labeled_spans}) + offset = len(document.text) + len(glue_text) + old2new_spans.update( + {span: shift_span(span.copy(), offset) for span in document.labeled_spans_pair} + ) + + # sort to make order deterministic + new_doc.labeled_spans.extend( + sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label)) + ) + for old_rel in document.binary_coref_relations: + new_rel = old_rel.copy(head=old2new_spans[old_rel.head], tail=old2new_spans[old_rel.tail]) + new_doc.binary_relations.append(new_rel) + + return new_doc + + def add_negative_coref_relations( documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations], **kwargs ) -> Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]: diff --git a/tests/document/processing/test_text_pair.py b/tests/document/processing/test_text_pair.py index 9d4aa1a20..c4ade49ee 100644 --- a/tests/document/processing/test_text_pair.py +++ b/tests/document/processing/test_text_pair.py @@ -9,6 +9,7 @@ from pie_modules.document.processing.text_pair import ( add_negative_coref_relations, + construct_text_document_from_text_pair_coref_document, construct_text_pair_coref_documents_from_partitions_via_relations, ) from pie_modules.documents import ( @@ -196,21 +197,31 @@ def test_positive_documents(positive_documents): ] -def test_construct_negative_documents(positive_documents): - assert len(positive_documents) == 2 +@pytest.fixture(scope="module") +def positive_and_negative_documents(positive_documents): docs = list(add_negative_coref_relations(positive_documents)) + return docs + + +def test_construct_negative_documents(positive_and_negative_documents): + assert len(positive_and_negative_documents) == 16 TEXTS = [ "Entity A works at B.", "And she founded C.", "Bob loves his cat.", "She sleeps a lot.", ] - assert all(doc.text in TEXTS for doc in docs) - assert all(doc.text_pair in TEXTS for doc in docs) + assert all(doc.text in TEXTS for doc in positive_and_negative_documents) + assert all(doc.text_pair in TEXTS for doc in positive_and_negative_documents) - all_texts = [(doc.text, doc.text_pair) for doc in docs] - all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs] - all_rels_resolved = [doc.binary_coref_relations.resolve() for doc in docs] + all_texts = [(doc.text, doc.text_pair) for doc in positive_and_negative_documents] + all_scores = [ + [coref_rel.score for coref_rel in doc.binary_coref_relations] + for doc in positive_and_negative_documents + ] + all_rels_resolved = [ + doc.binary_coref_relations.resolve() for doc in positive_and_negative_documents + ] all_rels_and_scores = [ (texts, list(zip(scores, rels_resolved))) @@ -265,3 +276,40 @@ def test_construct_negative_documents(positive_documents): (("She sleeps a lot.", "Entity A works at B."), []), (("She sleeps a lot.", "She sleeps a lot."), []), ] + + +def test_construct_text_document_from_text_pair_coref_document(positive_and_negative_documents): + glue_text = "" + docs = [ + construct_text_document_from_text_pair_coref_document(doc, glue_text=glue_text) + for doc in positive_and_negative_documents + ] + assert len(docs) == 16 + doc = docs[0] + assert doc.text == "And she founded C." + assert doc.labeled_spans.resolve() == [("PERSON", "she"), ("COMPANY", "C")] + assert doc.binary_relations.resolve() == [] + assert [rel.score for rel in doc.binary_relations] == [] + + doc = docs[1] + assert doc.text == "And she founded C.Bob loves his cat." + assert doc.labeled_spans.resolve() == [ + ("PERSON", "she"), + ("COMPANY", "C"), + ("PERSON", "Bob"), + ("ANIMAL", "his cat"), + ] + assert doc.binary_relations.resolve() == [("coref", (("PERSON", "she"), ("PERSON", "Bob")))] + assert [rel.score for rel in doc.binary_relations] == [0.0] + + doc = docs[7] + assert doc.text == "Bob loves his cat.She sleeps a lot." + assert doc.labeled_spans.resolve() == [ + ("PERSON", "Bob"), + ("ANIMAL", "his cat"), + ("ANIMAL", "She"), + ] + assert doc.binary_relations.resolve() == [ + ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))) + ] + assert [rel.score for rel in doc.binary_relations] == [1.0]