From e9acb06154185bc6291a75526906c4b01a0ed5d6 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Wed, 13 Nov 2024 15:41:20 +0100 Subject: [PATCH] fix _prepare --- .../taskmodules/pointer_network_for_end2end_re.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index f3108b3b1..6decf4612 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -280,14 +280,13 @@ def _prepare(self, documents: Sequence[DocumentType]) -> None: if self.add_reversed_relations: for rel_label in set(labels[self.relation_layer_name]): - reversed_label = rel_label if rel_label not in self.symmetric_relations: - reversed_label += self.REVERSED_RELATION_LABEL_SUFFIX - if reversed_label in labels[self.relation_layer_name]: - raise ValueError( - f"reversed relation label {reversed_label} already exists in relation layer labels" - ) - labels[self.relation_layer_name].add(reversed_label) + reversed_label = rel_label + self.REVERSED_RELATION_LABEL_SUFFIX + if reversed_label in labels[self.relation_layer_name]: + raise ValueError( + f"reversed relation label {reversed_label} already exists in relation layer labels" + ) + labels[self.relation_layer_name].add(reversed_label) self.labels_per_layer = { # sort labels to ensure deterministic order