Skip to content

Commit

Permalink
Merge pull request #123 from ArneBinder/add_negative_coref_relations/…
Browse files Browse the repository at this point in the history
…max_num_negatives

add parameter `max_num_negatives` to `add_negative_coref_relations()`
  • Loading branch information
ArneBinder authored Sep 20, 2024
2 parents 86c85d6 + 5e9fe45 commit 4c4e0a7
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 19 deletions.
49 changes: 31 additions & 18 deletions src/pie_modules/document/processing/text_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ def construct_text_document_from_text_pair_coref_document(

def add_negative_coref_relations(
documents: Iterable[TextPairDocumentWithLabeledSpansAndBinaryCorefRelations],
max_num_negatives: Optional[int] = None,
downsampling_factor: Optional[float] = None,
random_seed: Optional[int] = None,
enforce_same_original_doc_id: bool = False,
Expand Down Expand Up @@ -276,27 +277,39 @@ def add_negative_coref_relations(
for rel in positive_rels:
new_rels2new_docs[rel].binary_coref_relations.append(rel)

# Downsampling of negatives. This requires positive instances!
if downsampling_factor is not None:
if len(positive_rels) == 0:
raise ValueError(
f"downsampling [factor={downsampling_factor}] is enabled, "
f"but no positive relations are available to calculate max_num_negative"
)
if max_num_negatives is None:
# Downsampling of negatives. This requires positive instances!
if downsampling_factor is not None:
if len(positive_rels) == 0:
raise ValueError(
f"downsampling [factor={downsampling_factor}] is enabled, "
f"but no positive relations are available to calculate max_num_negatives"
)

max_num_negatives = int(len(positive_rels) * downsampling_factor)
if max_num_negatives == 0:
logger.warning(
f"downsampling with factor={downsampling_factor} and number of "
f"positive relations={len(positive_rels)} does not produce any negatives"
)
elif downsampling_factor is not None:
raise ValueError(
f"setting max_num_negatives [{max_num_negatives}] and [{downsampling_factor}] "
f"simultaneously is ambiguous and not allowed"
)

if max_num_negatives is not None:
if random_seed is not None:
random.seed(random_seed)
random.shuffle(negative_rels)
negative_rels = negative_rels[:max_num_negatives]

max_num_negative = int(len(positive_rels) * downsampling_factor)
if max_num_negative == 0:
logger.warning(
f"downsampling with factor={downsampling_factor} and number of "
f"positive relations={len(positive_rels)} does not produce any negatives"
)
else:
if random_seed is not None:
random.seed(random_seed)
random.shuffle(negative_rels)
negative_rels = negative_rels[:max_num_negative]
for rel in negative_rels:
new_rels2new_docs[rel].binary_coref_relations.append(rel)

docs_with_rels = [doc for doc in new_docs if len(doc.binary_coref_relations) > 0]
logger.info(
f"constructed {len(negative_rels)} negative for {len(positive_rels)} "
f"positive relations in {len(docs_with_rels)} documents"
)
return docs_with_rels
2 changes: 1 addition & 1 deletion tests/document/processing/test_text_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def test_construct_negative_documents_with_downsampling(positive_documents, capl
assert (
str(e.value)
== "downsampling [factor=1.0] is enabled, but no positive relations are available to calculate "
"max_num_negative"
"max_num_negatives"
)


Expand Down

0 comments on commit 4c4e0a7

Please sign in to comment.