Skip to content

Commit

Permalink
add parameter relation_label_mapping to construct_text_document_from_…
Browse files Browse the repository at this point in the history
…text_pair_coref_document()
  • Loading branch information
ArneBinder committed Sep 15, 2024
1 parent 8dc7c5e commit c5c3161
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
13 changes: 10 additions & 3 deletions src/pie_modules/document/processing/text_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict
from collections.abc import Iterator
from itertools import chain
from typing import Dict, Iterable, List, Tuple, TypeVar
from typing import Dict, Iterable, List, Optional, Tuple, TypeVar

from pytorch_ie.annotations import LabeledSpan, Span
from pytorch_ie.documents import (
Expand Down Expand Up @@ -127,7 +127,9 @@ def shift_span(span: S, offset: int) -> S:


def construct_text_document_from_text_pair_coref_document(
document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, glue_text: str
document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations,
glue_text: str,
relation_label_mapping: Optional[Dict[str, str]] = None,
) -> TextDocumentWithLabeledSpansAndBinaryRelations:
if document.text == document.text_pair:
new_doc = TextDocumentWithLabeledSpansAndBinaryRelations(
Expand Down Expand Up @@ -159,7 +161,12 @@ def construct_text_document_from_text_pair_coref_document(
sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label))
)
for old_rel in document.binary_coref_relations:
new_rel = old_rel.copy(head=old2new_spans[old_rel.head], tail=old2new_spans[old_rel.tail])
label = old_rel.label
if relation_label_mapping is not None:
label = relation_label_mapping.get(label, label)
new_rel = old_rel.copy(
head=old2new_spans[old_rel.head], tail=old2new_spans[old_rel.tail], label=label
)
new_doc.binary_relations.append(new_rel)

return new_doc
Expand Down
10 changes: 7 additions & 3 deletions tests/document/processing/test_text_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,9 @@ def test_construct_negative_documents(positive_and_negative_documents):
def test_construct_text_document_from_text_pair_coref_document(positive_and_negative_documents):
glue_text = "<s><s>"
docs = [
construct_text_document_from_text_pair_coref_document(doc, glue_text=glue_text)
construct_text_document_from_text_pair_coref_document(
doc, glue_text=glue_text, relation_label_mapping={"coref": "semantically_same"}
)
for doc in positive_and_negative_documents
]
assert len(docs) == 16
Expand All @@ -299,7 +301,9 @@ def test_construct_text_document_from_text_pair_coref_document(positive_and_nega
("PERSON", "Bob"),
("ANIMAL", "his cat"),
]
assert doc.binary_relations.resolve() == [("coref", (("PERSON", "she"), ("PERSON", "Bob")))]
assert doc.binary_relations.resolve() == [
("semantically_same", (("PERSON", "she"), ("PERSON", "Bob")))
]
assert [rel.score for rel in doc.binary_relations] == [0.0]

doc = docs[7]
Expand All @@ -310,6 +314,6 @@ def test_construct_text_document_from_text_pair_coref_document(positive_and_nega
("ANIMAL", "She"),
]
assert doc.binary_relations.resolve() == [
("coref", (("ANIMAL", "his cat"), ("ANIMAL", "She")))
("semantically_same", (("ANIMAL", "his cat"), ("ANIMAL", "She")))
]
assert [rel.score for rel in doc.binary_relations] == [1.0]

0 comments on commit c5c3161

Please sign in to comment.