Skip to content

Commit

Permalink
simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Feb 9, 2024
1 parent 08b0def commit ec19386
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 151 deletions.
224 changes: 103 additions & 121 deletions src/serializer/brat.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
from collections import defaultdict
from typing import DefaultDict, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, TypeVar, Union

from pie_datasets.core.dataset_dict import METADATA_FILE_NAME
from pie_modules.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
Expand All @@ -20,148 +20,128 @@


def serialize_labeled_span(
span_id: str,
labeled_span: Union[LabeledMultiSpan, LabeledSpan],
idx: int,
annotation: LabeledSpan,
label_prefix: Optional[str] = None,
) -> str:
) -> Tuple[str, str]:
"""Serialize a labeled span into a string representation.
Args:
span_id (str): The identifier for the labeled span.
labeled_span (Union[LabeledMultiSpan, LabeledSpan]): The labeled span object to serialize.
It can be either a LabeledMultiSpan or a LabeledSpan.
label_prefix (Optional[str], optional): A prefix to be added to the label.
Defaults to None.
idx (str): The index for the labeled span.
annotation (LabeledSpan): The labeled span object to serialize.
label_prefix (Optional[str], optional): A prefix to be added to the label. Defaults to None.
Returns:
str: The serialized representation of the labeled_span.
Raises:
Warning: If the labeled_span has an unknown type.
str: The id and serialized representation of the labeled span.
"""
label = labeled_span.label if label_prefix is None else f"{label_prefix}-{labeled_span.label}"
if isinstance(labeled_span, LabeledMultiSpan):
locations = []
texts = []
for slice in labeled_span.slices:
start_idx = slice[0]
end_idx = slice[1]
texts.append(labeled_span.target[start_idx:end_idx])
locations.append(f"{start_idx} {end_idx}")
location = ";".join(locations)
text = " ".join(texts)
serialized_labeled_span = f"{span_id}\t{label} {location}\t{text}\n"
elif isinstance(labeled_span, LabeledSpan):
start_idx = labeled_span.start
end_idx = labeled_span.end
entity_text = labeled_span.target[start_idx:end_idx]
serialized_labeled_span = f"{span_id}\t{label} {start_idx} {end_idx}\t{entity_text}\n"
else:
raise Warning(f"labeled span has unknown type: {type(labeled_span)}")
return serialized_labeled_span
label = annotation.label if label_prefix is None else f"{label_prefix}-{annotation.label}"
start_idx = annotation.start
end_idx = annotation.end
entity_text = annotation.target[start_idx:end_idx]
serialized_labeled_span = f"{label} {start_idx} {end_idx}\t{entity_text}\n"
return f"T{idx}", serialized_labeled_span


