Skip to content

Commit

Permalink
Merge pull request #16 from ArneBinder/fix_tokenize_document_missed_a…
Browse files Browse the repository at this point in the history
…nnotations

fix annotations that are reported as missing after tokenization
  • Loading branch information
ArneBinder authored Dec 11, 2023
2 parents 9ea864d + 9ef8330 commit 8d0dbca
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 18 deletions.
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [

[tool.poetry.dependencies]
python = "^3.9"
pytorch-ie = ">=0.29.2,<0.30.0"
pytorch-ie = ">=0.29.4,<0.30.0"
torchmetrics = "^1"
pytorch-crf = ">=0.7.2"

Expand Down
49 changes: 45 additions & 4 deletions src/pie_modules/document/processing/tokenization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import json
import logging
from collections import defaultdict
from copy import copy, deepcopy
Expand Down Expand Up @@ -62,6 +63,7 @@ def text_based_document_to_token_based(
char_to_token: Optional[Callable[[int], Optional[int]]] = None,
strict_span_conversion: bool = True,
verbose: bool = True,
added_annotations: Optional[Dict[str, List[Annotation]]] = None,
) -> ToD:
document_type = resolve_type(
type_or_str=result_document_type, expected_super_type=TokenBasedDocument
Expand Down Expand Up @@ -162,16 +164,21 @@ def char_to_token(char_idx: int) -> Optional[int]:
else:
token_span = char_span.copy(start=start_token_idx, end=end_token_idx_inclusive + 1)
override_annotations[text_targeting_layer_name][char_span._id] = token_span
if added_annotations is not None:
added_annotations[text_targeting_layer_name].append(char_span)
valid_spans = set(override_annotations[text_targeting_layer_name].values())
result[text_targeting_layer_name].extend(sorted(valid_spans, key=lambda span: span.start))

result.add_all_annotations_from_other(
added_annotations_from_remaining_layers = result.add_all_annotations_from_other(
doc,
override_annotations=override_annotations,
removed_annotations=removed_annotations,
strict=strict_span_conversion,
verbose=verbose,
)
if added_annotations is not None:
for layer_name, annotations in added_annotations_from_remaining_layers.items():
added_annotations[layer_name].extend(annotations)

return result

Expand All @@ -184,6 +191,7 @@ def token_based_document_to_text_based(
join_tokens_with: Optional[str] = None,
strict_span_conversion: bool = True,
verbose: bool = True,
added_annotations: Optional[Dict[str, List[Annotation]]] = None,
) -> TeD:
document_type = resolve_type(
type_or_str=result_document_type, expected_super_type=TextBasedDocument
Expand Down Expand Up @@ -258,16 +266,21 @@ def token_based_document_to_text_based(

char_span = token_span.copy(start=start_char_idx, end=end_char_idx)
override_annotations[token_targeting_layer_name][token_span._id] = char_span
if added_annotations is not None:
added_annotations[token_targeting_layer_name].append(token_span)
valid_spans = set(override_annotations[token_targeting_layer_name].values())
result[token_targeting_layer_name].extend(sorted(valid_spans, key=lambda span: span.start))

result.add_all_annotations_from_other(
added_annotations_from_remaining_layers = result.add_all_annotations_from_other(
doc,
override_annotations=override_annotations,
removed_annotations=removed_annotations,
strict=strict_span_conversion,
verbose=verbose,
)
if added_annotations is not None:
for layer_name, annotations in added_annotations_from_remaining_layers.items():
added_annotations[layer_name].extend(annotations)

return result

Expand All @@ -281,6 +294,7 @@ def tokenize_document(
verbose: bool = True,
**tokenize_kwargs,
) -> List[ToD]:
added_annotations: Dict[str, List[Annotation]] = defaultdict(list)
result = []
partitions: Iterable[Span]
if partition_layer is None:
Expand Down Expand Up @@ -319,9 +333,36 @@ def tokenize_document(
result_document_type=result_document_type,
token_offset_mapping=token_offset_mapping,
char_to_token=char_to_token,
strict_span_conversion=strict_span_conversion,
verbose=verbose,
strict_span_conversion=False,
verbose=False,
added_annotations=added_annotations,
)
tokenized_document.metadata["tokenizer_encoding"] = batch_encoding
result.append(tokenized_document)

missed_annotations = defaultdict(set)
if strict_span_conversion or verbose:
for annotation_field in doc.annotation_fields():
current_missed_annotations = set(doc[annotation_field.name]) - set(
added_annotations[annotation_field.name]
)
if len(current_missed_annotations) > 0:
missed_annotations[annotation_field.name] = current_missed_annotations

if len(missed_annotations) > 0:
missed_annotations_simplified = {k: str(v) for k, v in missed_annotations.items()}
if strict_span_conversion:
raise ValueError(
f"could not convert all annotations from document with id={doc.id} to token based documents, "
f"but strict_span_conversion is True, so raise an error, "
f"missed annotations:\n{json.dumps(missed_annotations_simplified, sort_keys=True, indent=2)}"
)
else:
if verbose:
logger.warning(
f"could not convert all annotations from document with id={doc.id} to token based documents, "
f"missed annotations (disable this message with verbose=False):\n"
f"{json.dumps(missed_annotations_simplified, sort_keys=True, indent=2)}"
)

return result
66 changes: 57 additions & 9 deletions tests/document/processing/test_tokenization.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import dataclasses
from collections import defaultdict

import pytest
from pytorch_ie.annotations import BinaryRelation, Label, LabeledSpan, Span
Expand Down Expand Up @@ -152,11 +153,16 @@ def test_find_token_offset_mapping(text_document, token_document):


def test_text_based_document_to_token_based(text_document, token_document):
added_annotations = defaultdict(list)
doc = text_based_document_to_token_based(
text_document,
tokens=list(token_document.tokens),
result_document_type=TokenizedTestDocument,
added_annotations=added_annotations,
)
for ann_field in text_document.annotation_fields():
layer_name = ann_field.name
assert added_annotations[layer_name] == list(text_document[layer_name])
_test_token_document(doc)


Expand Down Expand Up @@ -310,11 +316,17 @@ class WrongAnnotationType(TextBasedDocument):


def test_token_based_document_to_text_based(token_document, text_document):
added_annotations = defaultdict(list)
result = token_based_document_to_text_based(
token_document,
text=text_document.text,
result_document_type=TestDocument,
added_annotations=added_annotations,
)
for ann_field in token_document.annotation_fields():
layer_name = ann_field.name
assert added_annotations[layer_name] == list(token_document[layer_name])

_test_text_document(result)


Expand Down Expand Up @@ -440,14 +452,28 @@ def test_tokenize_document(text_document, tokenizer):
]


def test_tokenize_document_max_length(text_document, tokenizer):
tokenized_docs = tokenize_document(
text_document,
tokenizer=tokenizer,
result_document_type=TokenizedTestDocument,
strict_span_conversion=False,
max_length=10,
return_overflowing_tokens=True,
def test_tokenize_document_max_length(text_document, tokenizer, caplog):
caplog.clear()
with caplog.at_level("WARNING"):
tokenized_docs = tokenize_document(
text_document,
tokenizer=tokenizer,
result_document_type=TokenizedTestDocument,
# max_length is set to 10, so the document is split into two parts
strict_span_conversion=False,
max_length=10,
return_overflowing_tokens=True,
)
assert len(caplog.records) == 1
assert (
caplog.records[0].message
== "could not convert all annotations from document with id=None to token based documents, missed annotations "
"(disable this message with verbose=False):\n"
"{\n"
' "relations": "{BinaryRelation(head=LabeledSpan(start=16, end=24, label=\'per\', score=1.0), '
"tail=LabeledSpan(start=34, end=35, label='org', score=1.0), label='per:employee_of', score=1.0)}\",\n"
' "sentences": "{Span(start=16, end=36)}"\n'
"}"
)
assert len(tokenized_docs) == 2
tokenized_doc = tokenized_docs[0]
Expand Down Expand Up @@ -515,12 +541,34 @@ def test_tokenize_document_max_length(text_document, tokenizer):
assert relation_tuples == [("('it',)", "per:founder", "('O',)")]


def test_tokenize_document_max_length_strict(text_document, tokenizer):
with pytest.raises(ValueError) as excinfo:
tokenize_document(
text_document,
tokenizer=tokenizer,
result_document_type=TokenizedTestDocument,
# max_length is set to 10, so the document is split into two parts
strict_span_conversion=True,
max_length=10,
return_overflowing_tokens=True,
)
assert (
str(excinfo.value)
== "could not convert all annotations from document with id=None to token based documents, "
"but strict_span_conversion is True, so raise an error, missed annotations:\n"
"{\n"
' "relations": "{BinaryRelation(head=LabeledSpan(start=16, end=24, label=\'per\', score=1.0), '
"tail=LabeledSpan(start=34, end=35, label='org', score=1.0), label='per:employee_of', score=1.0)}\",\n"
' "sentences": "{Span(start=16, end=36)}"\n'
"}"
)


def test_tokenize_document_partition(text_document, tokenizer):
tokenized_docs = tokenize_document(
text_document,
tokenizer=tokenizer,
result_document_type=TokenizedTestDocument,
strict_span_conversion=False,
partition_layer="sentences",
)
assert len(tokenized_docs) == 3
Expand Down

0 comments on commit 8d0dbca

Please sign in to comment.