diff --git a/dataset_builders/pie/cdcp/cdcp.py b/dataset_builders/pie/cdcp/cdcp.py index 755557d2..0c27835a 100644 --- a/dataset_builders/pie/cdcp/cdcp.py +++ b/dataset_builders/pie/cdcp/cdcp.py @@ -1,8 +1,9 @@ import dataclasses import logging -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Dict, List, Optional import datasets +from pie_models.document.processing.text_span_trimmer import trim_text_spans from pytorch_ie.annotations import BinaryRelation, LabeledSpan from pytorch_ie.core import Annotation, AnnotationList, annotation_field from pytorch_ie.documents import ( @@ -11,7 +12,6 @@ ) from pie_datasets import GeneratorBasedBuilder -from pie_datasets.document.processing.text_span_trimmer import trim_text_spans log = logging.getLogger(__name__) diff --git a/dataset_builders/pie/scidtb_argmin/scidtb_argmin.py b/dataset_builders/pie/scidtb_argmin/scidtb_argmin.py index 28d5134c..1adb0fc5 100644 --- a/dataset_builders/pie/scidtb_argmin/scidtb_argmin.py +++ b/dataset_builders/pie/scidtb_argmin/scidtb_argmin.py @@ -1,10 +1,11 @@ import dataclasses import logging -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Tuple import datasets +from pie_models.document.processing import token_based_document_to_text_based from pytorch_ie.annotations import BinaryRelation, LabeledSpan -from pytorch_ie.core import AnnotationList, Document, annotation_field +from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import ( TextDocumentWithLabeledSpansAndBinaryRelations, TokenBasedDocument, @@ -12,7 +13,6 @@ from pytorch_ie.utils.span import bio_tags_to_spans from pie_datasets import GeneratorBasedBuilder -from pie_datasets.document.processing import token_based_document_to_text_based log = logging.getLogger(__name__) diff --git a/dataset_builders/pie/tacred/tacred.py b/dataset_builders/pie/tacred/tacred.py index 213746bb..68970724 100644 --- a/dataset_builders/pie/tacred/tacred.py +++ b/dataset_builders/pie/tacred/tacred.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Optional import datasets +from pie_models.document.processing import token_based_document_to_text_based from pytorch_ie.annotations import BinaryRelation, LabeledSpan from pytorch_ie.core import Annotation, AnnotationList, annotation_field from pytorch_ie.documents import ( @@ -10,7 +11,6 @@ ) from pie_datasets import GeneratorBasedBuilder -from pie_datasets.document.processing import token_based_document_to_text_based @dataclass(eq=True, frozen=True) diff --git a/poetry.lock b/poetry.lock index 0bf22e22..4adcf983 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1026,6 +1026,22 @@ sql-other = ["SQLAlchemy (>=1.4.36)"] test = ["hypothesis (>=6.46.1)", "pytest (>=7.3.2)", "pytest-asyncio (>=0.17.0)", "pytest-xdist (>=2.2.0)"] xml = ["lxml (>=4.8.0)"] +[[package]] +name = "pie-models" +version = "0.7.6" +description = "Model and Taskmodule implementations for PyTorch-IE" +optional = false +python-versions = ">=3.9,<4.0" +files = [ + {file = "pie_models-0.7.6-py3-none-any.whl", hash = "sha256:3f8245b6c6e07a1aa86aa2ff35fbd2a4b80637b3bb9d5a5cdb23362a649d1582"}, + {file = "pie_models-0.7.6.tar.gz", hash = "sha256:6b7233a4ac7e640810595f702c2d4bf6422e12e4796f54371bd163fb828579bc"}, +] + +[package.dependencies] +pytorch-crf = ">=0.7.2" +pytorch-ie = ">=0.29.2,<0.30.0" +torchmetrics = ">=1,<2" + [[package]] name = "platformdirs" version = "3.11.0" @@ -1169,15 +1185,26 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "pytorch-crf" +version = "0.7.2" +description = "Conditional random field in PyTorch" +optional = false +python-versions = ">=3.6, <4" +files = [ + {file = "pytorch-crf-0.7.2.tar.gz", hash = "sha256:e6456e22ccfc99a3d4fe1e03e996103b1b39e9830bf3c7e12e7a9077d3be866d"}, + {file = "pytorch_crf-0.7.2-py3-none-any.whl", hash = "sha256:1b2d7d5eea3255f6e0cac09ab8b645472e76ff70d9333bc88762cf7317a4992d"}, +] + [[package]] name = "pytorch-ie" -version = "0.29.1" +version = "0.29.2" description = "State-of-the-art Information Extraction in PyTorch" optional = false python-versions = ">=3.9,<4.0" files = [ - {file = "pytorch_ie-0.29.1-py3-none-any.whl", hash = "sha256:e56498570346b0cda165e49fe595bc6e4119af70db7b33c4673282c779353e45"}, - {file = "pytorch_ie-0.29.1.tar.gz", hash = "sha256:33a992eaa643ebe2dd98196930b21b7fd91d525d7f1b75006053a39781abc6d1"}, + {file = "pytorch_ie-0.29.2-py3-none-any.whl", hash = "sha256:73908f8a6b43e9484a8a463ce601e50367fbdca9e7be44d83608b1f035502bb1"}, + {file = "pytorch_ie-0.29.2.tar.gz", hash = "sha256:9c0b43b43307107c963e927336a0083a7f4fa4ca224b1447999848eedddde0d7"}, ] [package.dependencies] @@ -2050,4 +2077,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "1ff9a1791ab26f0b329fb3f9bb0ead4b6b01a73ae853b717bbdb61defaecb1bf" +content-hash = "e60362fb728c14102301c68b6c9735bd6fd50e30852c3b626dc427edd96b6639" diff --git a/pyproject.toml b/pyproject.toml index 7e395b48..1ca56b7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.9" -pytorch-ie = ">=0.29.1,<0.30.0" +pie-models = "^0.7.6" datasets = "^2.14" # this was manually added because we get a conflict with pyarrow otherwise pyarrow = "^13" diff --git a/src/pie_datasets/document/processing/__init__.py b/src/pie_datasets/document/processing/__init__.py index c37ac786..641fa1c0 100644 --- a/src/pie_datasets/document/processing/__init__.py +++ b/src/pie_datasets/document/processing/__init__.py @@ -1,9 +1 @@ from .generic import Caster, Converter, Pipeline -from .regex_partitioner import RegexPartitioner -from .relation_argument_sorter import RelationArgumentSorter -from .text_span_trimmer import TextSpanTrimmer -from .tokenization import ( - text_based_document_to_token_based, - token_based_document_to_text_based, - tokenize_document, -) diff --git a/src/pie_datasets/document/processing/regex_partitioner.py b/src/pie_datasets/document/processing/regex_partitioner.py deleted file mode 100644 index 7252d6f4..00000000 --- a/src/pie_datasets/document/processing/regex_partitioner.py +++ /dev/null @@ -1,217 +0,0 @@ -from __future__ import annotations - -import json -import logging -import re -import statistics -from typing import Any, Callable, Iterable, Iterator, Match, TypeVar - -from pytorch_ie.annotations import LabeledSpan -from pytorch_ie.documents import TextBasedDocument - -from pie_datasets import Dataset, IterableDataset -from pie_datasets.core.dataset_dict import EnterDatasetMixin, ExitDatasetMixin - -logger = logging.getLogger(__name__) - - -D = TypeVar("D", bound=TextBasedDocument) - - -def create_regex_matcher(pattern): - return re.compile(pattern).finditer - - -def strip_span(start: int, end: int, text: str) -> tuple[int, int]: - """This method strips the leading and trailing whitespaces from the span. - - :param start: An integer value that represents the start index of the span. - :param end: An integer value that represents the end index of the span. - :param text: A string value that represents the text from which the span is extracted. - """ - span_text = text[start:end] - new_start = start + len(span_text) - len(span_text.lstrip()) - new_end = end - len(span_text) + len(span_text.rstrip()) - # if the span is empty, then create a span of length 0 at the start index - if new_start >= new_end: - new_start = start - new_end = start - return new_start, new_end - - -def _get_partitions_with_matcher( - text: str, - matcher_or_pattern: Callable[[str], Iterable[Match]] | str, - label_group_id: int | None = None, # = 1, - label_whitelist: list[str] | None = None, - skip_initial_partition: bool = False, # = True - default_partition_label: str = "partition", - initial_partition_label: str | None = None, - strip_whitespace: bool = False, - verbose: bool = True, -) -> Iterator[LabeledSpan]: - """This method yields LabeledSpans as partitions of the given text. matcher is used to search - for a pattern in the text. If the pattern is found, it returns a Match object that contains - matched groups. A partition is then created using a span in the matched groups. The span of a - partition starts from the first match (inclusive) and ends at the next match (exclusive) or at - the end of the text. A partition is labeled either using the default_partition_label or using - the list of labels available in label_whitelist. It should be noted that none of the partitions - overlap. - - :param text: A text that is to be partitioned - :param matcher_or_pattern: A method or a string. In the former case, that method is used to - find a pattern in the text and return an iterator yielding the Match objects, e.g. - re.compile(PATTERN).finditer. In the latter, the string is used as a pattern to find the - matches in the text. - :param label_group_id: An integer value (default:None) to select the desired match group from - the Match object. This match group is then used to create a label for the partition. - :param label_whitelist: An optional list of labels (default:None) which are allowed to form a - partition if label_group_id is not None. label_whitelist is the whitelist for the labels - created using label_group_id. If label_whitelist is None, then all the labels created using - label_group_id will form a partition. - :param skip_initial_partition: A boolean value (default:False) that prevents the initial - partition to be saved. - :param default_partition_label: A string value (default:partition) to be used as the default - label for the parts if no label_group_id for the match object is provided. - :param initial_partition_label: A string value (default:None) to be used as a label for the - initial partition. This is only used when skip_initial_partition is False. If it is None - then default_partition_label is used as initial_partition_label. - """ - if isinstance(matcher_or_pattern, str): - matcher = create_regex_matcher(matcher_or_pattern) - else: - matcher = matcher_or_pattern - if initial_partition_label is None: - initial_partition_label = default_partition_label - previous_start = previous_label = None - if not skip_initial_partition: - if label_whitelist is None or initial_partition_label in label_whitelist: - previous_start = 0 - previous_label = initial_partition_label - for match in matcher(text): - if label_group_id is not None: - start = match.start(label_group_id) - end = match.end(label_group_id) - label = text[start:end] - else: - label = default_partition_label - if label_whitelist is None or label in label_whitelist: - if previous_start is not None and previous_label is not None: - start = previous_start - end = match.start() - if strip_whitespace: - start, end = strip_span(start=start, end=end, text=text) - if end - start == 0: - if verbose: - logger.warning( - f"Found empty partition in text at [{previous_start}:{match.start()}] " - f"with potential label: '{previous_label}'. It will be skipped." - ) - else: - span = LabeledSpan(start=start, end=end, label=previous_label) - yield span - - previous_start = match.start() - previous_label = label - - if previous_start is not None and previous_label is not None: - start = previous_start - end = len(text) - if strip_whitespace: - start, end = strip_span(start=start, end=end, text=text) - if end - start == 0: - if verbose: - logger.warning( - f"Found empty partition in text at [{previous_start}:{len(text)}] with potential label: " - f"'{previous_label}'. It will be skipped." - ) - else: - span = LabeledSpan(start=start, end=end, label=previous_label) - yield span - - -class RegexPartitioner(EnterDatasetMixin, ExitDatasetMixin): - """RegexPartitioner partitions a document into multiple partitions using a regular expression. - For more information, refer to get_partitions_with_matcher() method. - - :param pattern: A regular expression to search for in the text. It is also included at the beginning of each partition. - :param collect_statistics: A boolean value (default:False) that allows to collect relevant statistics of the - document after partitioning. When this parameter is enabled, following stats are - collected: - 1. partition_lengths: list of lengths of all partitions - 2. num_partitions: list of number of partitions in each document - 3. document_lengths: list of document lengths - show_statistics can be used to get statistical insight over these lists. - :param partitioner_kwargs: keyword arguments for get_partitions_with_matcher() method - """ - - def __init__( - self, - pattern: str, - collect_statistics: bool = False, - partition_layer_name: str = "partitions", - text_field_name: str = "text", - **partitioner_kwargs, - ): - self.matcher = create_regex_matcher(pattern) - self.partition_layer_name = partition_layer_name - self.text_field_name = text_field_name - self.collect_statistics = collect_statistics - self.reset_statistics() - self.partitioner_kwargs = partitioner_kwargs - - def reset_statistics(self): - self._statistics: dict[str, Any] = { - "partition_lengths": [], - "num_partitions": [], - "document_lengths": [], - } - - def show_statistics(self, description: str | None = None): - description = description or "Statistics" - statistics_show = { - key: { - "min": min(values), - "max": max(values), - "mean": statistics.mean(values), - "stddev": statistics.pstdev(values), - } - for key, values in self._statistics.items() - } - - logger.info(f"{description}: \n{json.dumps(statistics_show, indent=2)}") - - def update_statistics(self, key: str, value: int | str | list): - if self.collect_statistics: - if isinstance(value, list): - self._statistics[key] += value - elif isinstance(value, str) or isinstance(value, int): - self._statistics[key].append(value) - else: - raise TypeError( - f"type of given key [{type(key)}] or value [{type(value)}] is incorrect." - ) - - def __call__(self, document: D) -> D: - partition_lengths = [] - text: str = getattr(document, self.text_field_name) - for partition in _get_partitions_with_matcher( - text=text, matcher_or_pattern=self.matcher, **self.partitioner_kwargs - ): - document[self.partition_layer_name].append(partition) - partition_lengths.append(partition.end - partition.start) - - if self.collect_statistics: - self.update_statistics("num_partitions", len(document[self.partition_layer_name])) - self.update_statistics("partition_lengths", partition_lengths) - self.update_statistics("document_lengths", len(text)) - - return document - - def enter_dataset(self, dataset: Dataset | IterableDataset, name: str | None = None) -> None: - if self.collect_statistics: - self.reset_statistics() - - def exit_dataset(self, dataset: Dataset | IterableDataset, name: str | None = None) -> None: - if self.collect_statistics: - self.show_statistics(description=name) diff --git a/src/pie_datasets/document/processing/relation_argument_sorter.py b/src/pie_datasets/document/processing/relation_argument_sorter.py deleted file mode 100644 index a6e3ad99..00000000 --- a/src/pie_datasets/document/processing/relation_argument_sorter.py +++ /dev/null @@ -1,107 +0,0 @@ -from __future__ import annotations - -import logging -from typing import TypeVar - -from pytorch_ie.annotations import BinaryRelation, LabeledSpan -from pytorch_ie.core import Annotation, AnnotationList, Document - -logger = logging.getLogger(__name__) - - -D = TypeVar("D", bound=Document) - - -def get_relation_args(relation: Annotation) -> tuple[Annotation, ...]: - if isinstance(relation, BinaryRelation): - return relation.head, relation.tail - else: - raise TypeError( - f"relation {relation} has unknown type [{type(relation)}], cannot get arguments from it" - ) - - -def construct_relation_with_new_args( - relation: Annotation, new_args: tuple[Annotation, ...] -) -> BinaryRelation: - if isinstance(relation, BinaryRelation): - return BinaryRelation( - head=new_args[0], - tail=new_args[1], - label=relation.label, - score=relation.score, - ) - else: - raise TypeError( - f"original relation {relation} has unknown type [{type(relation)}], " - f"cannot reconstruct it with new arguments" - ) - - -def has_dependent_layers(document: D, layer: str) -> bool: - return layer not in document._annotation_graph["_artificial_root"] - - -class RelationArgumentSorter: - """Sorts the arguments of the relations in the given relation layer. The sorting is done by the - start and end positions of the arguments. The relations with the same sorted arguments are - merged into one relation. - - Args: - relation_layer: the name of the relation layer - label_whitelist: if not None, only the relations with the label in the whitelist are sorted - inplace: if True, the sorting is done in place, otherwise the document is copied and the sorting is done - on the copy - """ - - def __init__( - self, relation_layer: str, label_whitelist: list[str] | None = None, inplace: bool = True - ): - self.relation_layer = relation_layer - self.label_whitelist = label_whitelist - self.inplace = inplace - - def __call__(self, doc: D) -> D: - if not self.inplace: - doc = doc.copy() - - rel_layer: AnnotationList[BinaryRelation] = doc[self.relation_layer] - args2relations: dict[tuple[LabeledSpan, ...], BinaryRelation] = { - get_relation_args(rel): rel for rel in rel_layer - } - - # assert that no other layers depend on the relation layer - if has_dependent_layers(document=doc, layer=self.relation_layer): - raise ValueError( - f"the relation layer {self.relation_layer} has dependent layers, " - f"cannot sort the arguments of the relations" - ) - - rel_layer.clear() - for args, rel in args2relations.items(): - if self.label_whitelist is not None and rel.label not in self.label_whitelist: - # just add the relations whose label is not in the label whitelist (if a whitelist is present) - rel_layer.append(rel) - else: - args_sorted = tuple(sorted(args, key=lambda arg: (arg.start, arg.end))) - if args == args_sorted: - # if the relation args are already sorted, just add the relation - rel_layer.append(rel) - else: - if args_sorted not in args2relations: - new_rel = construct_relation_with_new_args(rel, args_sorted) - rel_layer.append(new_rel) - else: - prev_rel = args2relations[args_sorted] - if prev_rel.label != rel.label: - raise ValueError( - f"there is already a relation with sorted args {args_sorted} " - f"but with a different label: {prev_rel.label} != {rel.label}" - ) - else: - logger.warning( - f"do not add the new relation with sorted arguments, because it is already there: " - f"{prev_rel}" - ) - - return doc diff --git a/src/pie_datasets/document/processing/text_span_trimmer.py b/src/pie_datasets/document/processing/text_span_trimmer.py deleted file mode 100644 index 8a2ec918..00000000 --- a/src/pie_datasets/document/processing/text_span_trimmer.py +++ /dev/null @@ -1,116 +0,0 @@ -from __future__ import annotations - -import logging -from typing import TypeVar - -from pytorch_ie.annotations import LabeledSpan -from pytorch_ie.core import AnnotationList, Document - -logger = logging.getLogger(__name__) - - -D = TypeVar("D", bound=Document) - - -def trim_text_spans( - document: D, - layer: str, - skip_empty: bool = True, - verbose: bool = True, -) -> D: - """Remove the whitespace at the beginning and end of span annotations that target a text field. - - Args: - document: The document to trim its span annotations. - layer: The name of the span layer to trim. - skip_empty: If True, empty spans will be skipped. Otherwise, an error will be raised. - verbose: If True, log warnings for trimmed spans. - - Returns: - The document with trimmed spans. - """ - annotation_layer_names = {f.name for f in document.annotation_fields()} - result = type(document).fromdict( - {k: v for k, v in document.asdict().items() if k not in annotation_layer_names} - ) - - spans: AnnotationList[LabeledSpan] = document[layer] - - old2new_spans = {} - removed_span_ids = [] - - text = spans.target - - for span in spans: - span_text = text[span.start : span.end] - new_start = span.start + len(span_text) - len(span_text.lstrip()) - new_end = span.end - len(span_text) + len(span_text.rstrip()) - - if new_end <= new_start: - if skip_empty: - if verbose: - logger.warning( - f'Span "{span}" is empty after trimming. Skipping it. (disable this warning with verbose=False)' - ) - removed_span_ids.append(span._id) - continue - else: - if verbose: - logger.warning( - f'Span "{span}" is empty after trimming. Keep it. (disable this warning with verbose=False)' - ) - # if there was only whitespace, we create a span with length 0 at the start of the original span - if new_end < new_start: - new_start = span.start - new_end = span.start - - new_span = LabeledSpan( - start=new_start, - end=new_end, - label=span.label, - score=span.score, - ) - if (span.start != new_span.start or span.end != new_span.end) and verbose: - logger.debug( - f'Trimmed span "{span}" to "{new_span}" (disable this warning with verbose=False)' - ) - old2new_spans[span._id] = new_span - - result[layer].extend(old2new_spans.values()) - result.add_all_annotations_from_other( - document, - override_annotations={layer: old2new_spans}, - removed_annotations={layer: set(removed_span_ids)}, - verbose=verbose, - strict=True, - ) - - return result - - -class TextSpanTrimmer: - """Remove the whitespace at the beginning and end of span annotations that target a text field. - - Args: - layer: The name of the text span layer to trim. - skip_empty: If True, empty spans will be skipped. Otherwise, an error will be raised. - verbose: If True, log warnings for trimmed spans. - """ - - def __init__( - self, - layer: str, - skip_empty: bool = True, - verbose: bool = True, - ): - self.layer = layer - self.skip_empty = skip_empty - self.verbose = verbose - - def __call__(self, document: D) -> D: - return trim_text_spans( - document=document, - layer=self.layer, - skip_empty=self.skip_empty, - verbose=self.verbose, - ) diff --git a/src/pie_datasets/document/processing/tokenization.py b/src/pie_datasets/document/processing/tokenization.py deleted file mode 100644 index e2e81679..00000000 --- a/src/pie_datasets/document/processing/tokenization.py +++ /dev/null @@ -1,302 +0,0 @@ -import functools -import logging -from collections import defaultdict -from copy import copy, deepcopy -from typing import ( - Callable, - Dict, - Iterable, - List, - Optional, - Set, - Tuple, - Type, - TypeVar, - Union, -) - -from pytorch_ie.annotations import Span -from pytorch_ie.core import Annotation -from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument -from pytorch_ie.utils.hydra import resolve_target -from transformers import PreTrainedTokenizer - -logger = logging.getLogger(__name__) - -ToD = TypeVar("ToD", bound=TokenBasedDocument) -TeD = TypeVar("TeD", bound=TextBasedDocument) - - -def text_based_document_to_token_based( - doc: TextBasedDocument, - result_document_type: Union[Type[ToD], str], - tokens: Optional[List[str]] = None, - token_offset_mapping: Optional[List[Tuple[int, int]]] = None, - char_to_token: Optional[Callable[[int], Optional[int]]] = None, - strict_span_conversion: bool = True, - verbose: bool = True, -) -> ToD: - document_type: Type[ToD] - if isinstance(result_document_type, str): - document_type = resolve_target(result_document_type) # type: ignore - else: - document_type = result_document_type - if not (isinstance(document_type, type) and issubclass(document_type, TokenBasedDocument)): - raise TypeError( - f"result_document_type must be a subclass of TokenBasedDocument or a string that resolves to that, " - f"but got {result_document_type}" - ) - if tokens is None: - tokens = doc.metadata.get("tokens") - if tokens is None: - raise ValueError( - "tokens must be provided to convert a text based document to token based, but got None" - ) - result = document_type(tokens=tuple(tokens), id=doc.id, metadata=deepcopy(doc.metadata)) - - # save text, token_offset_mapping and char_to_token (if available) in metadata - result.metadata["text"] = doc.text - token_offset_mapping_lists: Optional[List[List[int]]] - if token_offset_mapping is not None: - # convert offset tuples to lists because serialization and deserialization again - # will produce lists in any way (json does not know tuples) - token_offset_mapping_lists = [list(offsets) for offsets in token_offset_mapping] - if ( - "token_offset_mapping" in doc.metadata - and doc.metadata["token_offset_mapping"] != token_offset_mapping_lists - ): - logger.warning( - "token_offset_mapping in metadata is different from the new token_offset_mapping, " - "overwrite the metadata" - ) - result.metadata["token_offset_mapping"] = token_offset_mapping_lists - else: - token_offset_mapping_lists = doc.metadata.get("token_offset_mapping") - if token_offset_mapping_lists is not None: - token_offset_mapping = [tuple(offsets) for offsets in token_offset_mapping_lists] # type: ignore - if char_to_token is not None: - if "char_to_token" in doc.metadata and doc.metadata["char_to_token"] != char_to_token: - logger.warning( - "char_to_token in metadata is different from the new char_to_token, overwrite the metadata" - ) - result.metadata["char_to_token"] = char_to_token - else: - char_to_token = doc.metadata.get("char_to_token") - - # construct the char_to_token function, if not provided, from the token_offset_mapping - if char_to_token is None: - if token_offset_mapping is None: - raise ValueError( - "either token_offset_mapping or char_to_token must be provided to convert a text " - "based document to token based, but both are None" - ) - char_to_token_dict: Dict[int, int] = {} - for token_idx, (start, end) in enumerate(token_offset_mapping): - for char_idx in range(start, end): - char_to_token_dict[char_idx] = token_idx - - def char_to_token(char_idx: int) -> Optional[int]: - return char_to_token_dict.get(char_idx) - - text_targeting_layers = [ - annotation_field.name - for annotation_field in doc.annotation_fields() - if "text" in annotation_field.metadata["targets"] - ] - - override_annotations: Dict[str, Dict[int, Annotation]] = {} - removed_annotations: Dict[str, Set[int]] = defaultdict(set) - for text_targeting_layer_name in text_targeting_layers: - override_annotations[text_targeting_layer_name] = {} - char_span: Span - for char_span in doc[text_targeting_layer_name]: - if not isinstance(char_span, Span): - raise ValueError( - f"can not convert layers that target the text but contain non-span annotations, " - f"but found {type(char_span)} in layer {text_targeting_layer_name}" - ) - start_token_idx = char_to_token(char_span.start) - end_token_idx_inclusive = char_to_token(char_span.end - 1) - if start_token_idx is None or end_token_idx_inclusive is None: - if strict_span_conversion: - raise ValueError( - f'cannot find token span for character span: "{char_span}", text="{doc.text}", ' - f"token_offset_mapping={token_offset_mapping}" - ) - else: - if verbose: - logger.warning( - f'cannot find token span for character span "{char_span}", skip it (disable this ' - f"warning with verbose=False)" - ) - removed_annotations[text_targeting_layer_name].add(char_span._id) - else: - token_span = char_span.copy(start=start_token_idx, end=end_token_idx_inclusive + 1) - override_annotations[text_targeting_layer_name][char_span._id] = token_span - valid_spans = set(override_annotations[text_targeting_layer_name].values()) - result[text_targeting_layer_name].extend(sorted(valid_spans, key=lambda span: span.start)) - - result.add_all_annotations_from_other( - doc, - override_annotations=override_annotations, - removed_annotations=removed_annotations, - strict=strict_span_conversion, - verbose=verbose, - ) - - return result - - -def token_based_document_to_text_based( - doc: TokenBasedDocument, - result_document_type: Union[Type[TeD], str], - text: Optional[str] = None, - token_offset_mapping: Optional[List[Tuple[int, int]]] = None, - join_tokens_with: Optional[str] = None, - strict_span_conversion: bool = True, - verbose: bool = True, -) -> TeD: - document_type: Type[TeD] - if isinstance(result_document_type, str): - document_type = resolve_target(result_document_type) # type: ignore - else: - document_type = result_document_type - if not (isinstance(document_type, type) and issubclass(document_type, TextBasedDocument)): - raise TypeError( - f"result_document_type must be a subclass of TextBasedDocument or a string that resolves to that, " - f"but got {result_document_type}" - ) - # if a token_separator is provided, we construct the text from the tokens - if join_tokens_with is not None: - start = 0 - token_offset_mapping = [] - tokens = doc.tokens - for token in tokens: - end = start + len(token) - token_offset_mapping.append((start, end)) - # we add the separator after each token - start = end + len(join_tokens_with) - text = join_tokens_with.join(tokens) - else: - text = doc.metadata.get("text") if text is None else text - if text is None: - raise ValueError( - "if join_tokens_with is None, text must be provided, but got None as well" - ) - token_offset_mapping_lists = ( - doc.metadata.get("token_offset_mapping") - if token_offset_mapping is None - else token_offset_mapping - ) - if token_offset_mapping_lists is None: - raise ValueError( - "if join_tokens_with is None, token_offsets must be provided, but got None as well" - ) - else: - token_offset_mapping = [tuple(offsets) for offsets in token_offset_mapping_lists] # type: ignore - - result = document_type(text=text, id=doc.id, metadata=deepcopy(doc.metadata)) - if "tokens" in doc.metadata and doc.metadata["tokens"] != list(doc.tokens): - logger.warning("tokens in metadata are different from new tokens, overwrite the metadata") - result.metadata["tokens"] = list(doc.tokens) - # convert offset tuples to lists because serialization and deserialization again - # will produce lists in any way (json does not know tuples) - token_offset_mapping_lists = [list(offsets) for offsets in token_offset_mapping] - if ( - "token_offset_mapping" in doc.metadata - and doc.metadata["token_offset_mapping"] != token_offset_mapping_lists - ): - logger.warning( - "token_offset_mapping in metadata is different from the new token_offset_mapping, " - "overwrite the metadata" - ) - result.metadata["token_offset_mapping"] = token_offset_mapping_lists - - token_targeting_layers = [ - annotation_field.name - for annotation_field in doc.annotation_fields() - if "tokens" in annotation_field.metadata["targets"] - ] - - override_annotations: Dict[str, Dict[int, Annotation]] = {} - removed_annotations: Dict[str, Set[int]] = defaultdict(set) - for token_targeting_layer_name in token_targeting_layers: - override_annotations[token_targeting_layer_name] = {} - for token_span in doc[token_targeting_layer_name]: - if not isinstance(token_span, Span): - raise ValueError( - f"can not convert layers that target the tokens but contain non-span annotations, " - f"but found {type(token_span)} in layer {token_targeting_layer_name}" - ) - start_char_idx = token_offset_mapping[token_span.start][0] - end_char_idx = token_offset_mapping[token_span.end - 1][1] - - char_span = token_span.copy(start=start_char_idx, end=end_char_idx) - override_annotations[token_targeting_layer_name][token_span._id] = char_span - valid_spans = set(override_annotations[token_targeting_layer_name].values()) - result[token_targeting_layer_name].extend(sorted(valid_spans, key=lambda span: span.start)) - - result.add_all_annotations_from_other( - doc, - override_annotations=override_annotations, - removed_annotations=removed_annotations, - strict=strict_span_conversion, - verbose=verbose, - ) - - return result - - -def tokenize_document( - doc: TextBasedDocument, - tokenizer: PreTrainedTokenizer, - result_document_type: Type[ToD], - partition_layer: Optional[str] = None, - strict_span_conversion: bool = True, - verbose: bool = True, - **tokenize_kwargs, -) -> List[ToD]: - result = [] - partitions: Iterable[Span] - if partition_layer is None: - partitions = [Span(start=0, end=len(doc.text))] - else: - partitions = doc[partition_layer] - for partition in partitions: - text = doc.text[partition.start : partition.end] - current_tokenize_kwargs = copy(tokenize_kwargs) - if "text" in tokenize_kwargs: - current_tokenize_kwargs["text_pair"] = text - sequence_index = 1 - else: - current_tokenize_kwargs["text"] = text - sequence_index = 0 - tokenized_text = tokenizer(**current_tokenize_kwargs) - for batch_encoding in tokenized_text.encodings: - token_offset_mapping = batch_encoding.offsets - char_to_token: Optional[Callable[[int], Optional[int]]] - char_to_token = functools.partial( - batch_encoding.char_to_token, sequence_index=sequence_index - ) - token_offset_mapping = [ - offsets if s_id == sequence_index else (0, 0) - for s_id, offsets in zip(batch_encoding.sequence_ids, token_offset_mapping) - ] - if partition.start > 0: - token_offset_mapping = [ - (start + partition.start, end + partition.start) - for start, end in token_offset_mapping - ] - char_to_token = None - tokenized_document = text_based_document_to_token_based( - doc, - tokens=batch_encoding.tokens, - result_document_type=result_document_type, - token_offset_mapping=token_offset_mapping, - char_to_token=char_to_token, - strict_span_conversion=strict_span_conversion, - verbose=verbose, - ) - tokenized_document.metadata["tokenizer_encoding"] = batch_encoding - result.append(tokenized_document) - return result diff --git a/src/pie_datasets/statistics/__init__.py b/src/pie_datasets/statistics/__init__.py deleted file mode 100644 index 27081d5b..00000000 --- a/src/pie_datasets/statistics/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .span_length_collector import SpanLengthCollector diff --git a/src/pie_datasets/statistics/span_length_collector.py b/src/pie_datasets/statistics/span_length_collector.py deleted file mode 100644 index 667d19d0..00000000 --- a/src/pie_datasets/statistics/span_length_collector.py +++ /dev/null @@ -1,122 +0,0 @@ -import logging -from collections import defaultdict -from typing import Any, Callable, Dict, List, Optional, Type, Union - -from pytorch_ie.annotations import Span -from pytorch_ie.core import Document, DocumentStatistic -from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument -from pytorch_ie.utils.hydra import resolve_optional_document_type -from transformers import AutoTokenizer, PreTrainedTokenizer - -from pie_datasets.document.processing import tokenize_document - -logger = logging.getLogger(__name__) - - -class SpanLengthCollector(DocumentStatistic): - """Collects the lengths of Span annotations. If labels are provided, the lengths collected per - label. - - If a tokenizer is provided, the span length is calculated in means of tokens, otherwise in - means of characters. - """ - - DEFAULT_AGGREGATION_FUNCTIONS = ["len", "mean", "std", "min", "max"] - - def __init__( - self, - layer: str, - tokenize: bool = False, - tokenizer: Optional[Union[str, PreTrainedTokenizer]] = None, - tokenized_document_type: Optional[Union[str, Type[TokenBasedDocument]]] = None, - labels: Optional[Union[List[str], str]] = None, - label_attribute: str = "label", - tokenize_kwargs: Optional[Dict[str, Any]] = None, - **kwargs, - ): - super().__init__(**kwargs) - self.layer = layer - if isinstance(labels, str) and labels != "INFERRED": - raise ValueError("labels must be a list of strings or 'INFERRED'") - if labels == "INFERRED": - logger.warning( - f"Inferring labels with {self.__class__.__name__} from data produces wrong results " - f"for certain aggregation functions (e.g. 'mean', 'std', 'min') because zero values " - f"are not included in the calculation. We remove these aggregation functions from " - f"this collector, but be aware that the results may be wrong for your own aggregation " - f"functions that rely on zero values." - ) - self.aggregation_functions: Dict[str, Callable[[List], Any]] = { - name: func - for name, func in self.aggregation_functions.items() - if name not in ["mean", "std", "min"] - } - self.labels = labels - self.label_field = label_attribute - self.tokenize = tokenize - if self.tokenize: - if tokenizer is None: - raise ValueError( - "tokenizer must be provided to calculate the span length in means of tokens" - ) - if isinstance(tokenizer, str): - tokenizer = AutoTokenizer.from_pretrained(tokenizer) - self.tokenizer = tokenizer - resolved_tokenized_document_type = resolve_optional_document_type( - tokenized_document_type - ) - if resolved_tokenized_document_type is None: - raise ValueError( - "tokenized_document_type must be provided to calculate the span length in means of tokens" - ) - if not ( - isinstance(resolved_tokenized_document_type, type) - and issubclass(resolved_tokenized_document_type, TokenBasedDocument) - ): - raise TypeError( - f"tokenized_document_type must be a subclass of TokenBasedDocument, but it is: " - f"{resolved_tokenized_document_type}" - ) - self.tokenized_document_type = resolved_tokenized_document_type - self.tokenize_kwargs = tokenize_kwargs or {} - - def _collect(self, doc: Document) -> Union[List[int], Dict[str, List[int]]]: - docs: Union[List[Document], List[TokenBasedDocument]] - if self.tokenize: - if not isinstance(doc, TextBasedDocument): - raise ValueError( - "doc must be a TextBasedDocument to calculate the span length in means of tokens" - ) - if not isinstance(doc, TextBasedDocument): - raise ValueError( - "doc must be a TextBasedDocument to calculate the span length in means of tokens" - ) - docs = tokenize_document( - doc, - tokenizer=self.tokenizer, - result_document_type=self.tokenized_document_type, - **self.tokenize_kwargs, - ) - else: - docs = [doc] - - values: Dict[str, List[int]] - if isinstance(self.labels, str): - values = defaultdict(list) - else: - values = {label: [] for label in self.labels or ["ALL"]} - for doc in docs: - layer_obj = getattr(doc, self.layer) - for span in layer_obj: - if not isinstance(span, Span): - raise TypeError( - f"span length calculation is not yet supported for {type(span)}" - ) - length = span.end - span.start - if self.labels is None: - label = "ALL" - else: - label = getattr(span, self.label_field) - values[label].append(length) - - return values if self.labels is not None else values["ALL"] diff --git a/tests/dataset_builders/pie/test_cdcp.py b/tests/dataset_builders/pie/test_cdcp.py index 8963b233..22176d96 100644 --- a/tests/dataset_builders/pie/test_cdcp.py +++ b/tests/dataset_builders/pie/test_cdcp.py @@ -3,6 +3,7 @@ import pytest from datasets import disable_caching, load_dataset +from pie_models.document.processing import tokenize_document from pytorch_ie.annotations import LabeledSpan from pytorch_ie.core import AnnotationList, Document, annotation_field from pytorch_ie.documents import ( @@ -19,7 +20,6 @@ example_to_document, ) from pie_datasets import DatasetDict -from pie_datasets.document.processing import tokenize_document from pie_datasets.document.types import TokenDocumentWithLabeledSpansAndBinaryRelations from tests import FIXTURES_ROOT from tests.dataset_builders.common import PIE_BASE_PATH, _deep_compare diff --git a/tests/dataset_builders/pie/test_scidtb_argmin.py b/tests/dataset_builders/pie/test_scidtb_argmin.py index df7c3338..cb1af028 100644 --- a/tests/dataset_builders/pie/test_scidtb_argmin.py +++ b/tests/dataset_builders/pie/test_scidtb_argmin.py @@ -3,6 +3,7 @@ import pytest from datasets import disable_caching, load_dataset +from pie_models.document.processing import tokenize_document from pytorch_ie.core import Document from pytorch_ie.documents import TextDocumentWithLabeledSpansAndBinaryRelations from transformers import AutoTokenizer, PreTrainedTokenizer @@ -15,7 +16,6 @@ example_to_document, ) from pie_datasets import DatasetDict -from pie_datasets.document.processing import tokenize_document from pie_datasets.document.types import TokenDocumentWithLabeledSpansAndBinaryRelations from tests import FIXTURES_ROOT from tests.dataset_builders.common import HF_DS_FIXTURE_DATA_PATH, PIE_BASE_PATH diff --git a/tests/unit/document/processing/test_regex_partitioner.py b/tests/unit/document/processing/test_regex_partitioner.py deleted file mode 100644 index 23f79779..00000000 --- a/tests/unit/document/processing/test_regex_partitioner.py +++ /dev/null @@ -1,409 +0,0 @@ -import dataclasses -import json -import logging -from typing import Tuple - -import pytest -from pytorch_ie.annotations import LabeledSpan -from pytorch_ie.core import AnnotationList, annotation_field -from pytorch_ie.documents import TextBasedDocument - -from pie_datasets.document.processing import RegexPartitioner -from pie_datasets.document.processing.regex_partitioner import ( - _get_partitions_with_matcher, -) - - -@dataclasses.dataclass -class TextDocumentWithPartitions(TextBasedDocument): - partitions: AnnotationList[LabeledSpan] = annotation_field(target="text") - - -def have_overlap(start_end: Tuple[int, int], other_start_end: Tuple[int, int]) -> bool: - other_start_overlaps = start_end[0] <= other_start_end[0] < start_end[1] - other_end_overlaps = start_end[0] < other_start_end[1] <= start_end[1] - start_overlaps_other = other_start_end[0] <= start_end[0] < other_start_end[1] - end_overlaps_other = other_start_end[0] < start_end[1] <= other_start_end[1] - return other_start_overlaps or other_end_overlaps or start_overlaps_other or end_overlaps_other - - -def test_regex_partitioner(): - TEXT1 = ( - "This is initial text.Jane lives in Berlin. this is no sentence about Karl." - "Seattle is a rainy city. Jenny Durkan is the city's mayor." - "Karl enjoys sunny days in Berlin." - ) - regex_partitioner = RegexPartitioner( - pattern="(||)", - ) - # The document contains a text separated by some markers like , and . RegexPartitioner - # partitions the text based on the given pattern. After partitioning, there are be four partitions with same label. - document = TextDocumentWithPartitions(text=TEXT1) - new_document = regex_partitioner(document) - - partitions = new_document.partitions - labels = [partition.label for partition in partitions] - assert len(partitions) == 4 - assert labels == ["partition"] * len(partitions) - assert str(partitions[0]) == "This is initial text." - assert str(partitions[1]) == "Jane lives in Berlin. this is no sentence about Karl." - assert ( - str(partitions[2]) == "Seattle is a rainy city. Jenny Durkan is the city's mayor." - ) - assert str(partitions[3]) == "Karl enjoys sunny days in Berlin." - - -def test_regex_partitioner_with_statistics(caplog): - TEXT1 = ( - "This is initial text.Jane lives in Berlin. this is no sentence about Karl." - "Seattle is a rainy city. Jenny Durkan is the city's mayor." - "Karl enjoys sunny days in Berlin." - ) - TEXT2 = "This is initial text.Lily is mother of Harry.Beth greets Emma." - - regex_partitioner = RegexPartitioner( - pattern="(||)", - label_group_id=0, - label_whitelist=["", "", ""], - skip_initial_partition=True, - collect_statistics=True, - ) - - # The document contains a text separated by some markers like , and . After partitioning, there - # are three partitions excluding initial part. Therefore, document length is not be equal to sum of partitions. - document = TextDocumentWithPartitions(text=TEXT1) - caplog.set_level(logging.INFO) - caplog.clear() - regex_partitioner.enter_dataset(None) - new_document = regex_partitioner(document) - regex_partitioner.exit_dataset(None) - partitions = new_document.partitions - assert len(partitions) == 3 - - assert len(caplog.records) == 1 - log_description, log_json = caplog.records[0].message.split("\n", maxsplit=1) - assert log_description.strip() == "Statistics:" - assert json.loads(log_json) == { - "partition_lengths": { - "min": 38, - "max": 66, - "mean": 54.666666666666664, - "stddev": 12.036980056845191, - }, - "num_partitions": {"min": 3, "max": 3, "mean": 3, "stddev": 0.0}, - "document_lengths": {"min": 185, "max": 185, "mean": 185, "stddev": 0.0}, - } - - # The document contains a text separated by some markers like and . RegexPartitioner appends statistics - # from each document, therefore statistics contains information from previous document as well. After partitioning, - # there are two partitions excluding initial part. Therefore, the sum of document lengths is not be equal to sum of - # partitions. - document = TextDocumentWithPartitions(text=TEXT2) - caplog.set_level(logging.INFO) - caplog.clear() - regex_partitioner.enter_dataset(None) - new_document = regex_partitioner(document) - regex_partitioner.exit_dataset(None) - partitions = new_document.partitions - assert len(partitions) == 2 - - assert len(caplog.records) == 1 - log_description, log_json = caplog.records[0].message.split("\n", maxsplit=1) - assert log_description.strip() == "Statistics:" - assert json.loads(log_json) == { - "partition_lengths": {"min": 22, "max": 31, "mean": 26.5, "stddev": 4.5}, - "num_partitions": {"min": 2, "max": 2, "mean": 2, "stddev": 0.0}, - "document_lengths": {"min": 74, "max": 74, "mean": 74, "stddev": 0.0}, - } - - with pytest.raises( - TypeError, - match=r"type of given key \[\] or value \[\] is incorrect.", - ): - regex_partitioner.update_statistics("num_partitions", 1.0) - - regex_partitioner.show_statistics() - - -@pytest.mark.parametrize("label_whitelist", [["", "", ""], [], None]) -@pytest.mark.parametrize("skip_initial_partition", [True, False]) -def test_regex_partitioner_without_label_group_id(label_whitelist, skip_initial_partition): - TEXT1 = ( - "This is initial text.Jane lives in Berlin. this is no sentence about Karl." - "Seattle is a rainy city. Jenny Durkan is the city's mayor." - "Karl enjoys sunny days in Berlin." - ) - regex_partitioner = RegexPartitioner( - pattern="(||)", - label_whitelist=label_whitelist, - skip_initial_partition=skip_initial_partition, - ) - # The document contains a text separated by some markers like , and . Since label_group_id is - # None, the partitions (if any) will have same label. - document = TextDocumentWithPartitions(text=TEXT1) - new_document = regex_partitioner(document) - partitions = new_document.partitions - assert [partition.label for partition in partitions] == ["partition"] * len(partitions) - if skip_initial_partition: - if label_whitelist == ["", "", ""] or label_whitelist == []: - # Since label_group_id is None, no label will be created using the matched pattern. Therefore, the default - # partition label is used but since it is not in label_whitelist, no partition is created. - assert len(partitions) == 0 - else: # label_whitelist is None - # since label_whitelist and label_group_id is None and skip_initial_partition is True, three partitions are - # created with the same label - assert len(partitions) == 3 - assert ( - str(partitions[0]) - == "Jane lives in Berlin. this is no sentence about Karl." - ) - assert ( - str(partitions[1]) - == "Seattle is a rainy city. Jenny Durkan is the city's mayor." - ) - assert str(partitions[2]) == "Karl enjoys sunny days in Berlin." - else: # skip_initial_partition is False - if label_whitelist == ["", "", ""] or label_whitelist == []: - # Since label_group_id is None, no label will be created using the matched pattern. Therefore, the default - # partition label is used but since it is not in label_whitelist, no partition is created. - assert len(partitions) == 0 - else: # label_whitelist is None - # since label_whitelist and label_group_id is None and skip_initial_partition is False, four partitions are - # created with the same label. - assert len(partitions) == 4 - assert str(partitions[0]) == "This is initial text." - assert ( - str(partitions[1]) - == "Jane lives in Berlin. this is no sentence about Karl." - ) - assert ( - str(partitions[2]) - == "Seattle is a rainy city. Jenny Durkan is the city's mayor." - ) - assert str(partitions[3]) == "Karl enjoys sunny days in Berlin." - - -@pytest.mark.parametrize( - "label_whitelist", [["partition", "", ""], ["", ""], [], None] -) -@pytest.mark.parametrize("skip_initial_partition", [True, False]) -def test_regex_partitioner_with_label_group_id(label_whitelist, skip_initial_partition): - TEXT1 = ( - "This is initial text.Jane lives in Berlin. this is no sentence about Karl." - "Seattle is a rainy city. Jenny Durkan is the city's mayor." - "Karl enjoys sunny days in Berlin." - ) - regex_partitioner = RegexPartitioner( - pattern="(||)", - label_group_id=0, - label_whitelist=label_whitelist, - skip_initial_partition=skip_initial_partition, - ) - # The document contains a text separated by some markers like , and . Possible partitions can - # be four including the initial partition. - document = TextDocumentWithPartitions(text=TEXT1) - new_document = regex_partitioner(document) - partitions = new_document.partitions - labels = [partition.label for partition in partitions] - if skip_initial_partition: - if label_whitelist == ["", ""] or label_whitelist == [ - "partition", - "", - "", - ]: - # Since skip_initial_partition is True, therefore even if initial_partition_label is in label_whitelist, it - # will not be added as a partition. - assert len(partitions) == 2 - assert labels == ["", ""] - assert ( - str(partitions[0]) - == "Jane lives in Berlin. this is no sentence about Karl.Seattle is a rainy city. Jenny Durkan is the city's mayor." - ) - assert str(partitions[1]) == "Karl enjoys sunny days in Berlin." - elif label_whitelist == []: - # Even though labels are created using label_group_id, label_whitelist is empty. Therefore, no partition will - # be created. - assert len(partitions) == 0 - else: # label_whitelist is None - # Since label_whitelist is None, all the labels formed using label_group_id will create a partition. - assert len(partitions) == 3 - assert labels == ["", "", ""] - assert ( - str(partitions[0]) - == "Jane lives in Berlin. this is no sentence about Karl." - ) - assert ( - str(partitions[1]) - == "Seattle is a rainy city. Jenny Durkan is the city's mayor." - ) - assert str(partitions[2]) == "Karl enjoys sunny days in Berlin." - else: # skip_initial_partition is False - if label_whitelist == ["", ""]: - # Though skip_initial_partition is False it is not in label_whitelist, therefore not added as a partition. - assert len(partitions) == 2 - assert labels == ["", ""] - assert ( - str(partitions[0]) - == "Jane lives in Berlin. this is no sentence about Karl.Seattle is a rainy city. Jenny Durkan is the city's mayor." - ) - assert str(partitions[1]) == "Karl enjoys sunny days in Berlin." - elif label_whitelist == ["partition", "", ""]: - # Since initial partition label is in label_whitelist, therefore it will form a partition in the document. - assert len(partitions) == 3 - assert labels == ["partition", "", ""] - assert str(partitions[0]) == "This is initial text." - assert ( - str(partitions[1]) - == "Jane lives in Berlin. this is no sentence about Karl.Seattle is a rainy city. Jenny Durkan is the city's mayor." - ) - assert str(partitions[2]) == "Karl enjoys sunny days in Berlin." - elif label_whitelist == []: - # Even though labels are created using label_group_id, label_whitelist is empty. Therefore, no partition will - # be created. - assert len(partitions) == 0 - else: # label_whitelist is None - # Since label_whitelist is None, all the labels formed using label_group_id will create a partition. In - # addition to that the initial partition will also be added to the document. - assert len(partitions) == 4 - assert labels == ["partition", "", "", ""] - assert str(partitions[0]) == "This is initial text." - assert ( - str(partitions[1]) - == "Jane lives in Berlin. this is no sentence about Karl." - ) - assert ( - str(partitions[2]) - == "Seattle is a rainy city. Jenny Durkan is the city's mayor." - ) - assert str(partitions[3]) == "Karl enjoys sunny days in Berlin." - - -@pytest.mark.parametrize("label_whitelist", [["partition"], [], None]) -@pytest.mark.parametrize("skip_initial_partition", [True, False]) -def test_regex_partitioner_with_no_match_found(skip_initial_partition, label_whitelist): - TEXT2 = "This is initial text.Lily is mother of Harry.Beth greets Emma." - regex_partitioner = RegexPartitioner( - pattern="()", - label_group_id=0, - label_whitelist=label_whitelist, - skip_initial_partition=skip_initial_partition, - ) - # The document contains a text separated by some markers like and . Only possible partition in the - # document based on the given pattern is the initial partition. - document = TextDocumentWithPartitions(text=TEXT2) - new_document = regex_partitioner(document) - - partitions = new_document.partitions - if skip_initial_partition: - # No matter what the value of label_whitelist is, there will be no partition created, since the given pattern - # is not in the document and skip_initial_partition is True. - if label_whitelist == ["partition"]: - assert len(partitions) == 0 - elif label_whitelist == []: - assert len(partitions) == 0 - else: # label_whitelist is None - assert len(partitions) == 0 - else: - if label_whitelist == ["partition"]: - # Since initial_partition_label is contained in label_whitelist, the initial partition will be added to the - # document. - assert len(partitions) == 1 - assert str(partitions[0]) == TEXT2 - assert partitions[0].label == "partition" - elif label_whitelist == []: - # Even though skip_initial_partition is False, initial_partition_label is not contained in label_whitelist. - # Therefore, the initial partition will not be added to the document. - assert len(partitions) == 0 - else: # label_whitelist is None - # Since label_whitelist is None and skip_initial_partition is False, the initial partition will be added to - # the document. - assert len(partitions) == 1 - assert str(partitions[0]) == TEXT2 - assert partitions[0].label == "partition" - - -def test_get_partitions_with_matcher(): - TEXT1 = ( - "This is initial text.Jane lives in Berlin. this is no sentence about Karl." - "Seattle is a rainy city. Jenny Durkan is the city's mayor." - "Karl enjoys sunny days in Berlin." - ) - # The document contains a text separated by some markers like , and . finditer method is used - # which returns non overlapping match from the text. Therefore, none of the partition created should have overlapped - # span and all of them should be instances of LabeledSpan. - document = TextDocumentWithPartitions(text=TEXT1) - partitions = [] - for partition in _get_partitions_with_matcher( - text=document.text, - matcher_or_pattern="(||)", - label_group_id=0, - label_whitelist=["", "", ""], - ): - assert isinstance(partition, LabeledSpan) - for p in partitions: - assert not have_overlap((p.start, p.end), (partition.start, partition.end)) - partitions.append(partition) - - -@pytest.mark.parametrize( - "strip_whitespace, verbose", - [ - (False, False), - (False, True), - (True, False), - (True, True), - ], -) -def test_regex_partitioner_with_strip_whitespace(strip_whitespace, verbose, caplog): - TEXT1 = ( - "\nThis is initial text. Jane lives in Berlin. this is no sentence about Karl.\n" - "Seattle is a rainy city. Jenny Durkan is the city's mayor.\n\n" - "Karl enjoys sunny days in Berlin.\n" - ) - regex_partitioner = RegexPartitioner( - pattern="\n", - strip_whitespace=strip_whitespace, - verbose=verbose, - ) - document = TextDocumentWithPartitions(text=TEXT1) - new_document = regex_partitioner(document) - - partitions = new_document.partitions - labels = [partition.label for partition in partitions] - if strip_whitespace: - assert len(partitions) == 3 - assert labels == ["partition"] * len(partitions) - assert ( - str(partitions[0]) - == "This is initial text. Jane lives in Berlin. this is no sentence about Karl." - ) - assert str(partitions[1]) == "Seattle is a rainy city. Jenny Durkan is the city's mayor." - assert str(partitions[2]) == "Karl enjoys sunny days in Berlin." - if verbose: - assert len(caplog.messages) == 3 - assert caplog.messages[0] == ( - "Found empty partition in text at [0:0] with potential label: 'partition'. It will be skipped." - ) - assert caplog.messages[1] == ( - "Found empty partition in text at [135:136] with potential label: 'partition'. It will be skipped." - ) - assert caplog.messages[2] == ( - "Found empty partition in text at [170:171] with potential label: 'partition'. It will be skipped." - ) - else: - assert len(partitions) == 5 - assert labels == ["partition"] * len(partitions) - assert ( - str(partitions[0]) - == "\nThis is initial text. Jane lives in Berlin. this is no sentence about Karl." - ) - assert str(partitions[1]) == "\nSeattle is a rainy city. Jenny Durkan is the city's mayor." - assert str(partitions[2]) == "\n" - assert str(partitions[3]) == "\nKarl enjoys sunny days in Berlin." - assert str(partitions[4]) == "\n" - if verbose: - assert len(caplog.messages) == 1 - assert ( - caplog.messages[0] - == "Found empty partition in text at [0:0] with potential label: 'partition'. It will be skipped." - ) diff --git a/tests/unit/document/processing/test_relation_argument_sorter.py b/tests/unit/document/processing/test_relation_argument_sorter.py deleted file mode 100644 index 6112deec..00000000 --- a/tests/unit/document/processing/test_relation_argument_sorter.py +++ /dev/null @@ -1,256 +0,0 @@ -import dataclasses -import logging - -import pytest -from pytorch_ie import Annotation, AnnotationLayer, annotation_field -from pytorch_ie.annotations import BinaryRelation, LabeledSpan, NaryRelation -from pytorch_ie.documents import ( - TextBasedDocument, - TextDocumentWithLabeledSpans, - TextDocumentWithLabeledSpansAndBinaryRelations, -) - -from pie_datasets.document.processing import RelationArgumentSorter -from pie_datasets.document.processing.relation_argument_sorter import ( - construct_relation_with_new_args, - get_relation_args, -) - - -@pytest.fixture -def document(): - doc = TextDocumentWithLabeledSpansAndBinaryRelations( - text="Entity G works at H. And founded I." - ) - doc.labeled_spans.append(LabeledSpan(start=0, end=8, label="PER")) - assert str(doc.labeled_spans[0]) == "Entity G" - doc.labeled_spans.append(LabeledSpan(start=18, end=19, label="ORG")) - assert str(doc.labeled_spans[1]) == "H" - doc.labeled_spans.append(LabeledSpan(start=33, end=34, label="ORG")) - assert str(doc.labeled_spans[2]) == "I" - - return doc - - -@pytest.mark.parametrize("inplace", [True, False]) -def test_relation_argument_sorter(document, inplace): - # these arguments are not sorted - document.binary_relations.append( - BinaryRelation( - head=document.labeled_spans[1], tail=document.labeled_spans[0], label="worksAt" - ) - ) - # these arguments are sorted - document.binary_relations.append( - BinaryRelation( - head=document.labeled_spans[0], tail=document.labeled_spans[2], label="founded" - ) - ) - - arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=inplace) - doc_sorted_args = arg_sorter(document) - - assert document.text == doc_sorted_args.text - assert document.labeled_spans == doc_sorted_args.labeled_spans - assert len(doc_sorted_args.binary_relations) == len(document.binary_relations) - - # this relation should be sorted - assert str(doc_sorted_args.binary_relations[0].head) == "Entity G" - assert str(doc_sorted_args.binary_relations[0].tail) == "H" - assert doc_sorted_args.binary_relations[0].label == "worksAt" - - # this relation should be the same as before - assert str(doc_sorted_args.binary_relations[1].head) == "Entity G" - assert str(doc_sorted_args.binary_relations[1].tail) == "I" - assert doc_sorted_args.binary_relations[1].label == "founded" - - if inplace: - assert document == doc_sorted_args - else: - assert document != doc_sorted_args - - -@pytest.fixture -def document_with_nary_relation(): - @dataclasses.dataclass - class TextDocumentWithLabeledSpansAndNaryRelations(TextDocumentWithLabeledSpans): - nary_relations: AnnotationLayer[NaryRelation] = annotation_field(target="labeled_spans") - - doc = TextDocumentWithLabeledSpansAndNaryRelations(text="Entity G works at H. And founded I.") - doc.labeled_spans.append(LabeledSpan(start=0, end=8, label="PER")) - assert str(doc.labeled_spans[0]) == "Entity G" - doc.labeled_spans.append(LabeledSpan(start=18, end=19, label="ORG")) - assert str(doc.labeled_spans[1]) == "H" - doc.labeled_spans.append(LabeledSpan(start=33, end=34, label="ORG")) - assert str(doc.labeled_spans[2]) == "I" - - doc.nary_relations.append( - NaryRelation( - arguments=(doc.labeled_spans[0], doc.labeled_spans[1], doc.labeled_spans[2]), - roles=("person", "worksAt", "founded"), - label="event", - ) - ) - - return doc - - -def test_get_args_wrong_type(document_with_nary_relation): - with pytest.raises(TypeError) as excinfo: - get_relation_args(document_with_nary_relation.nary_relations[0]) - assert ( - str(excinfo.value) - == "relation NaryRelation(arguments=(LabeledSpan(start=0, end=8, label='PER', score=1.0), " - "LabeledSpan(start=18, end=19, label='ORG', score=1.0), LabeledSpan(start=33, end=34, " - "label='ORG', score=1.0)), roles=('person', 'worksAt', 'founded'), label='event', score=1.0) " - "has unknown type [], cannot get arguments from it" - ) - - -def test_construct_relation_with_new_args_wrong_type(document_with_nary_relation): - with pytest.raises(TypeError) as excinfo: - construct_relation_with_new_args( - document_with_nary_relation.nary_relations[0], - ( - document_with_nary_relation.labeled_spans[0], - document_with_nary_relation.labeled_spans[1], - ), - ) - assert ( - str(excinfo.value) - == "original relation NaryRelation(arguments=(LabeledSpan(start=0, end=8, label='PER', score=1.0), " - "LabeledSpan(start=18, end=19, label='ORG', score=1.0), LabeledSpan(start=33, end=34, label='ORG', " - "score=1.0)), roles=('person', 'worksAt', 'founded'), label='event', score=1.0) has unknown type " - "[], cannot reconstruct it with new arguments" - ) - - -def test_relation_argument_sorter_with_label_whitelist(document): - # argument of both relations are not sorted - document.binary_relations.append( - BinaryRelation( - head=document.labeled_spans[1], tail=document.labeled_spans[0], label="worksAt" - ) - ) - document.binary_relations.append( - BinaryRelation( - head=document.labeled_spans[2], tail=document.labeled_spans[0], label="founded" - ) - ) - - # we only want to sort the relations with the label "founded" - arg_sorter = RelationArgumentSorter( - relation_layer="binary_relations", label_whitelist=["founded"], inplace=False - ) - doc_sorted_args = arg_sorter(document) - - assert document.text == doc_sorted_args.text - assert document.labeled_spans == doc_sorted_args.labeled_spans - - # this relation should be the same as before - assert doc_sorted_args.binary_relations[0] == document.binary_relations[0] - - # this relation should be sorted - assert doc_sorted_args.binary_relations[1] != document.binary_relations[1] - assert str(doc_sorted_args.binary_relations[1].head) == "Entity G" - assert str(doc_sorted_args.binary_relations[1].tail) == "I" - assert doc_sorted_args.binary_relations[1].label == "founded" - - -def test_relation_argument_sorter_sorted_rel_already_exists_with_same_label(document, caplog): - document.binary_relations.append( - BinaryRelation( - head=document.labeled_spans[1], tail=document.labeled_spans[0], label="worksAt" - ) - ) - document.binary_relations.append( - BinaryRelation( - head=document.labeled_spans[0], tail=document.labeled_spans[1], label="worksAt" - ) - ) - - arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False) - - caplog.clear() - with caplog.at_level(logging.WARNING): - doc_sorted_args = arg_sorter(document) - - assert len(caplog.records) == 1 - assert caplog.records[0].levelname == "WARNING" - assert ( - caplog.records[0].message - == "do not add the new relation with sorted arguments, because it is already there: " - "BinaryRelation(head=LabeledSpan(start=0, end=8, label='PER', score=1.0), " - "tail=LabeledSpan(start=18, end=19, label='ORG', score=1.0), label='worksAt', score=1.0)" - ) - - assert document.text == doc_sorted_args.text - assert document.labeled_spans == doc_sorted_args.labeled_spans - - # since there is already a relation with the same label and sorted arguments, - # there should be just one relation in the end - assert len(doc_sorted_args.binary_relations) == 1 - assert str(doc_sorted_args.binary_relations[0].head) == "Entity G" - assert str(doc_sorted_args.binary_relations[0].tail) == "H" - - -def test_relation_argument_sorter_sorted_rel_already_exists_with_different_label(document): - document.binary_relations.append( - BinaryRelation( - head=document.labeled_spans[1], tail=document.labeled_spans[0], label="worksAt" - ) - ) - document.binary_relations.append( - BinaryRelation( - head=document.labeled_spans[0], tail=document.labeled_spans[1], label="founded" - ) - ) - - arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False) - - with pytest.raises(ValueError) as excinfo: - arg_sorter(document) - assert ( - str(excinfo.value) - == "there is already a relation with sorted args (LabeledSpan(start=0, end=8, label='PER', score=1.0), " - "LabeledSpan(start=18, end=19, label='ORG', score=1.0)) but with a different label: founded != worksAt" - ) - - -def test_relation_argument_sorter_with_dependent_layers(): - @dataclasses.dataclass(frozen=True) - class Attribute(Annotation): - annotation: Annotation - label: str - - @dataclasses.dataclass - class ExampleDocument(TextBasedDocument): - labeled_spans: AnnotationLayer[LabeledSpan] = annotation_field(target="text") - binary_relations: AnnotationLayer[BinaryRelation] = annotation_field( - target="labeled_spans" - ) - relation_attributes: AnnotationLayer[Attribute] = annotation_field( - target="binary_relations" - ) - - doc = ExampleDocument(text="Entity G works at H. And founded I.") - doc.labeled_spans.append(LabeledSpan(start=0, end=8, label="PER")) - assert str(doc.labeled_spans[0]) == "Entity G" - doc.labeled_spans.append(LabeledSpan(start=18, end=19, label="ORG")) - assert str(doc.labeled_spans[1]) == "H" - doc.binary_relations.append( - BinaryRelation(head=doc.labeled_spans[1], tail=doc.labeled_spans[0], label="worksAt") - ) - doc.relation_attributes.append( - Attribute(annotation=doc.binary_relations[0], label="some_attribute") - ) - - arg_sorter = RelationArgumentSorter(relation_layer="binary_relations", inplace=False) - - with pytest.raises(ValueError) as excinfo: - arg_sorter(doc) - - assert ( - str(excinfo.value) - == "the relation layer binary_relations has dependent layers, cannot sort the arguments of the relations" - ) diff --git a/tests/unit/document/processing/test_text_span_trimmer.py b/tests/unit/document/processing/test_text_span_trimmer.py deleted file mode 100644 index 1b148040..00000000 --- a/tests/unit/document/processing/test_text_span_trimmer.py +++ /dev/null @@ -1,119 +0,0 @@ -import dataclasses - -import pytest -from pytorch_ie.annotations import BinaryRelation, LabeledSpan -from pytorch_ie.core import AnnotationList, annotation_field -from pytorch_ie.documents import TextBasedDocument - -from pie_datasets.document.processing import TextSpanTrimmer - - -@dataclasses.dataclass -class DocumentWithEntitiesRelationsAndPartitions(TextBasedDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") - partitions: AnnotationList[LabeledSpan] = annotation_field(target="text") - - -@pytest.fixture -def document1() -> DocumentWithEntitiesRelationsAndPartitions: - TEXT1 = "Jane lives in Berlin. this is a truncated sentence about Karl\n " - ENTITY_JANE_TEXT1 = LabeledSpan(start=0, end=4, label="person") - ENTITY_BERLIN_TEXT1 = LabeledSpan(start=13, end=20, label="city") - ENTITY_KARL_TEXT1 = LabeledSpan(start=57, end=61, label="person") - ENTITY_EMPTY_TEXT1 = LabeledSpan(start=62, end=65, label="other") - SENTENCE1_TEXT1 = LabeledSpan(start=0, end=21, label="sentence") - SENTENCE2_TEXT1 = LabeledSpan(start=22, end=65, label="sentence") - REL_JANE_LIVES_IN_BERLIN = BinaryRelation( - head=ENTITY_JANE_TEXT1, tail=ENTITY_BERLIN_TEXT1, label="lives_in" - ) - REL_KARL_HAS_NOTHING = BinaryRelation( - head=ENTITY_KARL_TEXT1, tail=ENTITY_EMPTY_TEXT1, label="has_nothing" - ) - - document = DocumentWithEntitiesRelationsAndPartitions(text=TEXT1) - document.entities.extend( - [ENTITY_JANE_TEXT1, ENTITY_BERLIN_TEXT1, ENTITY_KARL_TEXT1, ENTITY_EMPTY_TEXT1] - ) - document.partitions.extend([SENTENCE1_TEXT1, SENTENCE2_TEXT1]) - document.relations.extend([REL_JANE_LIVES_IN_BERLIN, REL_KARL_HAS_NOTHING]) - - assert str(document.entities[0]) == "Jane" - assert str(document.entities[1]) == " Berlin" - assert str(document.entities[2]) == "Karl" - assert str(document.entities[3]) == " " - assert str(document.partitions[0]) == "Jane lives in Berlin." - assert str(document.partitions[1]) == "this is a truncated sentence about Karl\n " - - assert str(document.relations[0].tail) == " Berlin" - assert str(document.relations[0].head) == "Jane" - assert str(document.relations[0].label) == "lives_in" - assert str(document.relations[1].tail) == " " - assert str(document.relations[1].head) == "Karl" - assert str(document.relations[1].label) == "has_nothing" - - return document - - -@pytest.mark.parametrize( - "layer,skip_empty", - [ - ("entities", False), - ("partitions", False), - ("partitions", True), - ], -) -def test_text_span_trimmer(document1, layer, skip_empty): - trimmer = TextSpanTrimmer(layer=layer, skip_empty=skip_empty) - processed_document = trimmer(document1) - - assert len(document1.entities) == 4 - assert len(document1.relations) == 2 - assert len(processed_document.partitions) == len(document1.partitions) == 2 - - if layer == "entities" and not skip_empty: - assert len(processed_document.entities) == 4 - assert len(processed_document.relations) == 2 - assert str(processed_document.entities[0]) == "Jane" - assert str(processed_document.entities[1]) == "Berlin" - assert str(processed_document.entities[2]) == "Karl" - assert str(processed_document.entities[3]) == "" - assert str(processed_document.partitions[0]) == "Jane lives in Berlin." - assert ( - str(processed_document.partitions[1]) == "this is a truncated sentence about Karl\n " - ) - assert str(processed_document.relations[0].tail) == "Berlin" - assert str(processed_document.relations[0].head) == "Jane" - assert str(processed_document.relations[0].label) == "lives_in" - assert str(processed_document.relations[1].tail) == "" - assert str(processed_document.relations[1].head) == "Karl" - assert str(processed_document.relations[1].label) == "has_nothing" - elif layer == "partitions": - assert len(processed_document.entities) == 4 - assert str(processed_document.entities[0]) == "Jane" - assert str(processed_document.entities[1]) == " Berlin" - assert str(processed_document.entities[2]) == "Karl" - assert str(processed_document.entities[3]) == " " - assert str(processed_document.partitions[0]) == "Jane lives in Berlin." - assert str(processed_document.partitions[1]) == "this is a truncated sentence about Karl" - assert str(processed_document.relations[0].tail) == " Berlin" - assert str(processed_document.relations[0].head) == "Jane" - assert str(processed_document.relations[0].label) == "lives_in" - assert str(processed_document.relations[1].tail) == " " - assert str(processed_document.relations[1].head) == "Karl" - assert str(processed_document.relations[1].label) == "has_nothing" - else: - raise ValueError(f"Unknown parameter combination: layer={layer}, skip_empty={skip_empty}") - - -def test_text_span_trimmer_remove_entity_of_relations(document1): - trimmer = TextSpanTrimmer(layer="entities", skip_empty=True) - with pytest.raises(ValueError) as excinfo: - processed_document = trimmer(document1) - assert ( - str(excinfo.value) - == "Could not add annotation BinaryRelation(head=LabeledSpan(start=57, end=61, label='person', score=1.0), " - "tail=LabeledSpan(start=62, end=65, label='other', score=1.0), label='has_nothing', score=1.0) " - "to DocumentWithEntitiesRelationsAndPartitions because it depends on annotations that are not present " - "in the document." - ) diff --git a/tests/unit/document/processing/test_tokenization.py b/tests/unit/document/processing/test_tokenization.py deleted file mode 100644 index c80ee260..00000000 --- a/tests/unit/document/processing/test_tokenization.py +++ /dev/null @@ -1,558 +0,0 @@ -import dataclasses - -import pytest -from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span -from pytorch_ie.core import AnnotationList, annotation_field -from pytorch_ie.documents import TokenBasedDocument -from transformers import AutoTokenizer, PreTrainedTokenizer - -from pie_datasets.document.processing import ( - text_based_document_to_token_based, - token_based_document_to_text_based, - tokenize_document, -) -from tests.conftest import TestDocument - - -@dataclasses.dataclass -class TokenizedTestDocument(TokenBasedDocument): - sentences: AnnotationList[Span] = annotation_field(target="tokens") - entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens") - relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") - - -@pytest.fixture(scope="module") -def tokenizer() -> PreTrainedTokenizer: - return AutoTokenizer.from_pretrained("bert-base-cased") - - -def test_text_based_document_to_token_based(documents, tokenizer): - assert len(documents) >= 3 - for i, doc in enumerate(documents[:3]): - tokenized_text = tokenizer(doc.text, return_offsets_mapping=True) - tokenized_doc = text_based_document_to_token_based( - doc, - tokens=tokenized_text.tokens(), - result_document_type=TokenizedTestDocument, - # to increase test coverage - token_offset_mapping=None if i == 1 else tokenized_text.offset_mapping, - # to increase test coverage - char_to_token=None if i == 0 else tokenized_text.char_to_token, - ) - assert tokenized_doc is not None - - # check (de-)serialization - tokenized_doc.copy() - - offset_mapping_lists = [list(offsets) for offsets in tokenized_text.offset_mapping] - if i == 0: - assert doc.id == "train_doc1" - assert tokenized_doc.metadata["text"] == doc.text == "A single sentence." - assert tokenized_doc.metadata["token_offset_mapping"] == offset_mapping_lists - assert tokenized_doc.metadata.get("char_to_token") is None - assert tokenized_doc.tokens == ("[CLS]", "A", "single", "sentence", ".", "[SEP]") - assert len(tokenized_doc.sentences) == len(doc.sentences) == 1 - assert str(doc.sentences[0]) == "A single sentence." - assert str(tokenized_doc.sentences[0]) == "('A', 'single', 'sentence', '.')" - assert len(tokenized_doc.entities) == len(doc.entities) == 0 - assert len(tokenized_doc.relations) == len(doc.relations) == 0 - elif i == 1: - assert doc.id == "train_doc2" - assert tokenized_doc.metadata["text"] == doc.text == "Entity A works at B." - assert tokenized_doc.metadata.get("token_offset_mapping") is None - assert tokenized_doc.metadata["char_to_token"] == tokenized_text.char_to_token - assert tokenized_doc.tokens == ( - "[CLS]", - "En", - "##ti", - "##ty", - "A", - "works", - "at", - "B", - ".", - "[SEP]", - ) - assert len(tokenized_doc.sentences) == len(doc.sentences) == 1 - assert str(doc.sentences[0]) == "Entity A works at B." - assert ( - str(tokenized_doc.sentences[0]) - == "('En', '##ti', '##ty', 'A', 'works', 'at', 'B', '.')" - ) - assert len(tokenized_doc.entities) == len(doc.entities) == 2 - assert str(doc.entities[0]) == "Entity A" - assert str(tokenized_doc.entities[0]) == "('En', '##ti', '##ty', 'A')" - assert str(doc.entities[1]) == "B" - assert str(tokenized_doc.entities[1]) == "('B',)" - assert len(tokenized_doc.relations) == len(doc.relations) == 1 - assert doc.relations[0].head == doc.entities[0] - assert tokenized_doc.relations[0].head == tokenized_doc.entities[0] - assert doc.relations[0].tail == doc.entities[1] - assert tokenized_doc.relations[0].tail == tokenized_doc.entities[1] - elif i == 2: - assert doc.id == "train_doc3" - assert tokenized_doc.metadata["text"] == doc.text == "Entity C and D." - assert tokenized_doc.metadata["token_offset_mapping"] == offset_mapping_lists - assert tokenized_doc.metadata["char_to_token"] == tokenized_text.char_to_token - assert tokenized_doc.tokens == ( - "[CLS]", - "En", - "##ti", - "##ty", - "C", - "and", - "D", - ".", - "[SEP]", - ) - assert len(tokenized_doc.sentences) == len(doc.sentences) == 1 - assert str(doc.sentences[0]) == "Entity C and D." - assert ( - str(tokenized_doc.sentences[0]) == "('En', '##ti', '##ty', 'C', 'and', 'D', '.')" - ) - assert len(tokenized_doc.entities) == len(doc.entities) == 2 - assert str(doc.entities[0]) == "Entity C" - assert str(tokenized_doc.entities[0]) == "('En', '##ti', '##ty', 'C')" - assert str(doc.entities[1]) == "D" - assert str(tokenized_doc.entities[1]) == "('D',)" - assert len(tokenized_doc.relations) == len(doc.relations) == 0 - else: - raise ValueError(f"Unexpected document: {doc.id}") - - -def test_text_based_document_to_token_based_missing_args(documents, tokenizer): - with pytest.raises(ValueError) as excinfo: - doc = documents[0] - tokenized_text = tokenizer(doc.text) - tokenized_doc = text_based_document_to_token_based( - doc, - tokens=tokenized_text.tokens(), - result_document_type=TokenizedTestDocument, - ) - assert ( - str(excinfo.value) - == "either token_offset_mapping or char_to_token must be provided to convert a text based document " - "to token based, but both are None" - ) - - -def test_text_based_document_to_token_based_unaligned_span_strict(documents, tokenizer): - doc = documents[0].copy() - # add a span that is not aligned with the tokenization - doc.entities.append(LabeledSpan(start=0, end=2, label="unaligned")) - assert str(doc.entities[-1]) == "A " - tokenized_text = tokenizer(doc.text, return_offsets_mapping=True) - with pytest.raises(ValueError) as excinfo: - tokenized_doc = text_based_document_to_token_based( - doc, - tokens=tokenized_text.tokens(), - result_document_type=TokenizedTestDocument, - # to increase test coverage - token_offset_mapping=tokenized_text.offset_mapping, - # to increase test coverage - char_to_token=tokenized_text.char_to_token, - ) - assert ( - str(excinfo.value) - == 'cannot find token span for character span: "A ", text="A single sentence.", ' - "token_offset_mapping=[(0, 0), (0, 1), (2, 8), (9, 17), (17, 18), (0, 0)]" - ) - - -def test_text_based_document_to_token_based_unaligned_span_not_strict(documents, tokenizer): - doc = documents[0].copy() - doc.entities.append(LabeledSpan(start=0, end=2, label="unaligned")) - assert str(doc.entities[-1]) == "A " - tokenized_text = tokenizer(doc.text, return_offsets_mapping=True) - tokenized_doc = text_based_document_to_token_based( - doc, - tokens=tokenized_text.tokens(), - result_document_type=TokenizedTestDocument, - # to increase test coverage - token_offset_mapping=tokenized_text.offset_mapping, - # to increase test coverage - char_to_token=tokenized_text.char_to_token, - strict_span_conversion=False, - ) - - # check (de-)serialization - tokenized_doc.copy() - - assert len(doc.entities) == 1 - # the unaligned span is not included in the tokenized document - assert len(tokenized_doc.entities) == 0 - - -@pytest.fixture -def token_documents(documents, tokenizer): - result = [] - for doc in documents: - tokenized_text = tokenizer(doc.text, return_offsets_mapping=True) - tokenized_doc = text_based_document_to_token_based( - doc, - tokens=tokenized_text.tokens(), - result_document_type=TokenizedTestDocument, - char_to_token=tokenized_text.char_to_token, - token_offset_mapping=tokenized_text.offset_mapping, - ) - result.append(tokenized_doc) - return result - - -def test_token_based_document_to_text_based(documents, token_documents): - for doc, tokenized_doc in zip(documents, token_documents): - reconstructed_doc = token_based_document_to_text_based( - tokenized_doc, - result_document_type=TestDocument, - ) - assert reconstructed_doc is not None - doc_dict = doc.asdict() - reconstructed_doc_dict = reconstructed_doc.asdict() - # remove all added metadata (original text, token_offset_mapping, char_to_token, tokens) - reconstructed_doc_dict["metadata"] = { - k: reconstructed_doc_dict["metadata"][k] for k in doc_dict["metadata"] - } - assert reconstructed_doc_dict == doc_dict - - -def test_token_based_document_to_text_based_with_join_tokens_with(documents): - for doc in documents: - # split the text by individual whitespace characters - # so that we can reconstruct the original text via " ".join(tokens) - tokens = [] - token_offset_mapping = [] - start = 0 - for token in doc.text.split(" "): - tokens.append(token) - end = start + len(token) - token_offset_mapping.append((start, end)) - start = end + 1 - - tokenized_doc = text_based_document_to_token_based( - doc, - tokens=tokens, - result_document_type=TokenizedTestDocument, - token_offset_mapping=token_offset_mapping, - ) - reconstructed_doc = token_based_document_to_text_based( - tokenized_doc, - result_document_type=TestDocument, - join_tokens_with=" ", - ) - assert reconstructed_doc is not None - assert reconstructed_doc.text == doc.text - - if doc.id in ["train_doc1", "train_doc7"]: - doc_dict = doc.asdict() - reconstructed_doc_dict = reconstructed_doc.asdict() - # remove all added metadata (original text, token_offset_mapping, char_to_token, tokens) - reconstructed_doc_dict["metadata"] = { - k: reconstructed_doc_dict["metadata"][k] for k in doc_dict["metadata"] - } - assert reconstructed_doc_dict == doc_dict - elif doc.id == "train_doc2": - assert reconstructed_doc.sentences == doc.sentences - assert len(reconstructed_doc.entities) == len(doc.entities) == 2 - assert str(reconstructed_doc.entities[0]) == str(doc.entities[0]) == "Entity A" - assert str(doc.entities[1]) == "B" - assert str(reconstructed_doc.entities[1]) == "B." - assert len(reconstructed_doc.relations) == len(doc.relations) == 1 - assert ( - reconstructed_doc.relations[0].label == doc.relations[0].label == "per:employee_of" - ) - assert doc.relations[0].head == doc.entities[0] - assert reconstructed_doc.relations[0].head == reconstructed_doc.entities[0] - assert doc.relations[0].tail == doc.entities[1] - assert reconstructed_doc.relations[0].tail == reconstructed_doc.entities[1] - elif doc.id == "train_doc3": - assert reconstructed_doc.sentences == doc.sentences - assert len(reconstructed_doc.entities) == len(doc.entities) == 2 - assert str(reconstructed_doc.entities[0]) == str(doc.entities[0]) == "Entity C" - assert str(doc.entities[1]) == "D" - assert str(reconstructed_doc.entities[1]) == "D." - assert len(reconstructed_doc.relations) == len(doc.relations) == 0 - elif doc.id == "train_doc4": - assert reconstructed_doc.sentences == doc.sentences - assert len(reconstructed_doc.entities) == len(doc.entities) == 2 - assert str(reconstructed_doc.entities[0]) == str(doc.entities[0]) == "Entity E" - assert str(doc.entities[1]) == "F" - assert str(reconstructed_doc.entities[1]) == "F." - assert len(reconstructed_doc.relations) == len(doc.relations) == 0 - elif doc.id == "train_doc5": - assert reconstructed_doc.sentences == doc.sentences - assert len(reconstructed_doc.entities) == len(doc.entities) == 3 - assert str(reconstructed_doc.entities[0]) == str(doc.entities[0]) == "Entity G" - assert str(doc.entities[1]) == "H" - assert str(reconstructed_doc.entities[1]) == "H." - assert str(doc.entities[2]) == "I" - assert str(reconstructed_doc.entities[2]) == "I." - assert len(reconstructed_doc.relations) == len(doc.relations) == 3 - assert ( - reconstructed_doc.relations[0].label == doc.relations[0].label == "per:employee_of" - ) - assert doc.relations[0].head == doc.entities[0] - assert reconstructed_doc.relations[0].head == reconstructed_doc.entities[0] - assert doc.relations[0].tail == doc.entities[1] - assert reconstructed_doc.relations[0].tail == reconstructed_doc.entities[1] - assert reconstructed_doc.relations[1].label == doc.relations[1].label == "per:founder" - assert doc.relations[1].head == doc.entities[0] - assert reconstructed_doc.relations[1].head == reconstructed_doc.entities[0] - assert doc.relations[1].tail == doc.entities[2] - assert reconstructed_doc.relations[1].tail == reconstructed_doc.entities[2] - assert ( - reconstructed_doc.relations[2].label == doc.relations[2].label == "org:founded_by" - ) - assert doc.relations[2].head == doc.entities[2] - assert reconstructed_doc.relations[2].head == reconstructed_doc.entities[2] - assert doc.relations[2].tail == doc.entities[1] - assert reconstructed_doc.relations[2].tail == reconstructed_doc.entities[1] - elif doc.id == "train_doc6": - assert reconstructed_doc.sentences == doc.sentences - assert len(reconstructed_doc.entities) == len(doc.entities) == 3 - assert str(doc.entities[0]) == "Entity J" - assert str(reconstructed_doc.entities[0]) == "Entity J," - assert str(doc.entities[1]) == "K" - assert str(reconstructed_doc.entities[1]) == "K," - assert str(doc.entities[2]) == "L" - assert str(reconstructed_doc.entities[2]) == "L." - assert len(reconstructed_doc.relations) == len(doc.relations) == 0 - elif doc.id == "train_doc8": - assert len(reconstructed_doc.sentences) == len(doc.sentences) == 3 - assert ( - str(reconstructed_doc.sentences[0]) == str(doc.sentences[0]) == "First sentence." - ) - assert ( - str(reconstructed_doc.sentences[1]) - == str(doc.sentences[1]) - == "Entity M works at N." - ) - assert str(doc.sentences[2]) == "And it founded O" - assert str(reconstructed_doc.sentences[2]) == "And it founded O." - assert len(reconstructed_doc.entities) == len(doc.entities) == 4 - assert str(reconstructed_doc.entities[0]) == str(doc.entities[0]) == "Entity M" - assert str(doc.entities[1]) == "N" - assert str(reconstructed_doc.entities[1]) == "N." - assert str(reconstructed_doc.entities[2]) == str(doc.entities[2]) == "it" - assert str(doc.entities[3]) == "O" - assert str(reconstructed_doc.entities[3]) == "O." - assert len(reconstructed_doc.relations) == len(doc.relations) == 3 - assert ( - reconstructed_doc.relations[0].label == doc.relations[0].label == "per:employee_of" - ) - assert doc.relations[0].head == doc.entities[0] - assert reconstructed_doc.relations[0].head == reconstructed_doc.entities[0] - assert doc.relations[0].tail == doc.entities[1] - assert reconstructed_doc.relations[0].tail == reconstructed_doc.entities[1] - assert reconstructed_doc.relations[1].label == doc.relations[1].label == "per:founder" - assert doc.relations[1].head == doc.entities[2] - assert reconstructed_doc.relations[1].head == reconstructed_doc.entities[2] - assert doc.relations[1].tail == doc.entities[3] - assert reconstructed_doc.relations[1].tail == reconstructed_doc.entities[3] - assert ( - reconstructed_doc.relations[2].label == doc.relations[2].label == "org:founded_by" - ) - assert doc.relations[2].head == doc.entities[3] - assert reconstructed_doc.relations[2].head == reconstructed_doc.entities[3] - assert doc.relations[2].tail == doc.entities[2] - assert reconstructed_doc.relations[2].tail == reconstructed_doc.entities[2] - else: - raise ValueError(f"Unexpected document: {doc.id}") - - -def test_tokenize_document(documents, tokenizer): - doc = documents[1] - tokenized_docs = tokenize_document( - doc, - tokenizer=tokenizer, - result_document_type=TokenizedTestDocument, - ) - assert len(tokenized_docs) == 1 - tokenized_doc = tokenized_docs[0] - - # check (de-)serialization - tokenized_doc.copy() - - assert doc.id == "train_doc2" - assert tokenized_doc.metadata["text"] == doc.text == "Entity A works at B." - assert tokenized_doc.tokens == ( - "[CLS]", - "En", - "##ti", - "##ty", - "A", - "works", - "at", - "B", - ".", - "[SEP]", - ) - assert len(tokenized_doc.sentences) == len(doc.sentences) == 1 - assert str(doc.sentences[0]) == "Entity A works at B." - assert ( - str(tokenized_doc.sentences[0]) == "('En', '##ti', '##ty', 'A', 'works', 'at', 'B', '.')" - ) - assert len(tokenized_doc.entities) == len(doc.entities) == 2 - assert str(doc.entities[0]) == "Entity A" - assert str(tokenized_doc.entities[0]) == "('En', '##ti', '##ty', 'A')" - assert str(doc.entities[1]) == "B" - assert str(tokenized_doc.entities[1]) == "('B',)" - assert len(tokenized_doc.relations) == len(doc.relations) == 1 - assert tokenized_doc.relations[0].label == doc.relations[0].label == "per:employee_of" - assert doc.relations[0].head == doc.entities[0] - assert tokenized_doc.relations[0].head == tokenized_doc.entities[0] - assert doc.relations[0].tail == doc.entities[1] - assert tokenized_doc.relations[0].tail == tokenized_doc.entities[1] - - -def test_tokenize_document_max_length(documents, tokenizer): - doc = documents[1] - assert doc.id == "train_doc2" - assert doc.text == "Entity A works at B." - assert len(doc.sentences) == 1 - assert str(doc.sentences[0]) == "Entity A works at B." - assert len(doc.entities) == 2 - assert str(doc.entities[0]) == "Entity A" - assert str(doc.entities[1]) == "B" - assert len(doc.relations) == 1 - assert doc.relations[0].label == "per:employee_of" - assert doc.relations[0].head == doc.entities[0] - assert doc.relations[0].tail == doc.entities[1] - - tokenized_docs = tokenize_document( - doc, - tokenizer=tokenizer, - result_document_type=TokenizedTestDocument, - strict_span_conversion=False, - # This will cut out the second entity. Also, the sentence annotation will be removed, - # because the sentence is not complete anymore. - max_length=8, - return_overflowing_tokens=True, - ) - assert len(tokenized_docs) == 2 - tokenized_doc = tokenized_docs[0] - - # check (de-)serialization - tokenized_doc.copy() - - assert tokenized_doc.id == doc.id == "train_doc2" - assert tokenized_doc.metadata["text"] == doc.text == "Entity A works at B." - assert tokenized_doc.tokens == ("[CLS]", "En", "##ti", "##ty", "A", "works", "at", "[SEP]") - assert len(tokenized_doc.sentences) == 0 - assert len(tokenized_doc.entities) == 1 - assert str(tokenized_doc.entities[0]) == "('En', '##ti', '##ty', 'A')" - assert len(tokenized_doc.relations) == 0 - - tokenized_doc = tokenized_docs[1] - - # check (de-)serialization - tokenized_doc.copy() - - assert tokenized_doc.id == doc.id == "train_doc2" - assert tokenized_doc.metadata["text"] == doc.text == "Entity A works at B." - assert tokenized_doc.tokens == ("[CLS]", "B", ".", "[SEP]") - assert len(tokenized_doc.sentences) == 0 - assert len(tokenized_doc.entities) == 1 - assert str(tokenized_doc.entities[0]) == "('B',)" - assert len(tokenized_doc.relations) == 0 - - -def test_tokenize_document_partition(documents, tokenizer): - doc = documents[7] - assert doc.id == "train_doc8" - assert doc.text == "First sentence. Entity M works at N. And it founded O." - assert len(doc.sentences) == 3 - assert str(doc.sentences[0]) == "First sentence." - assert str(doc.sentences[1]) == "Entity M works at N." - assert str(doc.sentences[2]) == "And it founded O" - assert len(doc.entities) == 4 - assert str(doc.entities[0]) == "Entity M" - assert str(doc.entities[1]) == "N" - assert str(doc.entities[2]) == "it" - assert str(doc.entities[3]) == "O" - assert len(doc.relations) == 3 - assert doc.relations[0].head == doc.entities[0] - assert doc.relations[0].tail == doc.entities[1] - assert doc.relations[1].head == doc.entities[2] - assert doc.relations[1].tail == doc.entities[3] - assert doc.relations[2].head == doc.entities[3] - assert doc.relations[2].tail == doc.entities[2] - - tokenized_docs = tokenize_document( - doc, - tokenizer=tokenizer, - result_document_type=TokenizedTestDocument, - strict_span_conversion=False, - partition_layer="sentences", - ) - assert len(tokenized_docs) == 3 - tokenized_doc = tokenized_docs[0] - - # check (de-)serialization - tokenized_doc.copy() - - assert tokenized_doc.id == doc.id == "train_doc8" - assert ( - tokenized_doc.metadata["text"] - == doc.text - == "First sentence. Entity M works at N. And it founded O." - ) - assert tokenized_doc.tokens == ("[CLS]", "First", "sentence", ".", "[SEP]") - assert len(tokenized_doc.sentences) == 1 - assert len(tokenized_doc.entities) == 0 - assert len(tokenized_doc.relations) == 0 - - tokenized_doc = tokenized_docs[1] - - # check (de-)serialization - tokenized_doc.copy() - - assert tokenized_doc.id == doc.id == "train_doc8" - assert ( - tokenized_doc.metadata["text"] - == doc.text - == "First sentence. Entity M works at N. And it founded O." - ) - assert tokenized_doc.tokens == ( - "[CLS]", - "En", - "##ti", - "##ty", - "M", - "works", - "at", - "N", - ".", - "[SEP]", - ) - assert len(tokenized_doc.sentences) == 1 - assert len(tokenized_doc.entities) == 2 - assert str(tokenized_doc.entities[0]) == "('En', '##ti', '##ty', 'M')" - assert str(tokenized_doc.entities[1]) == "('N',)" - assert len(tokenized_doc.relations) == 1 - assert tokenized_doc.relations[0].label == "per:employee_of" - assert tokenized_doc.relations[0].head == tokenized_doc.entities[0] - assert tokenized_doc.relations[0].tail == tokenized_doc.entities[1] - - tokenized_doc = tokenized_docs[2] - - # check (de-)serialization - tokenized_doc.copy() - - assert tokenized_doc.id == doc.id == "train_doc8" - assert ( - tokenized_doc.metadata["text"] - == doc.text - == "First sentence. Entity M works at N. And it founded O." - ) - assert tokenized_doc.tokens == ("[CLS]", "And", "it", "founded", "O", "[SEP]") - assert len(tokenized_doc.sentences) == 1 - assert len(tokenized_doc.entities) == 2 - assert str(tokenized_doc.entities[0]) == "('it',)" - assert str(tokenized_doc.entities[1]) == "('O',)" - assert len(tokenized_doc.relations) == 2 - assert tokenized_doc.relations[0].label == "per:founder" - assert tokenized_doc.relations[0].head == tokenized_doc.entities[0] - assert tokenized_doc.relations[0].tail == tokenized_doc.entities[1] - assert tokenized_doc.relations[1].label == "org:founded_by" - assert tokenized_doc.relations[1].head == tokenized_doc.entities[1] - assert tokenized_doc.relations[1].tail == tokenized_doc.entities[0] diff --git a/tests/unit/test_statistics.py b/tests/unit/test_statistics.py deleted file mode 100644 index 1129adfb..00000000 --- a/tests/unit/test_statistics.py +++ /dev/null @@ -1,87 +0,0 @@ -import dataclasses - -import pytest -from pytorch_ie.annotations import LabeledSpan -from pytorch_ie.core import AnnotationList, annotation_field -from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument - -from pie_datasets import DatasetDict -from pie_datasets.statistics import SpanLengthCollector -from tests import FIXTURES_ROOT - - -@pytest.fixture -def dataset(): - @dataclasses.dataclass - class Conll2003Document(TextBasedDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="text") - - return DatasetDict.from_json( - data_dir=FIXTURES_ROOT / "dataset_dict" / "conll2003_extract", - document_type=Conll2003Document, - ) - - -def test_statistics(dataset): - statistic = SpanLengthCollector(layer="entities") - values = statistic(dataset) - assert values == { - "train": {"len": 5, "mean": 7.6, "std": 4.223742416388575, "min": 2, "max": 15}, - "validation": { - "len": 6, - "mean": 10.833333333333334, - "std": 2.9674156357941426, - "min": 6, - "max": 14, - }, - "test": {"len": 5, "mean": 9.4, "std": 5.748043145279966, "min": 5, "max": 20}, - } - - statistic = SpanLengthCollector(layer="entities", labels="INFERRED") - values = statistic(dataset) - assert values == { - "train": { - "ORG": {"max": 2, "len": 1}, - "MISC": {"max": 7, "len": 2}, - "PER": {"max": 15, "len": 1}, - "LOC": {"max": 8, "len": 1}, - }, - "test": { - "LOC": { - "max": 20, - "len": 3, - }, - "PER": {"max": 11, "len": 2}, - }, - "validation": { - "ORG": {"max": 14, "len": 3}, - "LOC": {"max": 6, "len": 1}, - "MISC": {"max": 11, "len": 1}, - "PER": {"max": 12, "len": 1}, - }, - } - - -def test_statistics_with_tokenize(dataset): - @dataclasses.dataclass - class TokenDocumentWithLabeledEntities(TokenBasedDocument): - entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens") - - statistic = SpanLengthCollector( - layer="entities", - tokenize=True, - tokenizer="bert-base-uncased", - tokenized_document_type=TokenDocumentWithLabeledEntities, - ) - values = statistic(dataset) - assert values == { - "test": {"len": 5, "max": 4, "mean": 2.4, "min": 1, "std": 1.2000000000000002}, - "train": {"len": 5, "max": 2, "mean": 1.2, "min": 1, "std": 0.4}, - "validation": { - "len": 6, - "max": 2, - "mean": 1.3333333333333333, - "min": 1, - "std": 0.4714045207910317, - }, - }