Skip to content

Commit

Permalink
Merge pull request #144 from ArneBinder/pointer_re_tm/add_reversed_re…
Browse files Browse the repository at this point in the history
…lations

`PointerNetworkTaskModuleForEnd2EndRE`: add reversed relations
  • Loading branch information
ArneBinder authored Nov 13, 2024
2 parents c481fb4 + 2d6d1d0 commit d1ad3a9
Show file tree
Hide file tree
Showing 2 changed files with 380 additions and 15 deletions.
68 changes: 66 additions & 2 deletions src/pie_modules/taskmodules/pointer_network_for_end2end_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class PointerNetworkTaskModuleForEnd2EndRE(
],
):
PREPARED_ATTRIBUTES = ["labels_per_layer"]
REVERSED_RELATION_LABEL_SUFFIX = "_reversed"

def __init__(
self,
Expand All @@ -114,6 +115,8 @@ def __init__(
document_type: str = "pytorch_ie.documents.TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions",
tokenized_document_type: str = "pie_modules.documents.TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions",
relation_layer_name: str = "binary_relations",
add_reversed_relations: bool = False,
symmetric_relations: Optional[List[str]] = None,
none_label: str = "none",
loop_dummy_relation_name: str = "loop",
constrained_generation: bool = False,
Expand Down Expand Up @@ -167,6 +170,8 @@ def __init__(
self.span_layer_name = annotation_field_mapping_inv.get(
relation_layer_target, relation_layer_target
)
self.add_reversed_relations = add_reversed_relations
self.symmetric_relations = set(symmetric_relations or [])
self.none_label = none_label
self.loop_dummy_relation_name = loop_dummy_relation_name
self.constrained_generation = constrained_generation
Expand Down Expand Up @@ -263,6 +268,18 @@ def _prefix_allowed_tokens_fn_with_maximum(
# convert to a list
return allowed_indices.tolist()

def add_reversed_relation_labels(self, relation_labels: Iterable[str]) -> Set[str]:
result = set(relation_labels)
for rel_label in set(relation_labels):
if rel_label not in self.symmetric_relations:
reversed_label = rel_label + self.REVERSED_RELATION_LABEL_SUFFIX
if reversed_label in result:
raise ValueError(
f"reversed relation label {reversed_label} already exists in relation layer labels"
)
result.add(reversed_label)
return result

def _prepare(self, documents: Sequence[DocumentType]) -> None:
# collect all labels
labels: Dict[str, Set[str]] = {layer_name: set() for layer_name in self.layer_names}
Expand All @@ -273,6 +290,11 @@ def _prepare(self, documents: Sequence[DocumentType]) -> None:
ac.label for ac in doc[layer_name] if ac.label not in exclude_labels
)

if self.add_reversed_relations:
labels[self.relation_layer_name] = self.add_reversed_relation_labels(
relation_labels=labels[self.relation_layer_name]
)

self.labels_per_layer = {
# sort labels to ensure deterministic order
layer_name: sorted(labels)
Expand Down Expand Up @@ -443,8 +465,38 @@ def decode_relations(

return encodings, dict(errors), current_encoding

def reverse_relation(self, relation: Annotation) -> BinaryRelation:
if isinstance(relation, BinaryRelation):
reversed_label = relation.label
if (
reversed_label not in self.symmetric_relations
and reversed_label != self.none_label
):
reversed_label += self.REVERSED_RELATION_LABEL_SUFFIX
reversed_rel = relation.copy(
head=relation.tail, tail=relation.head, label=reversed_label
)
return reversed_rel
else:
raise Exception(f"reversing of relations of type {type(relation)} is not supported")

def unreverse_relation(self, relation: Annotation) -> BinaryRelation:
if isinstance(relation, BinaryRelation):
head, tail, label = relation.head, relation.tail, relation.label
# if the relation is symmetric, we sort head and tail to ensure consistent order
if relation.label in self.symmetric_relations:
head, tail = sorted([head, tail], key=lambda x: (x.start, x.end))
# if the relation was reversed, we need to reconstruct the original label and swap head and tail
elif label.endswith(self.REVERSED_RELATION_LABEL_SUFFIX):
# reconstruct the original label and swap head and tail
label = label[: -len(self.REVERSED_RELATION_LABEL_SUFFIX)]
head, tail = tail, head
return relation.copy(head=head, tail=tail, label=label)
else:
raise Exception(f"un-reversing of relations of type {type(relation)} is not supported")

def encode_annotations(
self, layers: Dict[str, List[Annotation]], metadata: Optional[Dict[str, Any]] = None
self, layers: Dict[str, Iterable[Annotation]], metadata: Optional[Dict[str, Any]] = None
) -> TaskOutputType:
if not set(layers.keys()) == set(self.layer_names):
raise Exception(f"unexpected layers: {layers.keys()}. expected: {self.layer_names}")
Expand Down Expand Up @@ -783,6 +835,13 @@ def encode_target(self, task_encoding: TaskEncodingType) -> Optional[TargetEncod
layer_name: self.get_mapped_layer(document, layer_name=layer_name)
for layer_name in self.layer_names
}

if self.add_reversed_relations:
# create a copy to avoid modifying the annotation layer in the document
relations = list(layers[self.relation_layer_name])
reversed_relations = [self.reverse_relation(rel) for rel in relations]
layers[self.relation_layer_name] = relations + reversed_relations

result = self.encode_annotations(
layers=layers,
metadata={
Expand Down Expand Up @@ -856,4 +915,9 @@ def create_annotations_from_output(
for layer_name in layers:
annotations = self.get_mapped_layer(untokenized_document, layer_name=layer_name)
for annotation in annotations:
yield layer_name, annotation.copy()
# handle relations that may be reversed
if layer_name == self.relation_layer_name and self.add_reversed_relations:
unreversed_relation = self.unreverse_relation(annotation)
yield layer_name, unreversed_relation
else:
yield layer_name, annotation.copy()
Loading

0 comments on commit d1ad3a9

Please sign in to comment.