diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index e052342cc..1117b349a 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -1,4 +1,5 @@ import copy +import random from collections import defaultdict from collections.abc import Iterator from itertools import chain @@ -177,7 +178,9 @@ def construct_text_document_from_text_pair_coref_document( def add_negative_coref_relations( - documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations], **kwargs + documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations], + downsampling_factor: Optional[float] = None, + **kwargs, ) -> Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]: positive_tuples = defaultdict(set) text2spans = defaultdict(set) @@ -225,7 +228,22 @@ def add_negative_coref_relations( for rel in positive_rels: new_rels2new_docs[rel].binary_coref_relations.append(rel) - # TODO: implement downsampling + # Downsampling of negatives. This requires positive instances! + if downsampling_factor is not None: + if len(positive_rels) == 0: + raise ValueError( + f"downsampling [factor={downsampling_factor}] is enabled, " + f"but no positive relations are available to calculate max_num_negative" + ) + + max_num_negative = int(len(positive_rels) * downsampling_factor) + if max_num_negative == 0: + raise ValueError( + f"downsampling with factor={downsampling_factor} and number of " + f"positive relations={len(positive_rels)} does not produce any negatives" + ) + random.shuffle(negative_rels) + negative_rels = negative_rels[:max_num_negative] for rel in negative_rels: new_rels2new_docs[rel].binary_coref_relations.append(rel) diff --git a/tests/document/processing/test_text_pair.py b/tests/document/processing/test_text_pair.py index e452219de..b64382180 100644 --- a/tests/document/processing/test_text_pair.py +++ b/tests/document/processing/test_text_pair.py @@ -1,3 +1,4 @@ +import random from itertools import chain from typing import List @@ -223,6 +224,16 @@ def test_construct_negative_documents(positive_and_negative_documents): doc.binary_coref_relations.resolve() for doc in positive_and_negative_documents ] + # check number of all relations + all_rels_flat = [ + rel for doc in positive_and_negative_documents for rel in doc.binary_coref_relations + ] + assert len(all_rels_flat) == 10 + # positives + assert len([rel.score for rel in all_rels_flat if rel.score > 0.0]) == 4 + # negatives + assert len([rel.score for rel in all_rels_flat if rel.score == 0.0]) == 6 + all_rels_and_scores = [ (texts, list(zip(scores, rels_resolved))) for texts, scores, rels_resolved in zip(all_texts, all_scores, all_rels_resolved) @@ -270,6 +281,83 @@ def test_construct_negative_documents(positive_and_negative_documents): ] +def test_construct_negative_documents_with_downsampling(positive_documents): + # set fixed seed because the negatives will get shuffled + random.seed(42) + docs = list(add_negative_coref_relations(positive_documents, downsampling_factor=1.0)) + all_texts = [(doc.text, doc.text_pair) for doc in docs] + all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs] + all_rels_resolved = [doc.binary_coref_relations.resolve() for doc in docs] + + all_rels_and_scores = [ + (texts, list(zip(scores, rels_resolved))) + for texts, scores, rels_resolved in zip(all_texts, all_scores, all_rels_resolved) + ] + + # check number relations + all_rels_flat = [rel for doc in docs for rel in doc.binary_coref_relations] + # positives + assert len([rel.score for rel in all_rels_flat if rel.score > 0.0]) == 4 + # negatives (same number positives because downsampling_factor=1.0) + assert len([rel.score for rel in all_rels_flat if rel.score == 0.0]) == 4 + + assert all_rels_and_scores == [ + ( + ("And she founded C.", "Entity A works at B."), + [ + (1.0, ("coref", (("PERSON", "she"), ("PERSON", "Entity A")))), + (0.0, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))), + ], + ), + ( + ("Bob loves his cat.", "And she founded C."), + [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "she"))))], + ), + ( + ("Bob loves his cat.", "Entity A works at B."), + [(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A"))))], + ), + ( + ("Bob loves his cat.", "She sleeps a lot."), + [(1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))))], + ), + ( + ("Entity A works at B.", "And she founded C."), + [ + (1.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "she")))), + (0.0, ("coref", (("COMPANY", "B"), ("COMPANY", "C")))), + ], + ), + ( + ("She sleeps a lot.", "Bob loves his cat."), + [(1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))))], + ), + ] + + # no positives + doc2 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( + id="0", text="Bob loves his cat.", text_pair="She sleeps a lot." + ) + doc2.labeled_spans.append(LabeledSpan(start=0, end=3, label="PERSON")) + doc2.labeled_spans.append(LabeledSpan(start=10, end=17, label="ANIMAL")) + doc2.labeled_spans_pair.append(LabeledSpan(start=0, end=3, label="ANIMAL")) + with pytest.raises(ValueError) as e: + list(add_negative_coref_relations([doc2], downsampling_factor=1.0)) + assert ( + str(e.value) + == "downsampling [factor=1.0] is enabled, but no positive relations are available to calculate " + "max_num_negative" + ) + + # sampling target is too low + with pytest.raises(ValueError) as e: + list(add_negative_coref_relations(positive_documents, downsampling_factor=0.0)) + assert ( + str(e.value) + == "downsampling with factor=0.0 and number of positive relations=4 does not produce any negatives" + ) + + def test_construct_text_document_from_text_pair_coref_document(positive_and_negative_documents): glue_text = "" docs = [