def serialize_labeled_spans(
labeled_spans: BaseAnnotationList,
def serialize_labeled_multi_span(
idx: int,
annotation: LabeledMultiSpan,
label_prefix: Optional[str] = None,
first_span_id: int = 0,
) -> Tuple[List[str], Dict[LabeledSpan, str]]:
"""Converts entity annotations of type LabeledMultiSpan and LabeledSpan to annotations in the
Brat format.
) -> Tuple[str, str]:
"""Serialize a labeled multi span into a string representation.
Parameters:
labeled_spans (BaseAnnotationList): The list of entity annotations.
label_prefix (Optional[str]): An optional prefix to add to entity labels.
first_span_id: An integer value used for creating span annotation IDs. It ensures the proper assignment of IDs
for predicted annotations, particularly when gold annotations have already been included.
Args:
idx (int): The index for the labeled multi span.
annotation (LabeledMultiSpan): The labeled multi span object to serialize.
label_prefix (Optional[str], optional): A prefix to be added to the label. Defaults to None.
Returns:
Tuple[List[str], DefaultDict[LabeledSpan, str]]: A tuple containing a list of strings representing
entity annotations in the Brat format, and a dictionary mapping labeled spans to their IDs.
str: The id and serialized representation of the labeled multi span.
"""
span2id: Dict[LabeledSpan, str] = defaultdict()
serialized_labeled_spans = []
for i, labeled_span in enumerate(labeled_spans, start=first_span_id):
span_id = f"T{i}"
serialized_labeled_span = serialize_labeled_span(
span_id=span_id, labeled_span=labeled_span, label_prefix=label_prefix
)
span2id[labeled_span] = span_id
serialized_labeled_spans.append(serialized_labeled_span)
return serialized_labeled_spans, span2id
label = annotation.label if label_prefix is None else f"{label_prefix}-{annotation.label}"

locations = []
texts = []
for slice in annotation.slices:
start_idx = slice[0]
end_idx = slice[1]
texts.append(annotation.target[start_idx:end_idx])
locations.append(f"{start_idx} {end_idx}")
location = ";".join(locations)
text = " ".join(texts)
serialized_labeled_span = f"{label} {location}\t{text}\n"
return f"T{idx}", serialized_labeled_span


def serialize_binary_relation(
relation_id: str,
binary_relation: Union[LabeledMultiSpan, LabeledSpan],
span2id: Dict[Annotation, str],
idx: int,
annotation: BinaryRelation,
annotation2id: Dict[Annotation, str],
label_prefix: Optional[str] = None,
) -> str:
) -> Tuple[str, str]:
"""Serialize a binary relation into a string representation.
Args:
relation_id (str): The identifier for the binary relation.
binary_relation (Union[LabeledMultiSpan, LabeledSpan]): The binary relation object to serialize.
idx (str): The index for the binary relation.
annotation (Union[LabeledMultiSpan, LabeledSpan]): The binary relation object to serialize.
Labeled Spans in the binary relation can be either a LabeledMultiSpan or a LabeledSpan.
span2id (Dict[Annotation, str]): A dictionary mapping span annotations to their IDs.
annotation2id (Dict[Annotation, str]): A dictionary mapping span annotations to their IDs.
label_prefix (Optional[str], optional): A prefix to be added to the label.
Defaults to None.
Returns:
str: The serialized representation of the binary relation.
Raises:
Warning: If the binary relation has an unknown type.
str: The id and serialized representation of the binary relation.
"""
if not isinstance(binary_relation, BinaryRelation):
raise Warning(f"relation has unknown type: {type(binary_relation)}")

arg1 = span2id[binary_relation.head]
arg2 = span2id[binary_relation.tail]
label = (
binary_relation.label
if label_prefix is None
else f"{label_prefix}-{binary_relation.label}"
)
serialized_binary_relation = f"{relation_id}\t{label} Arg1:{arg1} Arg2:{arg2}\n"
return serialized_binary_relation
if not isinstance(annotation, BinaryRelation):
raise Warning(f"relation has unknown type: {type(annotation)}")

arg1 = annotation2id[annotation.head]
arg2 = annotation2id[annotation.tail]
label = annotation.label if label_prefix is None else f"{label_prefix}-{annotation.label}"
serialized_binary_relation = f"{label} Arg1:{arg1} Arg2:{arg2}\n"
return f"R{idx}", serialized_binary_relation


def serialize_binary_relations(
binary_relations: BaseAnnotationList,
span2id: Dict[Annotation, str],
def serialize_annotation(
idx: int,
annotation: Annotation,
annotation2id: Dict[Annotation, str],
label_prefix: Optional[str] = None,
first_relation_id: int = 0,
) -> List[str]:
"""
Converts relation annotations to annotations in the Brat format.
e.g: R0 Arg1 Arg2 LABEL
) -> Tuple[str, str]:
if isinstance(annotation, LabeledMultiSpan):
return serialize_labeled_multi_span(
idx=idx, annotation=annotation, label_prefix=label_prefix
)
elif isinstance(annotation, LabeledSpan):
return serialize_labeled_span(idx=idx, annotation=annotation, label_prefix=label_prefix)
elif isinstance(annotation, BinaryRelation):
return serialize_binary_relation(
idx=idx, annotation=annotation, label_prefix=label_prefix, annotation2id=annotation2id
)
else:
raise Warning(f"annotation has unknown type: {type(annotation)}")

