Skip to content

Commit

Permalink
outsource reverse_relation()
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 13, 2024
1 parent 063eb79 commit 29ed676
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 77 deletions.
52 changes: 29 additions & 23 deletions src/pie_modules/taskmodules/pointer_network_for_end2end_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,21 @@ def decode_relations(

return encodings, dict(errors), current_encoding

def reverse_relation(self, relation: Annotation) -> BinaryRelation:
if isinstance(relation, BinaryRelation):
reversed_label = relation.label
if (
reversed_label not in self.symmetric_relations
and reversed_label != self.none_label
):
reversed_label += self.REVERSED_RELATION_LABEL_SUFFIX
reversed_rel = relation.copy(
head=relation.tail, tail=relation.head, label=reversed_label
)
return reversed_rel
else:
raise Exception(f"reversing of relations of type {type(relation)} is not supported")

def encode_annotations(
self, layers: Dict[str, List[Annotation]], metadata: Optional[Dict[str, Any]] = None
) -> TaskOutputType:
Expand All @@ -467,11 +482,15 @@ def encode_annotations(
if self.labels_per_layer is None:
raise Exception("labels_per_layer is not defined. Call prepare() first or pass it in.")

relations = list(layers[self.relation_layer_name])
if self.add_reversed_relations:
relations.extend(self.reverse_relation(rel) for rel in relations)

# encode relations
all_relation_arguments = set()
relation_arguments2label = dict()
relation_arguments2label: Dict[Tuple[Annotation, ...], str] = dict()
relation_encodings = dict()
for rel in layers[self.relation_layer_name]:
for rel in relations:
if not isinstance(rel, BinaryRelation):
raise Exception(f"expected BinaryRelation, but got: {rel}")
if rel.label in self.labels_per_layer[self.relation_layer_name]:
Expand All @@ -480,30 +499,17 @@ def encode_annotations(
)
if encoded_relation is None:
raise Exception(f"failed to encode relation: {rel}")
if (rel.head, rel.tail) in relation_arguments2label:
previous_label = relation_arguments2label[(rel.head, rel.tail)]
if previous_label != rel.label:
logger.warning(
f"relation {rel.head} -> {rel.tail} already exists, but has another label: "
f"{previous_label} (previous label: {rel.label}). Skipping."
)
continue
relation_encodings[rel] = encoded_relation
all_relation_arguments.update([rel.head, rel.tail])
relation_arguments2label[(rel.head, rel.tail)] = rel.label
if self.add_reversed_relations:
reversed_label = rel.label
if reversed_label not in self.symmetric_relations:
reversed_label += self.REVERSED_RELATION_LABEL_SUFFIX
if (rel.tail, rel.head) in relation_arguments2label:
previous_label = relation_arguments2label[(rel.tail, rel.head)]
logger.warning(
f"reversed relation already exists: {rel.tail} -> {rel.head} with "
f"label {previous_label} (reversed label: {reversed_label}). Skipping."
)
continue
reversed_rel = BinaryRelation(
head=rel.tail,
tail=rel.head,
label=reversed_label,
)
encoded_reversed_rel = self.relation_encoder_decoder.encode(
annotation=reversed_rel, metadata=metadata
)
if encoded_reversed_rel is not None:
relation_encodings[reversed_rel] = encoded_reversed_rel

# encode spans that are not arguments of any relation
no_relation_spans = [
Expand Down
83 changes: 29 additions & 54 deletions tests/taskmodules/test_pointer_network_for_end2end_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@ def config(config_str):
return CONFIG_DICT[config_str]


@dataclass
class ExampleDocument(TextBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
sentences: AnnotationList[LabeledSpan] = annotation_field(target="text")


@pytest.fixture(scope="module")
def document():
@dataclass
class ExampleDocument(TextBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="entities")
sentences: AnnotationList[LabeledSpan] = annotation_field(target="text")

doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.")
span1 = LabeledSpan(start=10, end=20, label="content")
span2 = LabeledSpan(start=27, end=34, label="topic")
Expand Down Expand Up @@ -256,6 +257,19 @@ def task_encoding_without_target(taskmodule, document):
return taskmodule.encode_input(document)[0]


def test_reverse_relation(taskmodule, document):
assert document.relations[0].resolve() == (
"is_about",
(("content", "dummy text"), ("topic", "nothing")),
)

reversed_relation = taskmodule.reverse_relation(relation=document.relations[0])
assert reversed_relation.resolve() == (
"is_about_reversed",
(("topic", "nothing"), ("content", "dummy text")),
)


@pytest.fixture(params=[False, True])
def taskmodule_with_reversed_relations(document, request) -> PointerNetworkTaskModuleForEnd2EndRE:
is_about_is_symmetric = request.param
Expand Down Expand Up @@ -321,30 +335,6 @@ def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations,
"</s>",
]
if not taskmodule_with_reversed_relations.symmetric_relations:
assert task_encoding.targets.labels == [
15,
15,
5,
12,
13,
3,
6,
12,
13,
3,
15,
15,
5,
7,
18,
18,
4,
2,
2,
2,
2,
1,
]
decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations(
task_encoding.targets
)
Expand All @@ -370,30 +360,6 @@ def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations,
],
}
else:
assert task_encoding.targets.labels == [
14,
14,
5,
11,
12,
3,
6,
11,
12,
3,
14,
14,
5,
6,
17,
17,
4,
2,
2,
2,
2,
1,
]
decoded_annotations, statistics = taskmodule_with_reversed_relations.decode_annotations(
task_encoding.targets
)
Expand All @@ -420,6 +386,15 @@ def test_encode_with_add_reversed_relations(taskmodule_with_reversed_relations,
}


def test_encode_with_add_reversed_relations_already_exists(taskmodule_with_reversed_relations):
doc = ExampleDocument(text="This is a dummy text about nothing. Trust me.")
doc.entities.append(LabeledSpan(start=10, end=20, label="content"))
doc.entities.append(LabeledSpan(start=27, end=34, label="topic"))
rel = BinaryRelation(head=doc.entities[0], tail=doc.entities[1], label="is_about")

task_encodings = taskmodule_with_reversed_relations.encode(document, encode_target=True)


@pytest.fixture()
def target_encoding(taskmodule, task_encoding_without_target):
return taskmodule.encode_target(task_encoding_without_target)
Expand Down

0 comments on commit 29ed676

Please sign in to comment.