diff --git a/src/pie_modules/document/processing/text_pair.py b/src/pie_modules/document/processing/text_pair.py index 4387177e6..cf3885b21 100644 --- a/src/pie_modules/document/processing/text_pair.py +++ b/src/pie_modules/document/processing/text_pair.py @@ -129,6 +129,7 @@ def shift_span(span: S, offset: int) -> S: def construct_text_document_from_text_pair_coref_document( document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations, glue_text: str, + no_relation_label: str, relation_label_mapping: Optional[Dict[str, str]] = None, ) -> TextDocumentWithLabeledSpansAndBinaryRelations: if document.text == document.text_pair: @@ -161,11 +162,14 @@ def construct_text_document_from_text_pair_coref_document( sorted(old2new_spans.values(), key=lambda s: (s.start, s.end, s.label)) ) for old_rel in document.binary_coref_relations: - label = old_rel.label + label = old_rel.label if old_rel.score > 0.0 else no_relation_label if relation_label_mapping is not None: label = relation_label_mapping.get(label, label) new_rel = old_rel.copy( - head=old2new_spans[old_rel.head], tail=old2new_spans[old_rel.tail], label=label + head=old2new_spans[old_rel.head], + tail=old2new_spans[old_rel.tail], + label=label, + score=1.0, ) new_doc.binary_relations.append(new_rel) diff --git a/tests/document/processing/test_text_pair.py b/tests/document/processing/test_text_pair.py index a9c589545..cd016a117 100644 --- a/tests/document/processing/test_text_pair.py +++ b/tests/document/processing/test_text_pair.py @@ -282,7 +282,10 @@ def test_construct_text_document_from_text_pair_coref_document(positive_and_nega glue_text = "" docs = [ construct_text_document_from_text_pair_coref_document( - doc, glue_text=glue_text, relation_label_mapping={"coref": "semantically_same"} + doc, + glue_text=glue_text, + no_relation_label="no_relation", + relation_label_mapping={"coref": "semantically_same"}, ) for doc in positive_and_negative_documents ] @@ -302,9 +305,9 @@ def test_construct_text_document_from_text_pair_coref_document(positive_and_nega ("ANIMAL", "his cat"), ] assert doc.binary_relations.resolve() == [ - ("semantically_same", (("PERSON", "she"), ("PERSON", "Bob"))) + ("no_relation", (("PERSON", "she"), ("PERSON", "Bob"))) ] - assert [rel.score for rel in doc.binary_relations] == [0.0] + assert [rel.score for rel in doc.binary_relations] == [1.0] doc = docs[7] assert doc.text == "Bob loves his cat.She sleeps a lot."