Skip to content

Commit

Permalink
complete test_encode_with_add_reversed_relations_already_exists
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 13, 2024
1 parent b798844 commit ac92b42
Showing 1 changed file with 46 additions and 3 deletions.
49 changes: 46 additions & 3 deletions tests/taskmodules/test_pointer_network_for_end2end_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,13 +470,56 @@ def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations,
}


def test_encode_with_add_reversed_relations_already_exists(taskmodule_with_reversed_relations):
def test_encode_with_add_reversed_relations_already_exists(caplog):
doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.")
doc.entities.append(LabeledSpan(start=10, end=20, label="content"))
doc.entities.append(LabeledSpan(start=27, end=34, label="topic"))
rel = BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about")
doc.relations.append(
BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about")
)
doc.relations.append(
BinaryRelation(head=doc.entities[1], tail=doc.entities[0], label="is_about")
)

task_encodings = taskmodule_with_reversed_relations.encode(document, encode_target=True)
taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
tokenizer_name_or_path="facebook/bart-base",
relation_layer_name="relations",
annotation_field_mapping={
"entities": "labeled_spans",
"relations": "binary_relations",
},
add_reversed_relations=True,
symmetric_relations=["is_about"],
)
taskmodule.prepare(documents=[doc])

with caplog.at_level(logging.WARNING):
task_encodings = taskmodule.encode(doc, encode_target=True)
assert len(caplog.messages) == 0
assert len(task_encodings) == 1
task_encoding = task_encodings[0]

decoded_annotations, statistics = taskmodule.decode_annotations(task_encoding.targets)
assert decoded_annotations == {
"entities": [
LabeledSpan(start=4, end=6, label="content", score=1.0),
LabeledSpan(start=7, end=8, label="topic", score=1.0),
],
"relations": [
BinaryRelation(
head=LabeledSpan(start=4, end=6, label="content", score=1.0),
tail=LabeledSpan(start=7, end=8, label="topic", score=1.0),
label="is_about",
score=1.0,
),
BinaryRelation(
head=LabeledSpan(start=7, end=8, label="topic", score=1.0),
tail=LabeledSpan(start=4, end=6, label="content", score=1.0),
label="is_about",
score=1.0,
),
],
}


@pytest.fixture()
Expand Down

0 comments on commit ac92b42

Please sign in to comment.