Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PointerNetworkTaskModuleForEnd2EndRE: add reversed relations #144

Merged
merged 20 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 66 additions & 2 deletions src/pie_modules/taskmodules/pointer_network_for_end2end_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class PointerNetworkTaskModuleForEnd2EndRE(
],
):
PREPARED_ATTRIBUTES = ["labels_per_layer"]
REVERSED_RELATION_LABEL_SUFFIX = "_reversed"

def __init__(
self,
Expand All @@ -114,6 +115,8 @@ def __init__(
document_type: str = "pytorch_ie.documents.TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions",
tokenized_document_type: str = "pie_modules.documents.TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions",
relation_layer_name: str = "binary_relations",
add_reversed_relations: bool = False,
symmetric_relations: Optional[List[str]] = None,
none_label: str = "none",
loop_dummy_relation_name: str = "loop",
constrained_generation: bool = False,
Expand Down Expand Up @@ -167,6 +170,8 @@ def __init__(
self.span_layer_name = annotation_field_mapping_inv.get(
relation_layer_target, relation_layer_target
)
self.add_reversed_relations = add_reversed_relations
self.symmetric_relations = set(symmetric_relations or [])
self.none_label = none_label
self.loop_dummy_relation_name = loop_dummy_relation_name
self.constrained_generation = constrained_generation
Expand Down Expand Up @@ -263,6 +268,18 @@ def _prefix_allowed_tokens_fn_with_maximum(
# convert to a list
return allowed_indices.tolist()

def add_reversed_relation_labels(self, relation_labels: Iterable[str]) -> Set[str]:
result = set(relation_labels)
for rel_label in set(relation_labels):
if rel_label not in self.symmetric_relations:
reversed_label = rel_label + self.REVERSED_RELATION_LABEL_SUFFIX
if reversed_label in result:
raise ValueError(
f"reversed relation label {reversed_label} already exists in relation layer labels"
)
result.add(reversed_label)
return result

def _prepare(self, documents: Sequence[DocumentType]) -> None:
# collect all labels
labels: Dict[str, Set[str]] = {layer_name: set() for layer_name in self.layer_names}
Expand All @@ -273,6 +290,11 @@ def _prepare(self, documents: Sequence[DocumentType]) -> None:
ac.label for ac in doc[layer_name] if ac.label not in exclude_labels
)

if self.add_reversed_relations:
labels[self.relation_layer_name] = self.add_reversed_relation_labels(
relation_labels=labels[self.relation_layer_name]
)

self.labels_per_layer = {
# sort labels to ensure deterministic order
layer_name: sorted(labels)
Expand Down Expand Up @@ -443,8 +465,38 @@ def decode_relations(

return encodings, dict(errors), current_encoding

def reverse_relation(self, relation: Annotation) -> BinaryRelation:
if isinstance(relation, BinaryRelation):
reversed_label = relation.label
if (
reversed_label not in self.symmetric_relations
and reversed_label != self.none_label
):
reversed_label += self.REVERSED_RELATION_LABEL_SUFFIX
reversed_rel = relation.copy(
head=relation.tail, tail=relation.head, label=reversed_label
)
return reversed_rel
else:
raise Exception(f"reversing of relations of type {type(relation)} is not supported")

def unreverse_relation(self, relation: Annotation) -> BinaryRelation:
if isinstance(relation, BinaryRelation):
head, tail, label = relation.head, relation.tail, relation.label
# if the relation is symmetric, we sort head and tail to ensure consistent order
if relation.label in self.symmetric_relations:
head, tail = sorted([head, tail], key=lambda x: (x.start, x.end))
# if the relation was reversed, we need to reconstruct the original label and swap head and tail
elif label.endswith(self.REVERSED_RELATION_LABEL_SUFFIX):
# reconstruct the original label and swap head and tail
label = label[: -len(self.REVERSED_RELATION_LABEL_SUFFIX)]
head, tail = tail, head
return relation.copy(head=head, tail=tail, label=label)
else:
raise Exception(f"un-reversing of relations of type {type(relation)} is not supported")

def encode_annotations(
self, layers: Dict[str, List[Annotation]], metadata: Optional[Dict[str, Any]] = None
self, layers: Dict[str, Iterable[Annotation]], metadata: Optional[Dict[str, Any]] = None
) -> TaskOutputType:
if not set(layers.keys()) == set(self.layer_names):
raise Exception(f"unexpected layers: {layers.keys()}. expected: {self.layer_names}")
Expand Down Expand Up @@ -783,6 +835,13 @@ def encode_target(self, task_encoding: TaskEncodingType) -> Optional[TargetEncod
layer_name: self.get_mapped_layer(document, layer_name=layer_name)
for layer_name in self.layer_names
}

if self.add_reversed_relations:
# create a copy to avoid modifying the annotation layer in the document
relations = list(layers[self.relation_layer_name])
reversed_relations = [self.reverse_relation(rel) for rel in relations]
layers[self.relation_layer_name] = relations + reversed_relations

result = self.encode_annotations(
layers=layers,
metadata={
Expand Down Expand Up @@ -856,4 +915,9 @@ def create_annotations_from_output(
for layer_name in layers:
annotations = self.get_mapped_layer(untokenized_document, layer_name=layer_name)
for annotation in annotations:
yield layer_name, annotation.copy()
# handle relations that may be reversed
if layer_name == self.relation_layer_name and self.add_reversed_relations:
unreversed_relation = self.unreverse_relation(annotation)
yield layer_name, unreversed_relation
else:
yield layer_name, annotation.copy()
Loading
Loading