Parameters:
binary_relations (BaseAnnotationList): The list of relation annotations.
span2id (Dict[Annotation, str]): A dictionary mapping labeled spans to their annotation IDs.
label_prefix (Optional[str]): An optional prefix to add to relation labels.
first_relation_id: An integer value used for creating relation annotation IDs. It ensures the proper assignment
of IDs for predicted annotations, particularly when gold annotations have already been included.

Returns:
List[str]: A list of strings representing relation annotations in the Brat format.
"""
serialized_binary_relations = []
for i, binary_relation in enumerate(binary_relations, start=first_relation_id):
relation_id = f"R{i}"
serialized_binary_relation = serialize_binary_relation(
relation_id=relation_id,
binary_relation=binary_relation,
span2id=span2id,
def serialize_annotations(
annotations: Iterable[Annotation],
annotation2id: Dict[Annotation, str],
label_prefix: Optional[str] = None,
first_idx: int = 0,
) -> List[str]:
serialized_annotations = []
for i, annotation in enumerate(annotations, start=first_idx):
annotation_id, serialized_annotation = serialize_annotation(
idx=i,
annotation=annotation,
annotation2id=annotation2id,
label_prefix=label_prefix,
)
serialized_binary_relations.append(serialized_binary_relation)
serialized_annotations.append(f"{annotation_id}\t{serialized_annotation}")
annotation2id[annotation] = annotation_id

return serialized_binary_relations
return serialized_annotations


def serialize_annotations(
def serialize_annotation_layers(
labeled_span_layer: AnnotationLayer,
binary_relation_layer: AnnotationLayer,
gold_label_prefix: Optional[str] = None,
Expand All @@ -182,32 +162,34 @@ def serialize_annotations(
"""
serialized_labeled_spans = []
serialized_binary_relations = []
annotation2id: Dict[Annotation, str] = {}
if gold_label_prefix is not None:
serialized_labeled_spans_gold, span2id = serialize_labeled_spans(
labeled_spans=labeled_span_layer,
serialized_labeled_spans_gold = serialize_annotations(
annotations=labeled_span_layer,
label_prefix=gold_label_prefix,
annotation2id=annotation2id,
)
serialized_labeled_spans.extend(serialized_labeled_spans_gold)
serialized_binary_relations_gold = serialize_binary_relations(
binary_relations=binary_relation_layer,
span2id=span2id,
serialized_binary_relations_gold = serialize_annotations(
annotations=binary_relation_layer,
annotation2id=annotation2id,
label_prefix=gold_label_prefix,
)
serialized_binary_relations.extend(serialized_binary_relations_gold)
else:
span2id = {}
serialized_labeled_spans_pred, span2id_pred = serialize_labeled_spans(
labeled_spans=labeled_span_layer.predictions,
annotation2id = {}
serialized_labeled_spans_pred = serialize_annotations(
annotations=labeled_span_layer.predictions,
label_prefix=prediction_label_prefix,
first_span_id=len(serialized_labeled_spans),
first_idx=len(serialized_labeled_spans),
annotation2id=annotation2id,
)
span2id.update(span2id_pred)
serialized_labeled_spans.extend(serialized_labeled_spans_pred)
serialized_binary_relations_pred = serialize_binary_relations(
binary_relations=binary_relation_layer.predictions,
span2id=span2id,
serialized_binary_relations_pred = serialize_annotations(
annotations=binary_relation_layer.predictions,
annotation2id=annotation2id,
label_prefix=prediction_label_prefix,
first_relation_id=len(serialized_binary_relations),
first_idx=len(serialized_binary_relations),
)
serialized_binary_relations.extend(serialized_binary_relations_pred)
return serialized_labeled_spans + serialized_binary_relations
Expand Down Expand Up @@ -299,7 +281,7 @@ def write(
file_name = f"{doc_id}.ann"
metadata_text[f"{file_name}"] = doc.text
ann_path = os.path.join(realpath, file_name)
serialized_annotations = serialize_annotations(
serialized_annotations = serialize_annotation_layers(
labeled_span_layer=doc[entity_layer],
binary_relation_layer=doc[relation_layer],
gold_label_prefix=gold_label_prefix,
Expand Down
Loading

0 comments on commit ec19386

Please sign in to comment.