Skip to content

Commit

Permalink
add test_encode_with_add_reversed_relations
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 13, 2024
1 parent c2ad563 commit 063eb79
Showing 1 changed file with 158 additions and 13 deletions.
171 changes: 158 additions & 13 deletions tests/taskmodules/test_pointer_network_for_end2end_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,23 +256,168 @@ def task_encoding_without_target(taskmodule, document):
return taskmodule.encode_input(document)[0]


def test_input_encoding(task_encoding_without_target, taskmodule):
assert task_encoding_without_target is not None
tokens = taskmodule.tokenizer.convert_ids_to_tokens(
task_encoding_without_target.inputs.input_ids
@pytest.fixture(params=[False, True])
def taskmodule_with_reversed_relations(document, request) -> PointerNetworkTaskModuleForEnd2EndRE:
is_about_is_symmetric = request.param
taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
tokenizer_name_or_path="facebook/bart-base",
relation_layer_name="relations",
exclude_labels_per_layer={"relations": ["no_relation"]},
annotation_field_mapping={
"entities": "labeled_spans",
"relations": "binary_relations",
},
create_constraints=True,
tokenizer_kwargs={"strict_span_conversion": False},
add_reversed_relations=True,
symmetric_relations=["is_about"] if is_about_is_symmetric else None,
)
if taskmodule.partition_layer_name is None:
assert asdict(task_encoding_without_target.inputs) == {
"input_ids": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 3101, 162, 4, 2],
"attention_mask": [1] * 13,

taskmodule.prepare(documents=[document])
assert taskmodule.is_prepared
if is_about_is_symmetric:
assert taskmodule.prepared_attributes == {
"labels_per_layer": {
"entities": ["content", "person", "topic"],
"relations": ["is_about"],
}
}
elif taskmodule.partition_layer_name == "sentences":
assert asdict(task_encoding_without_target.inputs) == {
"input_ids": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 2],
"attention_mask": [1] * 10,
else:
assert taskmodule.prepared_attributes == {
"labels_per_layer": {
"entities": ["content", "person", "topic"],
"relations": ["is_about", "is_about_reversed"],
}
}

return taskmodule


def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations, document):
task_encodings = taskmodule_with_reversed_relations.encode(document, encode_target=True)
assert len(task_encodings) == 1
task_encoding = task_encodings[0]
assert task_encoding is not None
assert asdict(task_encoding.inputs) == {
"input_ids": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 3101, 162, 4, 2],
"attention_mask": [1] * 13,
}
tokens = taskmodule_with_reversed_relations.tokenizer.convert_ids_to_tokens(
task_encoding.inputs.input_ids
)
assert tokens == [
"<s>",
"This",
"Ġis",
"Ġa",
"Ġdummy",
"Ġtext",
"Ġabout",
"Ġnothing",
".",
"ĠTrust",
"Ġme",
".",
"</s>",
]
if not taskmodule_with_reversed_relations.symmetric_relations:
assert task_encoding.targets.labels == [
15,
15,
5,
12,
13,
3,
6,
12,
13,
3,
15,
15,
5,
7,
18,
18,
4,
2,
2,
2,
2,
1,
]
decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations(
task_encoding.targets
)
assert decoded_annotations == {
"entities": [
LabeledSpan(start=4, end=6, label="content", score=1.0),
LabeledSpan(start=7, end=8, label="topic", score=1.0),
LabeledSpan(start=10, end=11, label="person", score=1.0),
],
"relations": [
BinaryRelation(
head=LabeledSpan(start=4, end=6, label="content", score=1.0),
tail=LabeledSpan(start=7, end=8, label="topic", score=1.0),
label="is_about",
score=1.0,
),
BinaryRelation(
head=LabeledSpan(start=7, end=8, label="topic", score=1.0),
tail=LabeledSpan(start=4, end=6, label="content", score=1.0),
label="is_about_reversed",
score=1.0,
),
],
}
else:
raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}")
assert task_encoding.targets.labels == [
14,
14,
5,
11,
12,
3,
6,
11,
12,
3,
14,
14,
5,
6,
17,
17,
4,
2,
2,
2,
2,
1,
]
decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations(
task_encoding.targets
)
assert decoded_annotations == {
"entities": [
LabeledSpan(start=4, end=6, label="content", score=1.0),
LabeledSpan(start=7, end=8, label="topic", score=1.0),
LabeledSpan(start=10, end=11, label="person", score=1.0),
],
"relations": [
BinaryRelation(
head=LabeledSpan(start=4, end=6, label="content", score=1.0),
tail=LabeledSpan(start=7, end=8, label="topic", score=1.0),
label="is_about",
score=1.0,
),
BinaryRelation(
head=LabeledSpan(start=7, end=8, label="topic", score=1.0),
tail=LabeledSpan(start=4, end=6, label="content", score=1.0),
label="is_about",
score=1.0,
),
],
}


@pytest.fixture()
Expand Down

0 comments on commit 063eb79

Please sign in to comment.