Skip to content

Commit

Permalink
add test_decode_with_add_reversed_relations
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 13, 2024
1 parent ac92b42 commit 2a9b2d6
Showing 1 changed file with 51 additions and 0 deletions.
51 changes: 51 additions & 0 deletions tests/taskmodules/test_pointer_network_for_end2end_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,57 @@ def test_encode_with_add_reversed_relations_already_exists(caplog):
}


def test_decode_with_add_reversed_relations():
doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.")
doc.entities.append(LabeledSpan(start=10, end=20, label="content"))
doc.entities.append(LabeledSpan(start=27, end=34, label="topic"))
doc.relations.append(
BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about")
)

taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
tokenizer_name_or_path="facebook/bart-base",
relation_layer_name="relations",
annotation_field_mapping={
"entities": "labeled_spans",
"relations": "binary_relations",
},
add_reversed_relations=True,
)
taskmodule.prepare(documents=[doc])

task_encodings = taskmodule.encode(doc, encode_target=True)
assert len(task_encodings) == 1
decoded_annotations, statistics = taskmodule.decode_annotations(task_encodings[0].targets)
assert decoded_annotations == {
"entities": [
LabeledSpan(start=4, end=6, label="content", score=1.0),
LabeledSpan(start=7, end=8, label="topic", score=1.0),
],
"relations": [
BinaryRelation(
head=LabeledSpan(start=4, end=6, label="content", score=1.0),
tail=LabeledSpan(start=7, end=8, label="topic", score=1.0),
label="is_about",
score=1.0,
),
BinaryRelation(
head=LabeledSpan(start=7, end=8, label="topic", score=1.0),
tail=LabeledSpan(start=4, end=6, label="content", score=1.0),
label="is_about_reversed",
score=1.0,
),
],
}

task_outputs = [task_encoding.targets for task_encoding in task_encodings]
docs_with_predictions = taskmodule.decode(task_encodings, task_outputs)
assert len(docs_with_predictions) == 1
doc_with_predictions: ExampleDocument = docs_with_predictions[0]
assert list(doc_with_predictions.entities.predictions) == list(doc_with_predictions.entities)
assert list(doc_with_predictions.relations.predictions) == list(doc_with_predictions.relations)


@pytest.fixture()
def target_encoding(taskmodule, task_encoding_without_target):
return taskmodule.encode_target(task_encoding_without_target)
Expand Down

0 comments on commit 2a9b2d6

Please sign in to comment.