Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add generic brat dataset #12

Merged
merged 15 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions dataset_builders/pie/brat/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# PIE Dataset Card for "conll2003"

This is a [PyTorch-IE](https://github.com/ChristophAlt/pytorch-ie) wrapper for the
[BRAT Huggingface dataset loading script](https://huggingface.co/datasets/DFKI-SLT/brat).

## Data Schema

The document type for this dataset is `BratDocument` or `BratDocumentWithMergedSpans`, depending on if the
data was loaded with `merge_fragmented_spans=True` (default: `False`). They define the following data fields:

- `text` (str)
- `id` (str, optional)
- `metadata` (dictionary, optional)

and the following annotation layers:

- `spans` (annotation type: `LabeledMultiSpan` in the case of `BratDocument` and `LabeledSpan` and in the case of `BratDocumentWithMergedSpans`, target: `text`)
- `relations` (annotation type: `BinaryRelation`, target: `spans`)
- `span_attributes` (annotation type: `Attribute`, target: `spans`)
- `relation_attributes` (annotation type: `Attribute`, target: `relations`)

The `Attribute` annotation type is defined as follows:

- `annotation` (type: `Annotation`): the annotation to which the attribute is attached
- `label` (type: `str`)
- `value` (type: `str`, optional)
- `score` (type: `float`, optional, not included in comparison)

See [here](https://github.com/ChristophAlt/pytorch-ie/blob/main/src/pytorch_ie/annotations.py) for the remaining annotation type definitions.

## Document Converters

The dataset provides no predefined document converters because the BRAT format is very flexible and can be used
for many different tasks. You can add your own document converter by doing the following:

```python
import dataclasses
from typing import Optional

from pytorch_ie.core import AnnotationList, annotation_field
from pytorch_ie.documents import TextBasedDocument
from pytorch_ie.annotations import LabeledSpan

from pie_datasets import DatasetDict

# define your document class
@dataclasses.dataclass
class MyDocument(TextBasedDocument):
my_field: Optional[str] = None
my_span_annotations: AnnotationList[LabeledSpan] = annotation_field(target="text")

# define your document converter
def my_converter(document: BratDocumentWithMergedSpans) -> MyDocument:
# create your document with the data from the original document.
# The fields "text", "id" and "metadata" are derived from the TextBasedDocument.
my_document = MyDocument(id=document.id, text=document.text, metadata=document.metadata, my_field="my_value")

# create a new span annotation
new_span = LabeledSpan(label="my_label", start=2, end=10)
# add the new span annotation to your document
my_document.my_span_annotations.append(new_span)

# add annotations from the document to your document
for span in document.spans:
# we need to copy the span because an annotation can only be attached to one document
my_document.my_span_annotations.append(span.copy())

return my_document


# load the dataset. We use the "merge_fragmented_spans" dataset variant here
# because it provides documents of type BratDocumentWithMergedSpans.
dataset = DatasetDict.load_dataset("pie/brat", name="merge_fragmented_spans", data_dir="path/to/brat/data")

# attach your document converter to the dataset
dataset.register_document_converter(my_converter)

# convert the dataset
converted_dataset = dataset.to_document_type(MyDocument)
```
305 changes: 305 additions & 0 deletions dataset_builders/pie/brat/brat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,305 @@
import dataclasses
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union

import datasets
from pytorch_ie.annotations import BinaryRelation, LabeledMultiSpan, LabeledSpan
from pytorch_ie.core import Annotation, AnnotationList, annotation_field
from pytorch_ie.documents import TextBasedDocument

from pie_datasets import GeneratorBasedBuilder

logger = logging.getLogger(__name__)


def dl2ld(dict_of_lists: Dict[str, List[Any]]) -> List[Dict[str, Any]]:
return [dict(zip(dict_of_lists, t)) for t in zip(*dict_of_lists.values())]


def ld2dl(
list_fo_dicts: List[Dict[str, Any]], keys: Optional[List[str]] = None
) -> Dict[str, List[Any]]:
keys = keys or list(list_fo_dicts[0])
return {k: [dic[k] for dic in list_fo_dicts] for k in keys}


@dataclasses.dataclass(eq=True, frozen=True)
class Attribute(Annotation):
annotation: Annotation
label: str
value: Optional[str] = None
score: Optional[float] = dataclasses.field(default=None, compare=False)


@dataclasses.dataclass
class BratDocument(TextBasedDocument):
spans: AnnotationList[LabeledMultiSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="spans")
span_attributes: AnnotationList[Attribute] = annotation_field(target="spans")
relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations")


@dataclasses.dataclass
class BratDocumentWithMergedSpans(TextBasedDocument):
spans: AnnotationList[LabeledSpan] = annotation_field(target="text")
relations: AnnotationList[BinaryRelation] = annotation_field(target="spans")
span_attributes: AnnotationList[Attribute] = annotation_field(target="spans")
relation_attributes: AnnotationList[Attribute] = annotation_field(target="relations")


def example_to_document(
example: Dict[str, Any], merge_fragmented_spans: bool = False
) -> BratDocument:
if merge_fragmented_spans:
doc = BratDocumentWithMergedSpans(text=example["context"], id=example["file_name"])
else:
doc = BratDocument(text=example["context"], id=example["file_name"])

spans: Dict[str, LabeledSpan] = dict()
span_locations: List[Tuple[Tuple[int, int]]] = []
span_texts: List[str] = []
for span_dict in dl2ld(example["spans"]):
starts: List[int] = span_dict["locations"]["start"]
ends: List[int] = span_dict["locations"]["end"]
slices = tuple(zip(starts, ends))
span_locations.append(slices)
span_texts.append(span_dict["text"])
# sanity check
span_text_parts = [doc.text[start:end] for start, end in slices]
joined_span_texts_stripped = " ".join(span_text_parts).strip()
span_text_stripped = span_dict["text"].strip()
if joined_span_texts_stripped != span_text_stripped:
logger.warning(

Check warning on line 73 in dataset_builders/pie/brat/brat.py

View check run for this annotation

Codecov / codecov/patch

dataset_builders/pie/brat/brat.py#L73

Added line #L73 was not covered by tests
f"joined span parts do not match stripped span text field content. "
f'joined_span_texts_stripped: "{joined_span_texts_stripped}" != stripped "text": "{span_text_stripped}"'
)
if merge_fragmented_spans:
if len(starts) > 1:
# check if the text in between the fragments holds only space
merged_content_texts = [

Check warning on line 80 in dataset_builders/pie/brat/brat.py

View check run for this annotation

Codecov / codecov/patch

dataset_builders/pie/brat/brat.py#L80

Added line #L80 was not covered by tests
doc.text[start:end] for start, end in zip(ends[:-1], starts[1:])
]
merged_content_texts_not_empty = [

Check warning on line 83 in dataset_builders/pie/brat/brat.py

View check run for this annotation

Codecov / codecov/patch

dataset_builders/pie/brat/brat.py#L83

Added line #L83 was not covered by tests
text.strip() for text in merged_content_texts if text.strip() != ""
]
if len(merged_content_texts_not_empty) > 0:
logger.warning(

Check warning on line 87 in dataset_builders/pie/brat/brat.py

View check run for this annotation

Codecov / codecov/patch

dataset_builders/pie/brat/brat.py#L86-L87

Added lines #L86 - L87 were not covered by tests
f"document '{doc.id}' contains a non-contiguous span with text content in between "
f"(will be merged into a single span): "
f"newly covered text parts: {merged_content_texts_not_empty}, "
f"merged span text: '{doc.text[starts[0]:ends[-1]]}', "
f"annotation: {span_dict}"
)
# just take everything
start = min(starts)
end = max(ends)
span = LabeledSpan(start=start, end=end, label=span_dict["type"])
else:
span = LabeledMultiSpan(slices=slices, label=span_dict["type"])
spans[span_dict["id"]] = span

doc.spans.extend(spans.values())
doc.metadata["span_ids"] = list(spans.keys())
doc.metadata["span_locations"] = span_locations
doc.metadata["span_texts"] = span_texts

relations: Dict[str, BinaryRelation] = dict()
for rel_dict in dl2ld(example["relations"]):
arguments = dict(zip(rel_dict["arguments"]["type"], rel_dict["arguments"]["target"]))
assert set(arguments) == {"Arg1", "Arg2"}
head = spans[arguments["Arg1"]]
tail = spans[arguments["Arg2"]]
rel = BinaryRelation(head=head, tail=tail, label=rel_dict["type"])
relations[rel_dict["id"]] = rel

doc.relations.extend(relations.values())
doc.metadata["relation_ids"] = list(relations.keys())

equivalence_relations = dl2ld(example["equivalence_relations"])
if len(equivalence_relations) > 0:
raise NotImplementedError("converting equivalence_relations is not yet implemented")

events = dl2ld(example["events"])
if len(events) > 0:
raise NotImplementedError("converting events is not yet implemented")

attribute_annotations: Dict[str, Dict[str, Attribute]] = defaultdict(dict)
attribute_ids = []
for attribute_dict in dl2ld(example["attributions"]):
target_id = attribute_dict["target"]
if target_id in spans:
target_layer_name = "spans"
annotation = spans[target_id]
elif target_id in relations:
target_layer_name = "relations"
annotation = relations[target_id]
else:
raise Exception("only span and relation attributes are supported yet")

Check warning on line 138 in dataset_builders/pie/brat/brat.py

View check run for this annotation

Codecov / codecov/patch

dataset_builders/pie/brat/brat.py#L138

Added line #L138 was not covered by tests
attribute = Attribute(
annotation=annotation,
label=attribute_dict["type"],
value=attribute_dict["value"],
)
attribute_annotations[target_layer_name][attribute_dict["id"]] = attribute
attribute_ids.append((target_layer_name, attribute_dict["id"]))

doc.span_attributes.extend(attribute_annotations["spans"].values())
doc.relation_attributes.extend(attribute_annotations["relations"].values())
doc.metadata["attribute_ids"] = attribute_ids

normalizations = dl2ld(example["normalizations"])
if len(normalizations) > 0:
raise NotImplementedError("converting normalizations is not yet implemented")

notes = dl2ld(example["notes"])
if len(notes) > 0:
raise NotImplementedError("converting notes is not yet implemented")

return doc


def document_to_example(
document: Union[BratDocument, BratDocumentWithMergedSpans]
) -> Dict[str, Any]:
example = {
"context": document.text,
"file_name": document.id,
}
span_dicts: Dict[Union[LabeledSpan, LabeledMultiSpan], Dict[str, Any]] = dict()
assert len(document.metadata["span_locations"]) == len(document.spans)
assert len(document.metadata["span_texts"]) == len(document.spans)
assert len(document.metadata["span_ids"]) == len(document.spans)
for i, span in enumerate(document.spans):
locations = tuple((start, end) for start, end in document.metadata["span_locations"][i])
if isinstance(span, LabeledSpan):
assert locations[0][0] == span.start
assert locations[-1][1] == span.end
elif isinstance(span, LabeledMultiSpan):
assert span.slices == locations
else:
raise TypeError(f"span has unknown type [{type(span)}]: {span}")

Check warning on line 181 in dataset_builders/pie/brat/brat.py

View check run for this annotation

Codecov / codecov/patch

dataset_builders/pie/brat/brat.py#L181

Added line #L181 was not covered by tests

starts, ends = zip(*locations)
span_dict = {
"id": document.metadata["span_ids"][i],
"locations": {
"start": list(starts),
"end": list(ends),
},
"text": document.metadata["span_texts"][i],
"type": span.label,
}
if span in span_dicts:
prev_ann_dict = span_dicts[span]
ann_dict = span_dict
logger.warning(

Check warning on line 196 in dataset_builders/pie/brat/brat.py

View check run for this annotation

Codecov / codecov/patch

dataset_builders/pie/brat/brat.py#L194-L196

Added lines #L194 - L196 were not covered by tests
f"document {document.id}: annotation exists twice: {prev_ann_dict['id']} and {ann_dict['id']} "
f"are identical"
)
span_dicts[span] = span_dict
example["spans"] = ld2dl(list(span_dicts.values()), keys=["id", "type", "locations", "text"])

relation_dicts: Dict[BinaryRelation, Dict[str, Any]] = dict()
assert len(document.metadata["relation_ids"]) == len(document.relations)
for i, rel in enumerate(document.relations):
arg1_id = span_dicts[rel.head]["id"]
arg2_id = span_dicts[rel.tail]["id"]
relation_dict = {
"id": document.metadata["relation_ids"][i],
"type": rel.label,
"arguments": {
"type": ["Arg1", "Arg2"],
"target": [arg1_id, arg2_id],
},
}
if rel in relation_dicts:
prev_ann_dict = relation_dicts[rel]
ann_dict = relation_dict
logger.warning(

Check warning on line 219 in dataset_builders/pie/brat/brat.py

View check run for this annotation

Codecov / codecov/patch

dataset_builders/pie/brat/brat.py#L217-L219

Added lines #L217 - L219 were not covered by tests
f"document {document.id}: annotation exists twice: {prev_ann_dict['id']} and {ann_dict['id']} "
f"are identical"
)
relation_dicts[rel] = relation_dict

example["relations"] = ld2dl(list(relation_dicts.values()), keys=["id", "type", "arguments"])

example["equivalence_relations"] = ld2dl([], keys=["type", "targets"])
example["events"] = ld2dl([], keys=["id", "type", "trigger", "arguments"])

annotation_dicts = {
"spans": span_dicts,
"relations": relation_dicts,
}
all_attribute_annotations = {
"spans": document.span_attributes,
"relations": document.relation_attributes,
}
attribute_dicts: Dict[Annotation, Dict[str, Any]] = dict()
attribute_ids_per_target = defaultdict(list)
for target_layer, attribute_id in document.metadata["attribute_ids"]:
attribute_ids_per_target[target_layer].append(attribute_id)

for target_layer, attribute_ids in attribute_ids_per_target.items():
attribute_annotations = all_attribute_annotations[target_layer]
assert len(attribute_ids) == len(attribute_annotations)
for i, attribute_annotation in enumerate(attribute_annotations):
target_id = annotation_dicts[target_layer][attribute_annotation.annotation]["id"]
attribute_dict = {
"id": attribute_ids_per_target[target_layer][i],
"type": attribute_annotation.label,
"target": target_id,
"value": attribute_annotation.value,
}
if attribute_annotation in attribute_dicts:
prev_ann_dict = attribute_dicts[attribute_annotation]
ann_dict = attribute_annotation
logger.warning(

Check warning on line 257 in dataset_builders/pie/brat/brat.py

View check run for this annotation

Codecov / codecov/patch

dataset_builders/pie/brat/brat.py#L255-L257

Added lines #L255 - L257 were not covered by tests
f"document {document.id}: annotation exists twice: {prev_ann_dict['id']} and {ann_dict['id']} "
f"are identical"
)
attribute_dicts[attribute_annotation] = attribute_dict

example["attributions"] = ld2dl(
list(attribute_dicts.values()), keys=["id", "type", "target", "value"]
)
example["normalizations"] = ld2dl(
[], keys=["id", "type", "target", "resource_id", "entity_id"]
)
example["notes"] = ld2dl([], keys=["id", "type", "target", "note"])

return example


class BratConfig(datasets.BuilderConfig):
"""BuilderConfig for BratDatasetLoader."""

def __init__(self, merge_fragmented_spans: bool = False, **kwargs):
"""BuilderConfig for DocRED.

Args:
**kwargs: keyword arguments forwarded to super.
"""
super().__init__(**kwargs)
self.merge_fragmented_spans = merge_fragmented_spans


class BratDatasetLoader(GeneratorBasedBuilder):
# this requires https://github.com/ChristophAlt/pytorch-ie/pull/288
DOCUMENT_TYPES = {
"default": BratDocument,
"merge_fragmented_spans": BratDocumentWithMergedSpans,
}

DEFAULT_CONFIG_NAME = "default"
BUILDER_CONFIGS = [
BratConfig(name="default"),
BratConfig(name="merge_fragmented_spans", merge_fragmented_spans=True),
]

BASE_DATASET_PATH = "DFKI-SLT/brat"

def _generate_document(self, example, **kwargs):
return example_to_document(
example, merge_fragmented_spans=self.config.merge_fragmented_spans
)
1 change: 1 addition & 0 deletions tests/dataset_builders/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
HF_BASE_PATH = DATASET_BUILDER_BASE_PATH / "hf"
PIE_BASE_PATH = DATASET_BUILDER_BASE_PATH / "pie"
HF_DS_FIXTURE_DATA_PATH = FIXTURES_ROOT / "dataset_builders" / "hf"
PIE_DS_FIXTURE_DATA_PATH = FIXTURES_ROOT / "dataset_builders" / "pie"

logger = logging.getLogger(__name__)

Expand Down
Loading
Loading