Skip to content

Commit

Permalink
Merge pull request #146 from ArneBinder/pointer_re_tm/skip_duplicated…
Browse files Browse the repository at this point in the history
…_rels

`PointerNetworkTaskModuleForEnd2EndRE`: skip duplicated / conflicting relations druing encoding
  • Loading branch information
ArneBinder authored Nov 13, 2024
2 parents a440655 + 8d4b1e3 commit c481fb4
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 6 deletions.
10 changes: 10 additions & 0 deletions src/pie_modules/taskmodules/pointer_network_for_end2end_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,18 +454,28 @@ def encode_annotations(

# encode relations
all_relation_arguments = set()
relation_arguments2label: Dict[Tuple[Annotation, ...], str] = dict()
relation_encodings = dict()
for rel in layers[self.relation_layer_name]:
if not isinstance(rel, BinaryRelation):
raise Exception(f"expected BinaryRelation, but got: {rel}")
if rel.label in self.labels_per_layer[self.relation_layer_name]:
if (rel.head, rel.tail) in relation_arguments2label:
previous_label = relation_arguments2label[(rel.head, rel.tail)]
if previous_label != rel.label:
raise ValueError(
f"relation {rel.head} -> {rel.tail} already exists, but has another label: "
f"{previous_label} (current label: {rel.label})."
)
continue
encoded_relation = self.relation_encoder_decoder.encode(
annotation=rel, metadata=metadata
)
if encoded_relation is None:
raise Exception(f"failed to encode relation: {rel}")
relation_encodings[rel] = encoded_relation
all_relation_arguments.update([rel.head, rel.tail])
relation_arguments2label[(rel.head, rel.tail)] = rel.label

# encode spans that are not arguments of any relation
no_relation_spans = [
Expand Down
126 changes: 120 additions & 6 deletions tests/taskmodules/test_pointer_network_for_end2end_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ def config(config_str):
return CONFIG_DICT[config_str]


@dataclass
class ExampleDocument(TextBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
sentences: AnnotationList[LabeledSpan] = annotation_field(target="text")


@pytest.fixture(scope="module")
def document():
@dataclass
class ExampleDocument(TextBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
sentences: AnnotationList[LabeledSpan] = annotation_field(target="text")

doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.")
span1 = LabeledSpan(start=10, end=20, label="content")
span2 = LabeledSpan(start=27, end=34, label="topic")
Expand Down Expand Up @@ -317,6 +318,119 @@ def test_target_encoding(target_encoding, taskmodule):
raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}")


def test_task_encoding_with_deduplicated_relations(caplog):
doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.")
doc.entities.append(LabeledSpan(start=10, end=20, label="content"))
doc.entities.append(LabeledSpan(start=27, end=34, label="topic"))
doc.entities.append(LabeledSpan(start=42, end=44, label="person"))
assert doc.entities.resolve() == [
("content", "dummy text"),
("topic", "nothing"),
("person", "me"),
]
# add the same relation twice (just use a different score, but that should not matter)
doc.relations.append(
BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about")
)
doc.relations.append(
BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about", score=0.9)
)
assert doc.relations.resolve() == [
("is_about", (("content", "dummy text"), ("topic", "nothing"))),
("is_about", (("content", "dummy text"), ("topic", "nothing"))),
]
taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
tokenizer_name_or_path="facebook/bart-base",
relation_layer_name="relations",
annotation_field_mapping={
"entities": "labeled_spans",
"relations": "binary_relations",
},
)
taskmodule.prepare(documents=[doc])
caplog.clear()
with caplog.at_level(logging.WARNING):
task_encodings = taskmodule.encode(doc, encode_target=True)
messages = list(caplog.messages)

assert len(task_encodings) == 1
decoded_annotations, statistics = taskmodule.decode_annotations(task_encodings[0].targets)
assert decoded_annotations == {
"entities": [
LabeledSpan(start=4, end=6, label="content", score=1.0),
LabeledSpan(start=7, end=8, label="topic", score=1.0),
LabeledSpan(start=10, end=11, label="person", score=1.0),
],
"relations": [
BinaryRelation(
head=LabeledSpan(start=4, end=6, label="content", score=1.0),
tail=LabeledSpan(start=7, end=8, label="topic", score=1.0),
label="is_about",
score=1.0,
)
],
}

assert messages == [
(
"encoding errors: {'correct': 2}, skipped annotations:\n"
"{\n"
' "relations": [\n'
' "BinaryRelation('
"head=LabeledSpan(start=4, end=6, label='content', score=1.0), "
"tail=LabeledSpan(start=7, end=8, label='topic', score=1.0), "
"label='is_about', score=0.9"
')"\n'
" ]\n"
"}"
)
]


def test_task_encoding_with_conflicting_relations(caplog):
doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.")
doc.entities.append(LabeledSpan(start=10, end=20, label="content"))
doc.entities.append(LabeledSpan(start=27, end=34, label="topic"))
doc.entities.append(LabeledSpan(start=42, end=44, label="person"))
assert doc.entities.resolve() == [
("content", "dummy text"),
("topic", "nothing"),
("person", "me"),
]
# add two relations with the same head and tail, but different labels
doc.relations.append(
BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about")
)
doc.relations.append(
BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="wrong_relation")
)
assert doc.relations.resolve() == [
("is_about", (("content", "dummy text"), ("topic", "nothing"))),
("wrong_relation", (("content", "dummy text"), ("topic", "nothing"))),
]
taskmodule = PointerNetworkTaskModuleForEnd2EndRE(
tokenizer_name_or_path="facebook/bart-base",
relation_layer_name="relations",
annotation_field_mapping={
"entities": "labeled_spans",
"relations": "binary_relations",
},
)
taskmodule.prepare(documents=[doc])
caplog.clear()
with caplog.at_level(logging.ERROR):
task_encodings = taskmodule.encode(doc, encode_target=True)
messages = list(caplog.messages)

assert len(task_encodings) == 0

assert messages == [
"failed to encode target, it will be skipped: "
"relation ('Ġdummy', 'Ġtext') -> ('Ġnothing',) already exists, but has "
"another label: is_about (current label: wrong_relation)."
]


@pytest.fixture()
def task_encoding(task_encoding_without_target, target_encoding):
task_encoding_without_target.targets = target_encoding
Expand Down

0 comments on commit c481fb4

Please sign in to comment.