diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index 9808d042e..756619eda 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -1,4 +1,5 @@ import copy +import logging import random from collections import defaultdict from collections.abc import Iterator @@ -18,6 +19,8 @@ ) from pie_modules.utils.span import are_nested +logger = logging.getLogger(__name__) + S = TypeVar("S", bound=Span) S2 = TypeVar("S2", bound=Span) @@ -237,11 +240,12 @@ def add_negative_coref_relations( max_num_negative = int(len(positive_rels) * downsampling_factor) if max_num_negative == 0: - raise ValueError( + logger.warning( f"downsampling with factor={downsampling_factor} and number of " f"positive relations={len(positive_rels)} does not produce any negatives" ) - random.shuffle(negative_rels) + else: + random.shuffle(negative_rels) negative_rels = negative_rels[:max_num_negative] for rel in negative_rels: new_rels2new_docs[rel].binary_coref_relations.append(rel) diff --git a/tests/document/processing/test_text_pair.py b/tests/document/processing/test_text_pair.py index b64382180..dd205ce7d 100644 --- a/tests/document/processing/test_text_pair.py +++ b/tests/document/processing/test_text_pair.py @@ -281,10 +281,7 @@ def test_construct_negative_documents(positive_and_negative_documents): ] -def test_construct_negative_documents_with_downsampling(positive_documents): - # set fixed seed because the negatives will get shuffled - random.seed(42) - docs = list(add_negative_coref_relations(positive_documents, downsampling_factor=1.0)) +def _get_all_all_rels_and_scores(docs): all_texts = [(doc.text, doc.text_pair) for doc in docs] all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs] all_rels_resolved = [doc.binary_coref_relations.resolve() for doc in docs] @@ -293,6 +290,14 @@ def test_construct_negative_documents_with_downsampling(positive_documents): (texts, list(zip(scores, rels_resolved))) for texts, scores, rels_resolved in zip(all_texts, all_scores, all_rels_resolved) ] + return all_rels_and_scores + + +def test_construct_negative_documents_with_downsampling(positive_documents, caplog): + # set fixed seed because the negatives will get shuffled + random.seed(42) + docs = list(add_negative_coref_relations(positive_documents, downsampling_factor=1.0)) + all_rels_and_scores = _get_all_all_rels_and_scores(docs) # check number relations all_rels_flat = [rel for doc in docs for rel in doc.binary_coref_relations] @@ -334,6 +339,39 @@ def test_construct_negative_documents_with_downsampling(positive_documents): ), ] + # sampling target is too low + caplog.clear() + docs = list(add_negative_coref_relations(positive_documents, downsampling_factor=0.0)) + assert caplog.messages == [ + "downsampling with factor=0.0 and number of positive relations=4 does not produce any negatives" + ] + # check number relations + all_rels_flat = [rel for doc in docs for rel in doc.binary_coref_relations] + # positives: 2 x number of positives (we add instances with swapped texts) + assert len([rel.score for rel in all_rels_flat if rel.score > 0.0]) == 4 + # negatives + assert len([rel.score for rel in all_rels_flat if rel.score == 0.0]) == 0 + # check actual content + all_rels_and_scores = _get_all_all_rels_and_scores(docs) + assert all_rels_and_scores == [ + ( + ("And she founded C.", "Entity A works at B."), + [(1.0, ("coref", (("PERSON", "she"), ("PERSON", "Entity A"))))], + ), + ( + ("Bob loves his cat.", "She sleeps a lot."), + [(1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))))], + ), + ( + ("Entity A works at B.", "And she founded C."), + [(1.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "she"))))], + ), + ( + ("She sleeps a lot.", "Bob loves his cat."), + [(1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))))], + ), + ] + # no positives doc2 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations( id="0", text="Bob loves his cat.", text_pair="She sleeps a lot." @@ -349,14 +387,6 @@ def test_construct_negative_documents_with_downsampling(positive_documents): "max_num_negative" ) - # sampling target is too low - with pytest.raises(ValueError) as e: - list(add_negative_coref_relations(positive_documents, downsampling_factor=0.0)) - assert ( - str(e.value) - == "downsampling with factor=0.0 and number of positive relations=4 does not produce any negatives" - ) - def test_construct_text_document_from_text_pair_coref_document(positive_and_negative_documents): glue_text = ""