Skip to content

Commit

Permalink
allow that downsampling negatives does nto produce negatives at all
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Sep 16, 2024
1 parent 538e9ee commit 833eb8c
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 14 deletions.
8 changes: 6 additions & 2 deletions src/pie_modules/document/processing/text_pair.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import logging
import random
from collections import defaultdict
from collections.abc import Iterator
Expand All @@ -18,6 +19,8 @@
)
from pie_modules.utils.span import are_nested

logger = logging.getLogger(__name__)

S = TypeVar("S", bound=Span)
S2 = TypeVar("S2", bound=Span)

Expand Down Expand Up @@ -237,11 +240,12 @@ def add_negative_coref_relations(

max_num_negative = int(len(positive_rels) * downsampling_factor)
if max_num_negative == 0:
raise ValueError(
logger.warning(
f"downsampling with factor={downsampling_factor} and number of "
f"positive relations={len(positive_rels)} does not produce any negatives"
)
random.shuffle(negative_rels)
else:
random.shuffle(negative_rels)
negative_rels = negative_rels[:max_num_negative]
for rel in negative_rels:
new_rels2new_docs[rel].binary_coref_relations.append(rel)
Expand Down
54 changes: 42 additions & 12 deletions tests/document/processing/test_text_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,7 @@ def test_construct_negative_documents(positive_and_negative_documents):
]


def test_construct_negative_documents_with_downsampling(positive_documents):
# set fixed seed because the negatives will get shuffled
random.seed(42)
docs = list(add_negative_coref_relations(positive_documents, downsampling_factor=1.0))
def _get_all_all_rels_and_scores(docs):
all_texts = [(doc.text, doc.text_pair) for doc in docs]
all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs]
all_rels_resolved = [doc.binary_coref_relations.resolve() for doc in docs]
Expand All @@ -293,6 +290,14 @@ def test_construct_negative_documents_with_downsampling(positive_documents):
(texts, list(zip(scores, rels_resolved)))
for texts, scores, rels_resolved in zip(all_texts, all_scores, all_rels_resolved)
]
return all_rels_and_scores


def test_construct_negative_documents_with_downsampling(positive_documents, caplog):
# set fixed seed because the negatives will get shuffled
random.seed(42)
docs = list(add_negative_coref_relations(positive_documents, downsampling_factor=1.0))
all_rels_and_scores = _get_all_all_rels_and_scores(docs)

# check number relations
all_rels_flat = [rel for doc in docs for rel in doc.binary_coref_relations]
Expand Down Expand Up @@ -334,6 +339,39 @@ def test_construct_negative_documents_with_downsampling(positive_documents):
),
]

# sampling target is too low
caplog.clear()
docs = list(add_negative_coref_relations(positive_documents, downsampling_factor=0.0))
assert caplog.messages == [
"downsampling with factor=0.0 and number of positive relations=4 does not produce any negatives"
]
# check number relations
all_rels_flat = [rel for doc in docs for rel in doc.binary_coref_relations]
# positives: 2 x number of positives (we add instances with swapped texts)
assert len([rel.score for rel in all_rels_flat if rel.score > 0.0]) == 4
# negatives
assert len([rel.score for rel in all_rels_flat if rel.score == 0.0]) == 0
# check actual content
all_rels_and_scores = _get_all_all_rels_and_scores(docs)
assert all_rels_and_scores == [
(
("And she founded C.", "Entity A works at B."),
[(1.0, ("coref", (("PERSON", "she"), ("PERSON", "Entity A"))))],
),
(
("Bob loves his cat.", "She sleeps a lot."),
[(1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))))],
),
(
("Entity A works at B.", "And she founded C."),
[(1.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "she"))))],
),
(
("She sleeps a lot.", "Bob loves his cat."),
[(1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))))],
),
]

# no positives
doc2 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations(
id="0", text="Bob loves his cat.", text_pair="She sleeps a lot."
Expand All @@ -349,14 +387,6 @@ def test_construct_negative_documents_with_downsampling(positive_documents):
"max_num_negative"
)

# sampling target is too low
with pytest.raises(ValueError) as e:
list(add_negative_coref_relations(positive_documents, downsampling_factor=0.0))
assert (
str(e.value)
== "downsampling with factor=0.0 and number of positive relations=4 does not produce any negatives"
)


def test_construct_text_document_from_text_pair_coref_document(positive_and_negative_documents):
glue_text = "<s><s>"
Expand Down

0 comments on commit 833eb8c

Please sign in to comment.