Skip to content

Commit

Permalink
implement downsampling for add_negative_coref_relations()
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Sep 15, 2024
1 parent 1d73f8a commit 11de48b
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 2 deletions.
22 changes: 20 additions & 2 deletions src/pie_modules/document/processing/text_pair.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import random
from collections import defaultdict
from collections.abc import Iterator
from itertools import chain
Expand Down Expand Up @@ -177,7 +178,9 @@ def construct_text_document_from_text_pair_coref_document(


def add_negative_coref_relations(
documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations], **kwargs
documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations],
downsampling_factor: Optional[float] = None,
**kwargs,
) -> Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations]:
positive_tuples = defaultdict(set)
text2spans = defaultdict(set)
Expand Down Expand Up @@ -225,7 +228,22 @@ def add_negative_coref_relations(
for rel in positive_rels:
new_rels2new_docs[rel].binary_coref_relations.append(rel)

# TODO: implement downsampling
# Downsampling of negatives. This requires positive instances!
if downsampling_factor is not None:
if len(positive_rels) == 0:
raise ValueError(
f"downsampling [factor={downsampling_factor}] is enabled, "
f"but no positive relations are available to calculate max_num_negative"
)

max_num_negative = int(len(positive_rels) * downsampling_factor)
if max_num_negative == 0:
raise ValueError(
f"downsampling with factor={downsampling_factor} and number of "
f"positive relations={len(positive_rels)} does not produce any negatives"
)
random.shuffle(negative_rels)
negative_rels = negative_rels[:max_num_negative]
for rel in negative_rels:
new_rels2new_docs[rel].binary_coref_relations.append(rel)

Expand Down
88 changes: 88 additions & 0 deletions tests/document/processing/test_text_pair.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from itertools import chain
from typing import List

Expand Down Expand Up @@ -223,6 +224,16 @@ def test_construct_negative_documents(positive_and_negative_documents):
doc.binary_coref_relations.resolve() for doc in positive_and_negative_documents
]

# check number of all relations
all_rels_flat = [
rel for doc in positive_and_negative_documents for rel in doc.binary_coref_relations
]
assert len(all_rels_flat) == 10
# positives
assert len([rel.score for rel in all_rels_flat if rel.score > 0.0]) == 4
# negatives
assert len([rel.score for rel in all_rels_flat if rel.score == 0.0]) == 6

all_rels_and_scores = [
(texts, list(zip(scores, rels_resolved)))
for texts, scores, rels_resolved in zip(all_texts, all_scores, all_rels_resolved)
Expand Down Expand Up @@ -270,6 +281,83 @@ def test_construct_negative_documents(positive_and_negative_documents):
]


def test_construct_negative_documents_with_downsampling(positive_documents):
# set fixed seed because the negatives will get shuffled
random.seed(42)
docs = list(add_negative_coref_relations(positive_documents, downsampling_factor=1.0))
all_texts = [(doc.text, doc.text_pair) for doc in docs]
all_scores = [[coref_rel.score for coref_rel in doc.binary_coref_relations] for doc in docs]
all_rels_resolved = [doc.binary_coref_relations.resolve() for doc in docs]

all_rels_and_scores = [
(texts, list(zip(scores, rels_resolved)))
for texts, scores, rels_resolved in zip(all_texts, all_scores, all_rels_resolved)
]

# check number relations
all_rels_flat = [rel for doc in docs for rel in doc.binary_coref_relations]
# positives
assert len([rel.score for rel in all_rels_flat if rel.score > 0.0]) == 4
# negatives (same number positives because downsampling_factor=1.0)
assert len([rel.score for rel in all_rels_flat if rel.score == 0.0]) == 4

assert all_rels_and_scores == [
(
("And she founded C.", "Entity A works at B."),
[
(1.0, ("coref", (("PERSON", "she"), ("PERSON", "Entity A")))),
(0.0, ("coref", (("COMPANY", "C"), ("COMPANY", "B")))),
],
),
(
("Bob loves his cat.", "And she founded C."),
[(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "she"))))],
),
(
("Bob loves his cat.", "Entity A works at B."),
[(0.0, ("coref", (("PERSON", "Bob"), ("PERSON", "Entity A"))))],
),
(
("Bob loves his cat.", "She sleeps a lot."),
[(1.0, ("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She"))))],
),
(
("Entity A works at B.", "And she founded C."),
[
(1.0, ("coref", (("PERSON", "Entity A"), ("PERSON", "she")))),
(0.0, ("coref", (("COMPANY", "B"), ("COMPANY", "C")))),
],
),
(
("She sleeps a lot.", "Bob loves his cat."),
[(1.0, ("coref", (("ANIMAL", "She"), ("ANIMAL", "his cat"))))],
),
]

# no positives
doc2 = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations(
id="0", text="Bob loves his cat.", text_pair="She sleeps a lot."
)
doc2.labeled_spans.append(LabeledSpan(start=0, end=3, label="PERSON"))
doc2.labeled_spans.append(LabeledSpan(start=10, end=17, label="ANIMAL"))
doc2.labeled_spans_pair.append(LabeledSpan(start=0, end=3, label="ANIMAL"))
with pytest.raises(ValueError) as e:
list(add_negative_coref_relations([doc2], downsampling_factor=1.0))
assert (
str(e.value)
== "downsampling [factor=1.0] is enabled, but no positive relations are available to calculate "
"max_num_negative"
)

# sampling target is too low
with pytest.raises(ValueError) as e:
list(add_negative_coref_relations(positive_documents, downsampling_factor=0.0))
assert (
str(e.value)
== "downsampling with factor=0.0 and number of positive relations=4 does not produce any negatives"
)


def test_construct_text_document_from_text_pair_coref_document(positive_and_negative_documents):
glue_text = "<s><s>"
docs = [
Expand Down

0 comments on commit 11de48b

Please sign in to comment.