Skip to content

Commit

Permalink
move adding reversed relations to encode_target()
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 13, 2024
1 parent e597f5c commit 140b026
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions src/pie_modules/taskmodules/pointer_network_for_end2end_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,16 +504,11 @@ def encode_annotations(
if self.labels_per_layer is None:
raise Exception("labels_per_layer is not defined. Call prepare() first or pass it in.")

relations = list(layers[self.relation_layer_name])
if self.add_reversed_relations:
reversed_relations = [self.reverse_relation(rel) for rel in relations]
relations.extend(reversed_relations)

# encode relations
all_relation_arguments = set()
relation_arguments2label: Dict[Tuple[Annotation, ...], str] = dict()
relation_encodings = dict()
for rel in relations:
for rel in layers[self.relation_layer_name]:
if not isinstance(rel, BinaryRelation):
raise Exception(f"expected BinaryRelation, but got: {rel}")
if rel.label in self.labels_per_layer[self.relation_layer_name]:
Expand Down Expand Up @@ -840,6 +835,13 @@ def encode_target(self, task_encoding: TaskEncodingType) -> Optional[TargetEncod
layer_name: self.get_mapped_layer(document, layer_name=layer_name)
for layer_name in self.layer_names
}

if self.add_reversed_relations:
reversed_relations = [
self.reverse_relation(rel) for rel in layers[self.relation_layer_name]
]
layers[self.relation_layer_name].extend(reversed_relations)

result = self.encode_annotations(
layers=layers,
metadata={
Expand Down

0 comments on commit 140b026

Please sign in to comment.