Skip to content

Commit

Permalink
add parameters add_reversed_relations and symmetric_relations to Poin…
Browse files Browse the repository at this point in the history
…terNetworkTaskModuleForEnd2EndRE
  • Loading branch information
ArneBinder committed Nov 13, 2024
1 parent a440655 commit 6c3d1a8
Showing 1 changed file with 49 additions and 1 deletion.
50 changes: 49 additions & 1 deletion src/pie_modules/taskmodules/pointer_network_for_end2end_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ class PointerNetworkTaskModuleForEnd2EndRE(
],
):
PREPARED_ATTRIBUTES = ["labels_per_layer"]
REVERSED_RELATION_LABEL_SUFFIX = "_reversed"

def __init__(
self,
Expand All @@ -114,6 +115,8 @@ def __init__(
document_type: str = "pytorch_ie.documents.TextDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions",
tokenized_document_type: str = "pie_modules.documents.TokenDocumentWithLabeledSpansBinaryRelationsAndLabeledPartitions",
relation_layer_name: str = "binary_relations",
add_reversed_relations: bool = False,
symmetric_relations: Optional[List[str]] = None,
none_label: str = "none",
loop_dummy_relation_name: str = "loop",
constrained_generation: bool = False,
Expand Down Expand Up @@ -167,6 +170,8 @@ def __init__(
self.span_layer_name = annotation_field_mapping_inv.get(
relation_layer_target, relation_layer_target
)
self.add_reversed_relations = add_reversed_relations
self.symmetric_relations = set(symmetric_relations or [])
self.none_label = none_label
self.loop_dummy_relation_name = loop_dummy_relation_name
self.constrained_generation = constrained_generation
Expand Down Expand Up @@ -454,6 +459,7 @@ def encode_annotations(

# encode relations
all_relation_arguments = set()
relation_arguments2label = dict()
relation_encodings = dict()
for rel in layers[self.relation_layer_name]:
if not isinstance(rel, BinaryRelation):
Expand All @@ -466,6 +472,28 @@ def encode_annotations(
raise Exception(f"failed to encode relation: {rel}")
relation_encodings[rel] = encoded_relation
all_relation_arguments.update([rel.head, rel.tail])
relation_arguments2label[(rel.head, rel.tail)] = rel.label
if self.add_reversed_relations:
reversed_label = rel.label
if reversed_label not in self.symmetric_relations:
reversed_label += self.REVERSED_RELATION_LABEL_SUFFIX
if (rel.tail, rel.head) in relation_arguments2label:
previous_label = relation_arguments2label[(rel.tail, rel.head)]
logger.warning(
f"reversed relation already exists: {rel.tail} -> {rel.head} with "
f"label {previous_label} (reversed label: {reversed_label}). Skipping."
)
continue
reversed_rel = BinaryRelation(
head=rel.tail,
tail=rel.head,
label=rel.label + self.REVERSED_RELATION_LABEL_SUFFIX,
)
encoded_reversed_rel = self.relation_encoder_decoder.encode(
annotation=reversed_rel, metadata=metadata
)
if encoded_reversed_rel is not None:
relation_encodings[reversed_rel] = encoded_reversed_rel

# encode spans that are not arguments of any relation
no_relation_spans = [
Expand Down Expand Up @@ -846,4 +874,24 @@ def create_annotations_from_output(
for layer_name in layers:
annotations = self.get_mapped_layer(untokenized_document, layer_name=layer_name)
for annotation in annotations:
yield layer_name, annotation.copy()
# handle relations that may be reversed
if layer_name == self.relation_layer_name and self.add_reversed_relations:
if not isinstance(annotation, BinaryRelation):
raise Exception(
f"expected BinaryRelation when handling reversed relations, but got: {annotation}"
)
head, tail = annotation.head, annotation.tail
label = annotation.label
# if the relation is symmetric, we sort head and tail to ensure consistent order
if annotation.label in self.symmetric_relations:
head, tail = sorted(
[annotation.head, annotation.tail], key=lambda x: (x.start, x.end)
)
# if the relation was reversed, we need to reconstruct the original label and swap head and tail
elif annotation.label.endswith(self.REVERSED_RELATION_LABEL_SUFFIX):
# reconstruct the original label and swap head and tail
label = annotation.label[: -len(self.REVERSED_RELATION_LABEL_SUFFIX)]
head, tail = tail, head
yield layer_name, annotation.copy(head=head, tail=tail, label=label)
else:
yield layer_name, annotation.copy()

0 comments on commit 6c3d1a8

Please sign in to comment.