From 6a6efd6e89f6d9c636400a36e7c4c7a3452ab2ca Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 6 Nov 2023 23:09:17 +0100 Subject: [PATCH 01/12] integrate from PyTorch-IE: Dataset, IterableDataset, DatasetDict, GeneratorBasedBuilder, and ArrowBasedBuilder (and necessary helper classes) --- dataset_builders/pie/conll2003/conll2003.py | 4 +- dataset_builders/pie/tacred/tacred.py | 7 +- src/pie_datasets/__init__.py | 22 + src/pie_datasets/builder.py | 256 +++++++ src/pie_datasets/common.py | 40 ++ src/pie_datasets/dataset.py | 590 ++++++++++++++++ src/pie_datasets/dataset_dict.py | 641 ++++++++++++++++++ src/pie_datasets/document/conversion.py | 302 +++++++++ .../document/processing/regex_partitioner.py | 3 +- src/pie_datasets/document_formatter.py | 22 + tests/conftest.py | 2 +- tests/dataset_builders/pie/test_conll2003.py | 2 +- 12 files changed, 1884 insertions(+), 7 deletions(-) create mode 100644 src/pie_datasets/builder.py create mode 100644 src/pie_datasets/common.py create mode 100644 src/pie_datasets/dataset.py create mode 100644 src/pie_datasets/dataset_dict.py create mode 100644 src/pie_datasets/document/conversion.py create mode 100644 src/pie_datasets/document_formatter.py diff --git a/dataset_builders/pie/conll2003/conll2003.py b/dataset_builders/pie/conll2003/conll2003.py index ebfe7191..add860d5 100644 --- a/dataset_builders/pie/conll2003/conll2003.py +++ b/dataset_builders/pie/conll2003/conll2003.py @@ -7,13 +7,15 @@ from pytorch_ie.documents import TextDocument, TextDocumentWithLabeledSpans from pytorch_ie.utils.span import tokens_and_tags_to_text_and_labeled_spans +from pie_datasets import GeneratorBasedBuilder + @dataclass class CoNLL2003Document(TextDocument): entities: AnnotationList[LabeledSpan] = annotation_field(target="text") -class Conll2003(pytorch_ie.data.builder.GeneratorBasedBuilder): +class Conll2003(GeneratorBasedBuilder): DOCUMENT_TYPE = CoNLL2003Document BASE_DATASET_PATH = "conll2003" diff --git a/dataset_builders/pie/tacred/tacred.py b/dataset_builders/pie/tacred/tacred.py index 5dcf6a83..b6e44afb 100644 --- a/dataset_builders/pie/tacred/tacred.py +++ b/dataset_builders/pie/tacred/tacred.py @@ -2,8 +2,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple import datasets -import pytorch_ie.data.builder -from pytorch_ie import token_based_document_to_text_based from pytorch_ie.annotations import BinaryRelation, LabeledSpan, _post_init_single_label from pytorch_ie.core import Annotation, AnnotationList, Document, annotation_field from pytorch_ie.documents import ( @@ -11,6 +9,9 @@ TokenBasedDocument, ) +from pie_datasets import GeneratorBasedBuilder +from pie_datasets.document.conversion import token_based_document_to_text_based + @dataclass(eq=True, frozen=True) class TokenRelation(Annotation): @@ -172,7 +173,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) -class Tacred(pytorch_ie.data.builder.GeneratorBasedBuilder): +class Tacred(GeneratorBasedBuilder): DOCUMENT_TYPE = TacredDocument DOCUMENT_CONVERTERS = { diff --git a/src/pie_datasets/__init__.py b/src/pie_datasets/__init__.py index e69de29b..6c3d2550 100644 --- a/src/pie_datasets/__init__.py +++ b/src/pie_datasets/__init__.py @@ -0,0 +1,22 @@ +from .builder import GeneratorBasedBuilder +from .common import ( + EnterDatasetDictMixin, + EnterDatasetMixin, + ExitDatasetDictMixin, + ExitDatasetMixin, +) +from .dataset import Dataset, IterableDataset +from .dataset_dict import DatasetDict +from .document_formatter import DocumentFormatter + +__all__ = [ + "GeneratorBasedBuilder", + "Dataset", + "IterableDataset", + "DatasetDict", + "DocumentFormatter", + "EnterDatasetMixin", + "ExitDatasetMixin", + "EnterDatasetDictMixin", + "ExitDatasetDictMixin", +] diff --git a/src/pie_datasets/builder.py b/src/pie_datasets/builder.py new file mode 100644 index 00000000..fc4773fb --- /dev/null +++ b/src/pie_datasets/builder.py @@ -0,0 +1,256 @@ +import abc +from typing import Any, Callable, Dict, Optional, Type, Union, overload + +import datasets +from pytorch_ie.core.document import Document +from pytorch_ie.utils.hydra import resolve_target + +from .dataset import ( + Dataset, + DocumentConvertersType, + IterableDataset, + decorate_convert_to_dict_of_lists, + get_pie_dataset_type, +) + + +def get_general_dataset_builder_parent_class( + obj: datasets.builder.DatasetBuilder, +) -> Type[datasets.builder.DatasetBuilder]: + general_dataset_builder_parent_classes = [ + cls + for cls in datasets.builder.DatasetBuilder.__subclasses__() + if cls != PieDatasetBuilder and isinstance(obj, cls) + ] + if len(general_dataset_builder_parent_classes) != 1: + raise TypeError("can not determine general dataset builder parent class of the object") + return general_dataset_builder_parent_classes[0] + + +class PieDatasetBuilder(datasets.builder.DatasetBuilder): + # The default pytorch-ie document type for the dataset. + DOCUMENT_TYPE: Optional[Type[Document]] = None + # A mapping from config names to PIE document types. Use this to specify individual + # document types per config. + DOCUMENT_TYPES: Dict[str, Type[Document]] = {} + + # The default path to the Huggingface dataset loading script that will be used as base dataset. + BASE_DATASET_PATH: Optional[str] = None + # A mapping from config names to Huggingface dataset loading script paths. Use this to specify individual + # base datasets for each config. + BASE_DATASET_PATHS: Dict[str, str] = {} + + # Define kwargs to create base configs. This should contain config names as keys + # and the respective config kwargs dicts as values. If the config name is not contained, a new entry + # {"name": config_name} will be created for it, i.e. the config name is passed as base config name. + # This default behaviour can be disabled by setting BASE_CONFIG_KWARGS_DICT to None. + BASE_CONFIG_KWARGS_DICT: Optional[Dict[Optional[str], Dict[str, Any]]] = {} + # Define base builder kwargs. This should contain config names as keys and the respective + # builder kwargs dicts as values. + BASE_BUILDER_KWARGS_DICT: Optional[Dict[Optional[str], Dict[str, Any]]] = None + + # Define document converters. This should be a mapping from document types as keys to the respective + # document converters as values. The document converters can be either callables or dicts + # that map from original field names to new field names. If a callable is provided, it will be used to + # convert the document. If a dict is provided, it will be used to rename the fields of the + # document (this is done by renaming the columns which is much more efficient). + DOCUMENT_CONVERTERS: DocumentConvertersType = {} + + def __init__( + self, + base_dataset_kwargs: Optional[Dict[str, Any]] = None, + document_converters: Optional[ + Dict[Union[Type[Document], str], Union[Callable[..., Document], Dict[str, str], str]] + ] = None, + **kwargs, + ): + self.base_builder = None + config_name = kwargs.get("config_name", None) + base_dataset_path = self.BASE_DATASET_PATHS.get(config_name, self.BASE_DATASET_PATH) + if base_dataset_path is not None: + base_dataset_kwargs = base_dataset_kwargs or {} + base_builder_kwargs: Dict[str, Any] = {} + + # get base config kwargs from mapping + if self.BASE_CONFIG_KWARGS_DICT is not None: + if config_name in self.BASE_CONFIG_KWARGS_DICT: + config_kwargs = self.BASE_CONFIG_KWARGS_DICT[config_name] + else: + # if the config name is not in BASE_CONFIG_KWARGS_DICT, + # we pass it as base config name + config_kwargs = {"name": config_name} + base_builder_kwargs.update(config_kwargs) + + # get base builder kwargs from mapping + if self.BASE_BUILDER_KWARGS_DICT is not None: + base_builder_kwargs.update(self.BASE_BUILDER_KWARGS_DICT[config_name]) + + base_builder_kwargs.update(base_dataset_kwargs) + self.base_builder = datasets.load.load_dataset_builder( + path=base_dataset_path, + **base_builder_kwargs, + ) + # Ensure that self and self.base_builder are derived from the same subclass of + # datasets.builder.DatasetBuilder. + base_builder_general_parent_class = get_general_dataset_builder_parent_class( + self.base_builder + ) + self_general_parent_class = get_general_dataset_builder_parent_class(self) + if base_builder_general_parent_class != self_general_parent_class: + raise TypeError( + f"The PyTorch-IE dataset builder class '{type(self).__name__}' is derived from " + f"{self_general_parent_class}, but the base builder is not which is not allowed. The base builder " + f"is of type '{type(self.base_builder).__name__}' that is derived from " + f"{base_builder_general_parent_class}. Consider to derive your PyTorch-IE dataset builder " + f"'{type(self).__name__}' from a PyTorch-IE variant of " + f"'{base_builder_general_parent_class.__name__}'." + ) + + # append the base_builder config_id to the hash, otherwise the base_builder config arguments + # are not respected in the cache fingerprint + if "hash" in kwargs: + kwargs["hash"] = f"{kwargs['hash']}-{self.base_builder.config_id}" + + # set base path to base builder base path. This is required so that the download manager + # works correctly with relative paths. + kwargs["base_path"] = self.base_builder.base_path + + super().__init__(**kwargs) + + self.document_converters = dict(self.DOCUMENT_CONVERTERS) + if document_converters is not None: + for document_type_or_str, document_converter_or_str in document_converters.items(): + document_type = resolve_target(document_type_or_str) + if isinstance(document_type, type) and issubclass(document_type, Document): + document_converter: Union[Callable[..., Any], dict[str, str]] + if isinstance(document_converter_or_str, str): + document_converter = resolve_target(document_converter_or_str) + else: + document_converter = document_converter_or_str + + self.document_converters[document_type] = document_converter + else: + raise TypeError( + f"The key '{document_type_or_str}' for one of the converters " + f"can not be resolved to a document type." + ) + + def _info(self): + return self.base_builder._info() + + def _split_generators(self, dl_manager): + return self.base_builder._split_generators(dl_manager) + + @property + def document_type(self) -> Optional[Type[Document]]: + return self.DOCUMENT_TYPES.get(self.config.name, self.DOCUMENT_TYPE) + + @abc.abstractmethod + def _generate_document(self, example, **kwargs): + pass + + def _generate_document_kwargs(self, dataset): + return None + + @overload # type: ignore + def _convert_dataset_single(self, dataset: datasets.IterableDataset) -> IterableDataset: + ... + + @overload # type: ignore + def _convert_dataset_single(self, dataset: datasets.Dataset) -> Dataset: + ... + + def _convert_dataset_single( + self, dataset: Union[datasets.Dataset, datasets.IterableDataset] + ) -> Union[Dataset, IterableDataset]: + document_type = self.document_type + if document_type is None: + raise TypeError( + f"the builder has no DOCUMENT_TYPE or DOCUMENT_TYPES[{self.config.name}] defined" + ) + + fn = decorate_convert_to_dict_of_lists(self._generate_document) + fn_kwargs = self._generate_document_kwargs(dataset) + mapped_dataset = dataset.map(fn, fn_kwargs=fn_kwargs) + dataset_type = get_pie_dataset_type(mapped_dataset) + result = dataset_type.from_hf_dataset( + dataset=mapped_dataset, + document_type=document_type, + document_converters=dict(self.document_converters), + ) + return result + + @overload # type: ignore + def _convert_datasets(self, datasets: datasets.DatasetDict) -> datasets.DatasetDict: + ... + + @overload # type: ignore + def _convert_datasets( + self, datasets: datasets.IterableDatasetDict + ) -> datasets.IterableDatasetDict: + ... + + @overload # type: ignore + def _convert_datasets(self, datasets: datasets.IterableDataset) -> IterableDataset: + ... + + @overload # type: ignore + def _convert_datasets(self, datasets: datasets.Dataset) -> Dataset: + ... + + def _convert_datasets( + self, + datasets: Union[ + datasets.Dataset, + datasets.IterableDataset, + datasets.DatasetDict, + datasets.IterableDatasetDict, + ], + ) -> Union[Dataset, IterableDataset, datasets.DatasetDict, datasets.IterableDatasetDict]: + if isinstance(datasets, dict): + return type(datasets)( + {k: self._convert_dataset_single(v) for k, v in datasets.items()} + ) + else: + return self._convert_dataset_single(datasets) + + def as_dataset( + self, + split: Optional[datasets.Split] = None, + run_post_process=True, + verification_mode: Optional[Union[datasets.VerificationMode, str]] = None, + ignore_verifications="deprecated", + in_memory=False, + ) -> Union[Dataset, datasets.DatasetDict]: + dataset = super().as_dataset( + split=split, + run_post_process=run_post_process, + ignore_verifications=ignore_verifications, + in_memory=in_memory, + verification_mode=verification_mode, + ) + converted_datasets = self._convert_datasets(datasets=dataset) + return converted_datasets + + def as_streaming_dataset( + self, + split: Optional[str] = None, + base_path: Optional[str] = None, + ) -> Union[IterableDataset, datasets.IterableDatasetDict]: # type: ignore + dataset: Union[ + datasets.IterableDataset, datasets.IterableDatasetDict + ] = super().as_streaming_dataset( + split=split, base_path=base_path + ) # type: ignore + converted_datasets = self._convert_datasets(datasets=dataset) + return converted_datasets + + +class GeneratorBasedBuilder(PieDatasetBuilder, datasets.builder.GeneratorBasedBuilder): + def _generate_examples(self, *args, **kwargs): + return self.base_builder._generate_examples(*args, **kwargs) + + +class ArrowBasedBuilder(PieDatasetBuilder, datasets.builder.ArrowBasedBuilder): + def _generate_tables(self, *args, **kwargs): + return self.base_builder._generate_tables(*args, **kwargs) diff --git a/src/pie_datasets/common.py b/src/pie_datasets/common.py new file mode 100644 index 00000000..e3213b2e --- /dev/null +++ b/src/pie_datasets/common.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod +from typing import Optional, Union + +from .dataset import Dataset, IterableDataset + + +class EnterDatasetMixin(ABC): + """Mixin for processors that enter a dataset context.""" + + @abstractmethod + def enter_dataset( + self, dataset: Union[Dataset, IterableDataset], name: Optional[str] = None + ) -> None: + """Enter dataset context.""" + + +class ExitDatasetMixin(ABC): + """Mixin for processors that exit a dataset context.""" + + @abstractmethod + def exit_dataset( + self, dataset: Union[Dataset, IterableDataset], name: Optional[str] = None + ) -> None: + """Exit dataset context.""" + + +class EnterDatasetDictMixin(ABC): + """Mixin for processors that enter a dataset dict context.""" + + @abstractmethod + def enter_dataset_dict(self, dataset_dict) -> None: + """Enter dataset dict context.""" + + +class ExitDatasetDictMixin(ABC): + """Mixin for processors that exit a dataset dict context.""" + + @abstractmethod + def exit_dataset_dict(self, dataset_dict) -> None: + """Exit dataset dict context.""" diff --git a/src/pie_datasets/dataset.py b/src/pie_datasets/dataset.py new file mode 100644 index 00000000..ca18f0b0 --- /dev/null +++ b/src/pie_datasets/dataset.py @@ -0,0 +1,590 @@ +import logging +from collections.abc import Iterable, Sequence +from functools import wraps +from inspect import Signature, isclass, signature +from typing import ( + Callable, + Dict, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +import datasets +import pandas as pd +from datasets.formatting import _register_formatter +from pytorch_ie.core.document import Document + +from .document_formatter import DocumentFormatter + +logger = logging.getLogger(__name__) + +_register_formatter(DocumentFormatter, "document") + + +def decorate_convert_to_dict_of_lists(f): + """Decorate the mapped function, so that converts a single Document to a dict, and a list of + Documents into a dict of lists.""" + + @wraps(f) + def decorated(item, *args, **kwargs): + if isinstance(item, list): + # Convert a list of dicts into a dict of lists. + return pd.DataFrame([e.asdict() for e in f(item, *args, **kwargs)]).to_dict( + orient="list" + ) + else: + return f(item, *args, **kwargs).asdict() + + return decorated + + +E = TypeVar("E") + + +def dl_to_ld(dict_list: Dict[str, List[E]]) -> List[Dict[str, E]]: + # Convert a dict of lists to a list of dicts + return [dict(zip(dict_list, t)) for t in zip(*dict_list.values())] + + +def ld_to_dl( + list_dict: List[Dict[str, E]], keys: Optional[Iterable[str]] = None +) -> Dict[str, List[E]]: + # Convert a list of dicts to a dict of lists. + # Provide keys to create the expected format when lists are empty. + if keys is None: + keys = list_dict[0] + return {k: [dic[k] for dic in list_dict] for k in keys} + + +def decorate_convert_to_document_and_back(f, document_type: Type[Document], batched: bool): + @wraps(f) + def decorated(item, *args, **kwargs): + if batched: + # Convert a list of dicts into a dict of lists. + return ld_to_dl( + [ + e.asdict() + for e in f( + [document_type.fromdict(x) for x in dl_to_ld(item)], *args, **kwargs + ) + ], + # passing the keys allows to work correctly with empty lists + keys=item.keys(), + ) + else: + return f(document_type.fromdict(item), *args, **kwargs).asdict() + + return decorated + + +def _check_fields_for_casting( + field_mapping: Dict[str, str], + current_document_type: Type[Document], + new_document_type: Type[Document], + column_names: list[str], +) -> Tuple[Set[str], Set[str]]: + original_fields = {field.name: field for field in current_document_type.fields()} + new_fields = {field.name: field for field in new_document_type.fields()} + hidden_fields = set(column_names) - set(original_fields) + fields_to_map_not_in_original_fields = ( + set(field_mapping) - set(original_fields) - set(hidden_fields) + ) + if len(fields_to_map_not_in_original_fields) > 0: + raise ValueError( + f"some fields to rename are not in the original document_type or hidden fields: " + f"{fields_to_map_not_in_original_fields}" + ) + mapped_but_not_in_new_fields = set(field_mapping.values()) - set(new_fields) + if len(mapped_but_not_in_new_fields) > 0: + raise ValueError( + f"some renamed fields are not in the new document_type: {mapped_but_not_in_new_fields}" + ) + original_fields_mapped = { + field_mapping.get(f_name, f_name): f for f_name, f in original_fields.items() + } + added_field_names = set(new_fields) - set(original_fields_mapped) + removed_field_names = set(original_fields) - set(new_fields) - set(field_mapping) + + # Sanity checks + kept_field_names = set(original_fields_mapped) & set(new_fields) + for f_name_mapped in kept_field_names: + f = original_fields_mapped[f_name_mapped] + new_f = new_fields[f_name_mapped] + if not ( + f.type == new_f.type + and f.default == new_f.default + and f.default_factory == new_f.default_factory + ): + raise ValueError(f"new field is not the same as old field:\n{new_f}\nvs\n{f}") + + return removed_field_names, added_field_names + + +def _infer_document_type_from_function_return( + function: Callable, strict: bool = True +) -> Optional[Type[Document]]: + # try to infer the document type from the return type annotation of function + return_signature = signature(function).return_annotation + if not return_signature == Signature.empty: + if not isclass(return_signature) or not issubclass(return_signature, Document): + msg = "the return type annotation of the function used with map is not a subclass of Document" + if strict: + raise TypeError(msg) + else: + logger.warning(msg) + return None + return return_signature + return None + + +D = TypeVar("D", bound=Document) +DocumentConvertersType = Dict[Type[D], Union[Callable[..., D], Dict[str, str]]] + + +def _get_best_dataset_converter_with_types( + dataset: Union["IterableDataset", "Dataset"], + document_type: Union[Type[Document]], +) -> Tuple[Union[Callable[..., Document], Dict[str, str]], Type[Document], Type[Document]]: + # first try to find an exact match + if document_type in dataset.document_converters: + return dataset.document_converters[document_type], document_type, document_type + + # then try to find a match with a superclass + for registered_dt, candidate_converter in dataset.document_converters.items(): + if issubclass(registered_dt, document_type): + return candidate_converter, document_type, registered_dt + + # then try to find a match with a subclass + for registered_dt, candidate_converter in dataset.document_converters.items(): + if issubclass(document_type, registered_dt): + return candidate_converter, document_type, registered_dt + + raise ValueError( + f"No valid key (either subclass or superclass) was found for the document type '{document_type}' " + f"in the document_converters of the dataset. Available keys: {set(dataset.document_converters)}. " + f"Consider adding a respective converter to the dataset with " + f"dataset.register_document_converter(my_converter_method) where my_converter_method should accept " + f"{dataset.document_type} as input and return '{document_type}'." + ) + + +@overload +def dataset_to_document_type( + dataset: "Dataset", + document_type: Type[Document], + **kwargs, +) -> "Dataset": + ... + + +@overload +def dataset_to_document_type( + dataset: "IterableDataset", + document_type: Type[Document], + **kwargs, +) -> "IterableDataset": + ... + + +def dataset_to_document_type( + dataset: Union["IterableDataset", "Dataset"], + document_type: Type[Document], + **kwargs, +) -> Union["IterableDataset", "Dataset"]: + # do nothing if the document type is already the requested type + if document_type == dataset.document_type: + logger.info(f"The dataset has already the requested document type {document_type}.") + return dataset + + converter, requested_type, registered_type = _get_best_dataset_converter_with_types( + dataset=dataset, + document_type=document_type, + ) + + result = dataset + if callable(converter): + result = result.map( + function=converter, + result_document_type=registered_type, + fn_kwargs=kwargs, + ) + else: + result = result.cast_document_type( + new_document_type=registered_type, field_mapping=converter, **kwargs + ) + # if the type is not the same or a subclass of the requested type, try to cast (again) + if not issubclass(registered_type, requested_type): + result = result.cast_document_type(new_document_type=requested_type) + + # remove the document converters because they are not valid anymore + result.document_converters = {} + + return result + + +def dataset_register_document_converter( + dataset: Union["Dataset", "IterableDataset"], + converter: Union[Callable[..., D], Dict[str, str]], + document_type: Optional[Type[D]] = None, +) -> None: + if callable(converter) and document_type is None: + dt = _infer_document_type_from_function_return(converter) + else: + dt = document_type + if not (isinstance(dt, type) and issubclass(dt, Document)): + raise TypeError( + f"the (inferred) document_type {dt} is not a subclass of Document. " + "Please provide a document_type or a converter with a return type annotation." + ) + dataset.document_converters[dt] = converter + + +class Dataset(datasets.Dataset, Sequence[D]): + def __init__( + self, + document_type: Type[D], + arrow_table: datasets.table.Table, + info: Optional[datasets.DatasetInfo] = None, + split: Optional[datasets.NamedSplit] = None, + indices_table: Optional[datasets.table.Table] = None, + fingerprint: Optional[str] = None, + document_converters: Optional[DocumentConvertersType] = None, + ): + super().__init__( + arrow_table=arrow_table, + info=info, + split=split, + indices_table=indices_table, + fingerprint=fingerprint, + ) + + self.document_type = document_type + self.set_format("document", document_type=document_type) + self.document_converters = document_converters or {} + + @classmethod + def get_base_kwargs(cls, dataset: datasets.Dataset): + return dict( + arrow_table=dataset._data, + info=dataset.info, + split=dataset.split, + indices_table=dataset._indices, + fingerprint=dataset._fingerprint, + ) + + @classmethod + def from_hf_dataset( + cls, + dataset: datasets.Dataset, + document_type: Type[D], + document_converters: Optional[DocumentConvertersType] = None, + ) -> "Dataset": + document_dataset = cls( + document_type=document_type, + document_converters=document_converters, + **cls.get_base_kwargs(dataset), + ) + return document_dataset + + def apply_hf_func(self, func, **kwargs) -> "Dataset": + return Dataset.from_hf_dataset( + func(self, **kwargs), + document_type=self.document_type, + document_converters=self.document_converters, + ) + + def register_document_converter( + self, + converter: Union[Callable[..., D], Dict[str, str]], + document_type: Optional[Type[D]] = None, + ) -> None: + dataset_register_document_converter( + dataset=self, + converter=converter, + document_type=document_type, + ) + + def to_document_type( + self, + document_type: Type[Document], + **kwargs, + ) -> "Dataset": + return dataset_to_document_type( + dataset=self, + document_type=document_type, + **kwargs, + ) + + def map( + self, + function: Optional[Callable] = None, + with_indices: bool = False, + with_rank: bool = False, + input_columns: Optional[Union[str, List[str]]] = None, + batched: bool = False, + batch_size: Optional[int] = 1000, + drop_last_batch: bool = False, + remove_columns: Optional[Union[str, List[str]]] = None, + keep_in_memory: bool = False, + load_from_cache_file: Optional[bool] = None, + cache_file_name: Optional[str] = None, + writer_batch_size: Optional[int] = 1000, + features: Optional[datasets.Features] = None, + disable_nullable: bool = False, + fn_kwargs: Optional[dict] = None, + num_proc: Optional[int] = None, + suffix_template: str = "_{rank:05d}_of_{num_proc:05d}", + new_fingerprint: Optional[str] = None, + desc: Optional[str] = None, + as_documents: bool = True, + result_document_type: Optional[Type[Document]] = None, + ) -> "Dataset": + dataset = super().map( + function=decorate_convert_to_dict_of_lists(function) if as_documents else function, + with_indices=with_indices, + with_rank=with_rank, + input_columns=input_columns, + batched=batched, + batch_size=batch_size, + drop_last_batch=drop_last_batch, + remove_columns=remove_columns, + keep_in_memory=keep_in_memory, + # ignore typing because typing in Huggingface Dataset.map() is incorrect + load_from_cache_file=load_from_cache_file, # type: ignore + cache_file_name=cache_file_name, + writer_batch_size=writer_batch_size, + features=features, + disable_nullable=disable_nullable, + fn_kwargs=fn_kwargs, + num_proc=num_proc, + suffix_template=suffix_template, + new_fingerprint=new_fingerprint, + desc=desc, + ) + + if result_document_type is None: + result_document_type = self.document_type + + return Dataset.from_hf_dataset( + dataset, + document_type=result_document_type, + document_converters=self.document_converters, + ) + + def cast_document_type( + self, + new_document_type: Type[D], + remove_columns: bool = False, + field_mapping: Optional[Dict[str, str]] = None, + ) -> "Dataset": + field_mapping = field_mapping or {} + + removed_field_names, added_field_names = _check_fields_for_casting( + field_mapping=field_mapping, + current_document_type=self.document_type, + new_document_type=new_document_type, + column_names=self.column_names, + ) + + new_hf_dataset = datasets.Dataset(**self.get_base_kwargs(self)) + + if remove_columns: + new_hf_dataset = new_hf_dataset.remove_columns(list(removed_field_names)) + + rename_targets_already_in_columns = ( + set(field_mapping.values()) - set(field_mapping) + ) & set(new_hf_dataset.column_names) + if len(rename_targets_already_in_columns) > 0: + raise ValueError( + f"rename targets are already in column names: {rename_targets_already_in_columns}. Did you miss " + f"to set remove_columns=True in a previous call of cast_document_type?" + ) + + new_hf_dataset = new_hf_dataset.rename_columns(field_mapping) + for f_name in added_field_names: + if f_name not in new_hf_dataset.column_names: + # add empty columns + new_hf_dataset = new_hf_dataset.add_column( + name=f_name, column=len(new_hf_dataset) * [{}] + ) + new_dataset = Dataset.from_hf_dataset( + new_hf_dataset, + document_type=new_document_type, + document_converters=self.document_converters, + ) + + return new_dataset + + +class IterableDataset(datasets.IterableDataset): + def __init__( + self, + document_type: Type[Document], + hidden_columns: Optional[Set[str]] = None, + document_converters: Optional[DocumentConvertersType] = None, + **kwargs, + ): + super().__init__(**kwargs) + self.document_type = document_type + self._document_field_names = [field.name for field in document_type.fields()] + self.hidden_columns = set() + if hidden_columns is not None: + self.hidden_columns.update(hidden_columns) + self.document_converters = document_converters or {} + + @property + def column_names(self) -> List[str]: + return self._document_field_names + list(self.hidden_columns) + + @classmethod + def get_base_kwargs(cls, dataset: datasets.IterableDataset): + return dict( + ex_iterable=dataset._ex_iterable, + info=dataset.info, + split=dataset.split, + formatting=dataset._formatting, + shuffling=dataset._shuffling, + distributed=dataset._distributed, + token_per_repo_id=dataset._token_per_repo_id, + ) + + @classmethod + def from_hf_dataset( + cls, + dataset: datasets.IterableDataset, + document_type: Type[Document], + hidden_columns: Optional[Set[str]] = None, + document_converters: Optional[DocumentConvertersType] = None, + ) -> "IterableDataset": + dataset = cls( + document_type=document_type, + hidden_columns=hidden_columns, + document_converters=document_converters, + **cls.get_base_kwargs(dataset), + ) + return dataset + + def __iter__(self): + for example in iter(super().__iter__()): + yield self.document_type.fromdict(example) + + def register_document_converter( + self, + converter: Union[Callable[..., D], Dict[str, str]], + document_type: Optional[Type[D]] = None, + ) -> None: + dataset_register_document_converter( + dataset=self, + converter=converter, + document_type=document_type, + ) + + def to_document_type( + self, + document_type: Type[Document], + **kwargs, + ) -> "IterableDataset": + return dataset_to_document_type( + dataset=self, + document_type=document_type, + **kwargs, + ) + + def map( # type: ignore + self, + function: Optional[Callable] = None, + batched: bool = False, + as_documents: bool = True, + result_document_type: Optional[Type[Document]] = None, + **kwargs, + ) -> "IterableDataset": + dataset_mapped = super().map( + function=decorate_convert_to_document_and_back( + function, document_type=self.document_type, batched=batched + ) + if as_documents + else function, + batched=batched, + **kwargs, + ) + + if result_document_type is None: + result_document_type = self.document_type + + return IterableDataset.from_hf_dataset( + dataset_mapped, + document_type=result_document_type, + document_converters=self.document_converters, + ) + + def apply_hf_func(self, func, **kwargs) -> "IterableDataset": + return IterableDataset.from_hf_dataset( + func(self, **kwargs), + document_type=self.document_type, + hidden_columns=self.hidden_columns, + document_converters=self.document_converters, + ) + + def cast_document_type( + self, + new_document_type: Type[D], + remove_columns: bool = False, + field_mapping: Optional[Dict[str, str]] = None, + ) -> "IterableDataset": + field_mapping = field_mapping or {} + + removed_field_names, added_field_names = _check_fields_for_casting( + field_mapping=field_mapping, + current_document_type=self.document_type, + new_document_type=new_document_type, + column_names=self.column_names, + ) + hidden_columns = set(self.hidden_columns) + new_hf_dataset = datasets.IterableDataset(**self.get_base_kwargs(self)) + + if remove_columns: + new_hf_dataset = new_hf_dataset.remove_columns(column_names=list(removed_field_names)) + else: + hidden_columns.update(removed_field_names) + + rename_targets_already_in_columns = ( + set(field_mapping.values()) - set(field_mapping) + ) & hidden_columns + if len(rename_targets_already_in_columns) > 0: + raise ValueError( + f"rename targets are already in column names: {rename_targets_already_in_columns}. Did you " + f"miss to set remove_columns=True in a previous call of cast_document_type?" + ) + + new_hf_dataset = new_hf_dataset.rename_columns(column_mapping=field_mapping) + + new_dataset = IterableDataset.from_hf_dataset( + new_hf_dataset, + hidden_columns=hidden_columns, + document_type=new_document_type, + document_converters=self.document_converters, + ) + + return new_dataset + + def take(self, n) -> "IterableDataset": + return self.apply_hf_func(datasets.IterableDataset.take, n=n) + + +def get_pie_dataset_type( + hf_dataset: Union[datasets.Dataset, datasets.IterableDataset] +) -> Union[Type[Dataset], Type[IterableDataset]]: + if isinstance(hf_dataset, datasets.Dataset): + return Dataset + elif isinstance(hf_dataset, datasets.IterableDataset): + return IterableDataset + else: + raise TypeError( + f"the dataset must be of type Dataset or IterableDataset, but is of type {type(hf_dataset)}" + ) diff --git a/src/pie_datasets/dataset_dict.py b/src/pie_datasets/dataset_dict.py new file mode 100644 index 00000000..ef0d5467 --- /dev/null +++ b/src/pie_datasets/dataset_dict.py @@ -0,0 +1,641 @@ +import json +import logging +import os +from pathlib import Path +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + SupportsIndex, + Type, + TypeVar, + Union, +) + +import datasets +from pytorch_ie.core import Document +from pytorch_ie.utils.hydra import resolve_target, serialize_document_type + +from .common import ( + EnterDatasetDictMixin, + EnterDatasetMixin, + ExitDatasetDictMixin, + ExitDatasetMixin, +) +from .dataset import Dataset, IterableDataset, get_pie_dataset_type + +logger = logging.getLogger(__name__) + +METADATA_FILE_NAME = "metadata.json" + + +D = TypeVar("D", bound=Document) + + +class DatasetDict(datasets.DatasetDict): + def __getitem__(self, k) -> Union[Dataset, IterableDataset]: # type: ignore + """Returns an individual dataset split.""" + + dataset = super().__getitem__(k) + if isinstance(dataset, (Dataset, IterableDataset)): + return dataset + else: + raise TypeError(f"dataset must be of type Dataset, but is {type(dataset)}") + + @classmethod + def load_dataset(cls, *args, split=None, **kwargs) -> "DatasetDict": + dataset_or_dataset_dict = datasets.load_dataset(*args, split=split, **kwargs) + if isinstance(dataset_or_dataset_dict, (Dataset, IterableDataset)): + if split is None: + raise ValueError( + f"split must be provided if the loaded dataset is not a (Iterable)DatasetDict, " + f"but is {type(dataset_or_dataset_dict)}" + ) + return cls({split: dataset_or_dataset_dict}) + elif isinstance( + dataset_or_dataset_dict, (datasets.DatasetDict, datasets.IterableDatasetDict) + ): + for dataset in dataset_or_dataset_dict.values(): + if not isinstance(dataset, (Dataset, IterableDataset)): + raise TypeError( + f"expected pie_datasets.Dataset or pie_datasets.IterableDataset, but got {type(dataset)}" + ) + return cls(dataset_or_dataset_dict) + else: + raise TypeError( + f"expected pie_datasets.DatasetDict, pie_datasets.IterableDatasetDict, pie_datasets.Dataset, " + f"or pie_datasets.IterableDataset, but got {type(dataset_or_dataset_dict)}" + ) + + @classmethod + def from_hf( + cls, + hf_dataset: Union[ + datasets.DatasetDict, + datasets.IterableDatasetDict, + Dict[str, datasets.Dataset], + Dict[str, datasets.IterableDataset], + ], + document_type: Union[str, Type[Document]], + ) -> "DatasetDict": + """Creates a PIE DatasetDict from a HuggingFace DatasetDict, or IterableDatasetDict. If the + input is a Dataset or IterableDataset, we create a DatasetDict with one split named + "train". + + Args: + hf_dataset: HuggingFace (Iterable)Dataset(Dict) + document_type: document type of the dataset. Can be a subclass of Document or string that can be + resolved to such a type. + """ + + doc_type = resolve_target(document_type) + if not isinstance(doc_type, type) or not issubclass(doc_type, Document): + raise TypeError(f"document_type must be a subclass of Document, but is {doc_type}") + + res = cls( + { + k: get_pie_dataset_type(v).from_hf_dataset(v, document_type=doc_type) + for k, v in hf_dataset.items() + } + ) + return res + + @classmethod + def from_json( # type: ignore + cls, + document_type: Optional[Union[Type[Document], str]] = None, + metadata_path: Optional[Union[str, Path]] = None, + data_dir: Optional[str] = None, + split: Optional[str] = None, + **kwargs, + ) -> "DatasetDict": + """Creates a PIE DatasetDict from JSONLINE files. Uses `datasets.load_dataset("json")` + under the hood. Requires a document type to be provided. If the document type is not + provided, we try to load it from the metadata file. + + Args: + document_type: document type of the dataset + data_dir: Defining the `data_dir` of the dataset configuration. See datasets.load_dataset() for more + information. + metadata_path: path to the metadata file. Should point to a directory containing the metadata file + `metadata.json`. Defaults to the value of the `data_dir` parameter. + split: if provided, only the specified split is loaded. see `datasets.load_dataset()` for more information. + **kwargs: additional keyword arguments for `datasets.load_dataset()` + """ + + # try to load metadata + if metadata_path is None: + metadata_path = data_dir + if metadata_path is not None: + metadata_file_name = Path(metadata_path) / METADATA_FILE_NAME + if os.path.exists(metadata_file_name): + with open(metadata_file_name) as f: + metadata = json.load(f) + document_type = document_type or metadata.get("document_type", None) + + if document_type is None: + raise ValueError( + "document_type must be provided if it cannot be loaded from the metadata file" + ) + + hf_dataset = datasets.load_dataset("json", data_dir=data_dir, split=split, **kwargs) + if isinstance(hf_dataset, (datasets.Dataset, datasets.IterableDataset)): + if split is None: + raise ValueError( + f"split must be provided if the loaded dataset is not a (Iterable)DatasetDict, " + f"but is {type(hf_dataset)}" + ) + hf_dataset = {split: hf_dataset} + return cls.from_hf(hf_dataset, document_type=document_type) + + def to_json(self, path: Union[str, Path], **kwargs) -> None: + """Serializes the DatasetDict. We convert all documents with `.asdict()` and dump them with + `json.dump()` to one JSONLINE file per split. + + Args: + path: path to the output directory + **kwargs: additional keyword arguments for `json.dump()` + """ + + path = Path(path) + + # save the metadata + metadata = {"document_type": serialize_document_type(self.document_type)} + os.makedirs(path, exist_ok=True) + if os.path.exists(path / METADATA_FILE_NAME): + logger.warning( + f"metadata file '{path / METADATA_FILE_NAME}' already exists, overwriting it" + ) + with open(path / METADATA_FILE_NAME, "w") as f: + json.dump(metadata, f, indent=2) + + # save the splits + for split, dataset in self.items(): + split_path = path / split + logger.info(f'serialize documents to "{split_path}" ...') + os.makedirs(split_path, exist_ok=True) + file_name = split_path / "documents.jsonl" + with open(file_name, "w") as f: + for doc in dataset: + f.write(json.dumps(doc.asdict(), **kwargs) + "\n") + + @property + def document_type(self) -> Type[Document]: + """Returns the document type of the dataset splits. + + Raises an error if there are no splits in the dataset or if the dataset splits have + different document types. + """ + + if len(self) == 0: + raise ValueError("dataset does not contain any splits, cannot determine document type") + document_types = {ds.document_type for ds in self.values()} + if len(document_types) > 1: + raise ValueError( + f"dataset contains splits with different document types: {document_types}" + ) + return next(iter(document_types)) + + @property + def dataset_type(self) -> Union[Type[Dataset], Type[IterableDataset]]: + """Returns the dataset type of the dataset splits, i.e. either `Dataset` or + `IterableDataset`. + + Raises an error if there are no splits in the dataset or if the dataset splits have + different dataset types. + """ + + if len(self) == 0: + raise ValueError( + "dataset does not contain any splits, cannot determine the dataset type" + ) + dataset_types = {type(ds) for ds in self.values()} + if len(dataset_types) > 1: + raise ValueError( + f"dataset contains splits with different dataset types: {dataset_types}" + ) + return next(iter(dataset_types)) + + def register_document_converter( + self, + converter: Union[Callable[..., D], Dict[str, str], str], + document_type: Optional[Union[Type[D], str]] = None, + ) -> "DatasetDict": + """Register a converter function or field mapping for a target document type. + + Args: + document_type: The target document type for which the converter should be registered. Can be a subclass + of Document or string that can be resolved to such a type. If `None`, the document type is tried to be + inferred from the converter function signature. + converter: Either a function that converts a document of the document type of this dataset to a document + of the target document_type, a string that can be resolved to such a function, or a field mapping + (dict[str, str]) that maps fields of the document type of this dataset to fields of the target + document_type. + """ + resolved_document_type: Optional[Union[Type[D], Callable]] = None + if document_type is not None: + if isinstance(document_type, str): + resolved_document_type = resolve_target(document_type) + else: + resolved_document_type = document_type + if not ( + isinstance(resolved_document_type, type) + and issubclass(resolved_document_type, Document) + ): + raise TypeError( + f"document_type must be or resolve to a subclass of Document, but is '{document_type}'" + ) + + resolved_converter: Union[Callable[..., Any], dict[str, str]] + if isinstance(converter, str): + resolved_converter = resolve_target(converter) + else: + resolved_converter = converter + if not (callable(resolved_converter) or isinstance(resolved_converter, dict)): + raise TypeError( + f"converter must be a callable or a dict, but is {type(resolved_converter)}" + ) + + for ds in self.values(): + ds.register_document_converter( + document_type=resolved_document_type, converter=resolved_converter + ) + return self + + def to_document_type( + self, + document_type: Union[Type[Document], str], + **kwargs, + ) -> "DatasetDict": + """Converts all documents in the dataset to a new document type using the best registered + document converter. + + Args: + document_type: document type to convert the documents to. Can be a subclass of Document or string that + can be resolved to such a type. + """ + + if isinstance(document_type, str): + resolved_document_type = resolve_target(document_type) + else: + resolved_document_type = document_type + if not ( + isinstance(resolved_document_type, type) + and issubclass(resolved_document_type, Document) + ): + raise TypeError( + f"document_type must be a document type or a string that can be resolved to such a type, " + f"but got {document_type}." + ) + + if resolved_document_type == self.document_type: + logger.info(f"The dataset has already the requested document type {document_type}.") + return self + + result = type(self)( + { + name: ds.to_document_type(document_type=resolved_document_type, **kwargs) + for name, ds in self.items() + } + ) + return result + + def map( # type: ignore + self, + function: Optional[Union[Callable, str]] = None, + result_document_type: Optional[Union[str, Type[Document]]] = None, + **kwargs, + ) -> "DatasetDict": + """Applies a function to all documents in the dataset. + + If the function is an object and is derived from the following mixins, the respective logic + is applied: + - EnterDatasetMixin: `enter_dataset(dataset_split, split_name)` is called before the function is + applied to a dataset split + - ExitDatasetMixin: `exit_dataset(processed_dataset_split, split_name)` is called after the function + is applied to a dataset split + - EnterDatasetDictMixin: `enter_dataset_dict(dataset_dict)` is called before any dataset split is + processed (and before any `enter_dataset()` is called) + - ExitDatasetDictMixin: `exit_dataset_dict(processed_dataset_dict)` is called after all dataset splits + are processed (and after all `exit_dataset()` are called) + + Args: + function: function to apply to the documents. If `None`, the identity function is used. If `str`, + the function is resolved from the global namespace. + result_document_type: optional document type of the resulting dataset. Can be a subclass of Document or + string that can be resolved to such a type. If not provided, it is tried to infer it from the + function signature. If this is not possible, the document type of the input dataset + is used. + **kwargs: additional keyword arguments for `datasets.Dataset.map()` + """ + + if function is not None: + func = resolve_target(function) + if not callable(func): + raise TypeError(f"function must be callable, but is of type {type(func)}") + else: + + def identity(x): + # exclude from coverage because its usage happens in the map which is not collected + return x # pragma: no cover + + func = identity + map_kwargs = dict(function=func, **kwargs) + if result_document_type is not None: + map_kwargs["result_document_type"] = resolve_target(result_document_type) + + if isinstance(func, EnterDatasetDictMixin): + func.enter_dataset_dict(self) + + result_dict = {} + for split, dataset in self.items(): + if isinstance(func, EnterDatasetMixin): + func.enter_dataset(dataset=dataset, name=split) + result_dict[split] = dataset.map(**map_kwargs) + if isinstance(func, ExitDatasetMixin): + func.exit_dataset(dataset=result_dict[split], name=split) + + result = type(self)(result_dict) + + if isinstance(func, ExitDatasetDictMixin): + func.exit_dataset_dict(result) + + return result + + def select( + self, + split: str, + start: Optional[SupportsIndex] = None, + stop: Optional[SupportsIndex] = None, + step: Optional[SupportsIndex] = None, + **kwargs, + ) -> "DatasetDict": + """Reduce a certain dataset split to a selection of its documents. This is similar to the + Huggingface `select()`, but adds optional parameters `start`, `stop`, `step` that will be + used to create indices, if available. + + Args: + split: name of the dataset split to modify + start: optional start index of the selection + stop: optional stop index of the selection + step: optional step size of the selection + **kwargs: additional keyword arguments for `datasets.Dataset.select()` + """ + + if stop is not None: + range_args = [stop] + if start is not None: + range_args = [start] + range_args + if step is not None: + range_args = range_args + [step] + kwargs["indices"] = range(*range_args) + + if "indices" in kwargs: + result = type(self)(self) + pie_split = result[split] + if not isinstance(pie_split, Dataset): + raise TypeError( + f"can only select from a Dataset, but the split '{split}' is of type {type(pie_split)}" + ) + result[split] = Dataset.from_hf_dataset( + dataset=pie_split.select(**kwargs), document_type=pie_split.document_type + ) + return result + else: + if len(kwargs) > 0: + logger.warning( + f"arguments for dataset.select() available, but they do not contain 'indices' which is required, " + f"so we do not call select. provided arguments: \n{json.dumps(kwargs, indent=2)}" + ) + return self + + def rename_splits( + self, + mapping: Optional[Dict[str, str]] = None, + keep_other_splits: bool = True, + ) -> "DatasetDict": + """Renames the dataset splits. + + Args: + mapping: mapping from old split names to new split names. + keep_other_splits: if `True` (default), splits not contained in `mapping` are kept in the dataset + """ + + if mapping is None: + mapping = {} + result = type(self)( + { + mapping.get(name, name): data + for name, data in self.items() + if name in mapping or keep_other_splits + } + ) + return result + + def add_test_split( + self, + source_split: str = "train", + target_split: str = "test", + **kwargs, + ) -> "DatasetDict": + """Adds a test split to the dataset by splitting the source split. + + Uses the Huggingface `train_test_split()` method. + """ + + pie_split = self[source_split] + if not isinstance(pie_split, Dataset): + raise TypeError( + f"can only create a train-test-split from a Dataset, but the source split '{source_split}' is of type " + f"{type(pie_split)}" + ) + split_result_hf = pie_split.train_test_split(**kwargs) + split_result = type(self)( + { + name: Dataset.from_hf_dataset( + ds, + document_type=pie_split.document_type, + document_converters=pie_split.document_converters, + ) + for name, ds in split_result_hf.items() + } + ) + res = type(self)(self) + res[source_split] = split_result["train"] + res[target_split] = split_result["test"] + split_sizes = {k: len(v) for k, v in res.items()} + logger.info(f"dataset size after adding the split: {split_sizes}") + return res + + def drop_splits(self, split_names: List[str]) -> "DatasetDict": + """Drops splits from the dataset. + + Args: + split_names: names of the splits to drop + """ + + result = type(self)({name: ds for name, ds in self.items() if name not in split_names}) + return result + + def concat_splits(self, splits: List[str], target: str) -> "DatasetDict": + """Concatenates selected splits into a new split. + + Args: + splits: names of the splits to concatenate + target: name of the new split + """ + + if any(split not in self for split in splits): + raise ValueError( + f"not all splits to concatenate are present in the dataset: {splits}, {self.keys()}" + ) + if len(splits) == 0: + raise ValueError("please provide at least one split to concatenate") + result = type(self)({name: ds for name, ds in self.items() if name not in splits}) + if not issubclass(self.dataset_type, Dataset): + raise TypeError( + f"can only concatenate splits if the dataset type is a Dataset, but it is {self.dataset_type}" + ) + splits_to_concat: List[Dataset] = [self[name] for name in splits] # type: ignore + if any(self.dataset_type != type(ds) for ds in splits_to_concat): + raise ValueError( + f"not all splits to concatenate have the same dataset type: " + f"{({name: type(self[name]) for name in splits})}" + ) + document_converters = None + for ds in splits_to_concat: + if ds.document_converters is not None: + if document_converters is None: + document_converters = {} + document_converters.update(ds.document_converters) + # TODO: why do we need to ignore the typing here? + concatenated = datasets.concatenate_datasets(splits_to_concat) # type: ignore + if not issubclass(self.dataset_type, type(concatenated)): + raise ValueError( + f"concatenated dataset is not of the same type as the original dataset: " + f"{self.dataset_type}, {type(concatenated)}" + ) + result[target] = self.dataset_type.from_hf_dataset( + concatenated, document_type=self.document_type, document_converters=document_converters + ) + split_sizes = {k: len(v) for k, v in result.items()} + logger.info(f"dataset size after concatenating splits: {split_sizes}") + return result + + def filter( # type: ignore + self, + split: str, + function: Optional[Union[Callable[[Dict], bool], str]] = None, + result_split_name: Optional[str] = None, + **kwargs, + ) -> "DatasetDict": + """Filters a dataset split using a filter function. + + Note: In contrast to `map`, the filter function gets the example dict instead of a document as input + because the PIE variant of `Dataset.filter()` is not yet implemented and, thus, the Huggingface + variant is internally used instead. + + Args: + split: name of the split to filter + function: filter function that is called on each example dict. Can be provided as a callable or as a + string that is resolved to a callable using `resolve_target()`. + result_split_name: name of the split to store the filtered examples in. If `None`, the filtered examples + are stored in the same split as the original examples. + """ + + if function is not None: + # create a shallow copy to not modify the input + result = type(self)(self) + function = resolve_target(function) + pie_split = result[split] + # TODO: Implement pytorch_ie.Dataset.filter() in a similar way such as map() to make use of the + # document type. For now, the filter function is called directly on the HF dataset and thus needs to + # accept a dict as input. + # we need to convert the dataset back to HF because the filter function internally uses map() which will + # break if the PIE variant is used + hf_split: Union[datasets.Dataset, datasets.IterableDataset] + if isinstance(pie_split, Dataset): + hf_split = datasets.Dataset(**Dataset.get_base_kwargs(pie_split)) + elif isinstance(pie_split, IterableDataset): + hf_split = datasets.IterableDataset(**IterableDataset.get_base_kwargs(pie_split)) + else: + raise ValueError(f"dataset split has unknown type: {type(pie_split)}") + hf_split_filtered = hf_split.filter(function=function, **kwargs) + target_split_name = result_split_name or split + target_split = type(pie_split).from_hf_dataset( + dataset=hf_split_filtered, # type: ignore + document_type=pie_split.document_type, + document_converters=pie_split.document_converters, + ) + # iterable datasets do not have a length + if not isinstance(target_split, IterableDataset): + logger.info( + f"filtered split [{target_split_name}] has {len(target_split)} entries" + ) + result[target_split_name] = target_split + return result + else: + return self + + def move_to_new_split( + self, + ids: Optional[List[str]] = None, + filter_function: Optional[Union[Callable[[Dict[str, Any]], bool], str]] = None, + source_split: str = "train", + target_split: str = "test", + ) -> "DatasetDict": + """Moves examples from one split to another split. ids or a filter function can be provided + to select the examples to move. + + Args: + ids: list of ids of the examples to move + filter_function: filter function that is called on each example dict. Can be provided as a callable or as a + string that can be resolved to such a callable. + source_split: name of the split to move the examples from + target_split: name of the split to move the examples to + """ + + if filter_function is not None: + filter_func = resolve_target(filter_function) + else: + if ids is None: + raise ValueError("please provide either a list of ids or a filter function") + + ids_set = set(ids) + + def filter_with_ids(ex: Dict[str, Any]): + # exclude from coverage because its usage happens in the map which is not collected + return ex["id"] in ids_set # pragma: no cover + + filter_func = filter_with_ids + + dataset_with_only_ids = self.filter( + split=source_split, + function=filter_func, + ) + dataset_without_ids = self.filter( + split=source_split, + function=lambda ex: not filter_func(ex), + ) + dataset_without_ids[target_split] = dataset_with_only_ids[source_split] + + split_sizes = {k: len(v) for k, v in dataset_without_ids.items()} + logger.info(f"dataset size after moving to new split: {split_sizes}") + return dataset_without_ids + + def cast_document_type( + self, new_document_type: Union[Type[Document], str], **kwargs + ) -> "DatasetDict": + """Casts the document type of all splits to a new document type.""" + + new_type = resolve_target(new_document_type) + + result = type(self)( + { + name: ds.cast_document_type(new_document_type=new_type, **kwargs) + for name, ds in self.items() + } + ) + return result diff --git a/src/pie_datasets/document/conversion.py b/src/pie_datasets/document/conversion.py new file mode 100644 index 00000000..e2e81679 --- /dev/null +++ b/src/pie_datasets/document/conversion.py @@ -0,0 +1,302 @@ +import functools +import logging +from collections import defaultdict +from copy import copy, deepcopy +from typing import ( + Callable, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) + +from pytorch_ie.annotations import Span +from pytorch_ie.core import Annotation +from pytorch_ie.documents import TextBasedDocument, TokenBasedDocument +from pytorch_ie.utils.hydra import resolve_target +from transformers import PreTrainedTokenizer + +logger = logging.getLogger(__name__) + +ToD = TypeVar("ToD", bound=TokenBasedDocument) +TeD = TypeVar("TeD", bound=TextBasedDocument) + + +def text_based_document_to_token_based( + doc: TextBasedDocument, + result_document_type: Union[Type[ToD], str], + tokens: Optional[List[str]] = None, + token_offset_mapping: Optional[List[Tuple[int, int]]] = None, + char_to_token: Optional[Callable[[int], Optional[int]]] = None, + strict_span_conversion: bool = True, + verbose: bool = True, +) -> ToD: + document_type: Type[ToD] + if isinstance(result_document_type, str): + document_type = resolve_target(result_document_type) # type: ignore + else: + document_type = result_document_type + if not (isinstance(document_type, type) and issubclass(document_type, TokenBasedDocument)): + raise TypeError( + f"result_document_type must be a subclass of TokenBasedDocument or a string that resolves to that, " + f"but got {result_document_type}" + ) + if tokens is None: + tokens = doc.metadata.get("tokens") + if tokens is None: + raise ValueError( + "tokens must be provided to convert a text based document to token based, but got None" + ) + result = document_type(tokens=tuple(tokens), id=doc.id, metadata=deepcopy(doc.metadata)) + + # save text, token_offset_mapping and char_to_token (if available) in metadata + result.metadata["text"] = doc.text + token_offset_mapping_lists: Optional[List[List[int]]] + if token_offset_mapping is not None: + # convert offset tuples to lists because serialization and deserialization again + # will produce lists in any way (json does not know tuples) + token_offset_mapping_lists = [list(offsets) for offsets in token_offset_mapping] + if ( + "token_offset_mapping" in doc.metadata + and doc.metadata["token_offset_mapping"] != token_offset_mapping_lists + ): + logger.warning( + "token_offset_mapping in metadata is different from the new token_offset_mapping, " + "overwrite the metadata" + ) + result.metadata["token_offset_mapping"] = token_offset_mapping_lists + else: + token_offset_mapping_lists = doc.metadata.get("token_offset_mapping") + if token_offset_mapping_lists is not None: + token_offset_mapping = [tuple(offsets) for offsets in token_offset_mapping_lists] # type: ignore + if char_to_token is not None: + if "char_to_token" in doc.metadata and doc.metadata["char_to_token"] != char_to_token: + logger.warning( + "char_to_token in metadata is different from the new char_to_token, overwrite the metadata" + ) + result.metadata["char_to_token"] = char_to_token + else: + char_to_token = doc.metadata.get("char_to_token") + + # construct the char_to_token function, if not provided, from the token_offset_mapping + if char_to_token is None: + if token_offset_mapping is None: + raise ValueError( + "either token_offset_mapping or char_to_token must be provided to convert a text " + "based document to token based, but both are None" + ) + char_to_token_dict: Dict[int, int] = {} + for token_idx, (start, end) in enumerate(token_offset_mapping): + for char_idx in range(start, end): + char_to_token_dict[char_idx] = token_idx + + def char_to_token(char_idx: int) -> Optional[int]: + return char_to_token_dict.get(char_idx) + + text_targeting_layers = [ + annotation_field.name + for annotation_field in doc.annotation_fields() + if "text" in annotation_field.metadata["targets"] + ] + + override_annotations: Dict[str, Dict[int, Annotation]] = {} + removed_annotations: Dict[str, Set[int]] = defaultdict(set) + for text_targeting_layer_name in text_targeting_layers: + override_annotations[text_targeting_layer_name] = {} + char_span: Span + for char_span in doc[text_targeting_layer_name]: + if not isinstance(char_span, Span): + raise ValueError( + f"can not convert layers that target the text but contain non-span annotations, " + f"but found {type(char_span)} in layer {text_targeting_layer_name}" + ) + start_token_idx = char_to_token(char_span.start) + end_token_idx_inclusive = char_to_token(char_span.end - 1) + if start_token_idx is None or end_token_idx_inclusive is None: + if strict_span_conversion: + raise ValueError( + f'cannot find token span for character span: "{char_span}", text="{doc.text}", ' + f"token_offset_mapping={token_offset_mapping}" + ) + else: + if verbose: + logger.warning( + f'cannot find token span for character span "{char_span}", skip it (disable this ' + f"warning with verbose=False)" + ) + removed_annotations[text_targeting_layer_name].add(char_span._id) + else: + token_span = char_span.copy(start=start_token_idx, end=end_token_idx_inclusive + 1) + override_annotations[text_targeting_layer_name][char_span._id] = token_span + valid_spans = set(override_annotations[text_targeting_layer_name].values()) + result[text_targeting_layer_name].extend(sorted(valid_spans, key=lambda span: span.start)) + + result.add_all_annotations_from_other( + doc, + override_annotations=override_annotations, + removed_annotations=removed_annotations, + strict=strict_span_conversion, + verbose=verbose, + ) + + return result + + +def token_based_document_to_text_based( + doc: TokenBasedDocument, + result_document_type: Union[Type[TeD], str], + text: Optional[str] = None, + token_offset_mapping: Optional[List[Tuple[int, int]]] = None, + join_tokens_with: Optional[str] = None, + strict_span_conversion: bool = True, + verbose: bool = True, +) -> TeD: + document_type: Type[TeD] + if isinstance(result_document_type, str): + document_type = resolve_target(result_document_type) # type: ignore + else: + document_type = result_document_type + if not (isinstance(document_type, type) and issubclass(document_type, TextBasedDocument)): + raise TypeError( + f"result_document_type must be a subclass of TextBasedDocument or a string that resolves to that, " + f"but got {result_document_type}" + ) + # if a token_separator is provided, we construct the text from the tokens + if join_tokens_with is not None: + start = 0 + token_offset_mapping = [] + tokens = doc.tokens + for token in tokens: + end = start + len(token) + token_offset_mapping.append((start, end)) + # we add the separator after each token + start = end + len(join_tokens_with) + text = join_tokens_with.join(tokens) + else: + text = doc.metadata.get("text") if text is None else text + if text is None: + raise ValueError( + "if join_tokens_with is None, text must be provided, but got None as well" + ) + token_offset_mapping_lists = ( + doc.metadata.get("token_offset_mapping") + if token_offset_mapping is None + else token_offset_mapping + ) + if token_offset_mapping_lists is None: + raise ValueError( + "if join_tokens_with is None, token_offsets must be provided, but got None as well" + ) + else: + token_offset_mapping = [tuple(offsets) for offsets in token_offset_mapping_lists] # type: ignore + + result = document_type(text=text, id=doc.id, metadata=deepcopy(doc.metadata)) + if "tokens" in doc.metadata and doc.metadata["tokens"] != list(doc.tokens): + logger.warning("tokens in metadata are different from new tokens, overwrite the metadata") + result.metadata["tokens"] = list(doc.tokens) + # convert offset tuples to lists because serialization and deserialization again + # will produce lists in any way (json does not know tuples) + token_offset_mapping_lists = [list(offsets) for offsets in token_offset_mapping] + if ( + "token_offset_mapping" in doc.metadata + and doc.metadata["token_offset_mapping"] != token_offset_mapping_lists + ): + logger.warning( + "token_offset_mapping in metadata is different from the new token_offset_mapping, " + "overwrite the metadata" + ) + result.metadata["token_offset_mapping"] = token_offset_mapping_lists + + token_targeting_layers = [ + annotation_field.name + for annotation_field in doc.annotation_fields() + if "tokens" in annotation_field.metadata["targets"] + ] + + override_annotations: Dict[str, Dict[int, Annotation]] = {} + removed_annotations: Dict[str, Set[int]] = defaultdict(set) + for token_targeting_layer_name in token_targeting_layers: + override_annotations[token_targeting_layer_name] = {} + for token_span in doc[token_targeting_layer_name]: + if not isinstance(token_span, Span): + raise ValueError( + f"can not convert layers that target the tokens but contain non-span annotations, " + f"but found {type(token_span)} in layer {token_targeting_layer_name}" + ) + start_char_idx = token_offset_mapping[token_span.start][0] + end_char_idx = token_offset_mapping[token_span.end - 1][1] + + char_span = token_span.copy(start=start_char_idx, end=end_char_idx) + override_annotations[token_targeting_layer_name][token_span._id] = char_span + valid_spans = set(override_annotations[token_targeting_layer_name].values()) + result[token_targeting_layer_name].extend(sorted(valid_spans, key=lambda span: span.start)) + + result.add_all_annotations_from_other( + doc, + override_annotations=override_annotations, + removed_annotations=removed_annotations, + strict=strict_span_conversion, + verbose=verbose, + ) + + return result + + +def tokenize_document( + doc: TextBasedDocument, + tokenizer: PreTrainedTokenizer, + result_document_type: Type[ToD], + partition_layer: Optional[str] = None, + strict_span_conversion: bool = True, + verbose: bool = True, + **tokenize_kwargs, +) -> List[ToD]: + result = [] + partitions: Iterable[Span] + if partition_layer is None: + partitions = [Span(start=0, end=len(doc.text))] + else: + partitions = doc[partition_layer] + for partition in partitions: + text = doc.text[partition.start : partition.end] + current_tokenize_kwargs = copy(tokenize_kwargs) + if "text" in tokenize_kwargs: + current_tokenize_kwargs["text_pair"] = text + sequence_index = 1 + else: + current_tokenize_kwargs["text"] = text + sequence_index = 0 + tokenized_text = tokenizer(**current_tokenize_kwargs) + for batch_encoding in tokenized_text.encodings: + token_offset_mapping = batch_encoding.offsets + char_to_token: Optional[Callable[[int], Optional[int]]] + char_to_token = functools.partial( + batch_encoding.char_to_token, sequence_index=sequence_index + ) + token_offset_mapping = [ + offsets if s_id == sequence_index else (0, 0) + for s_id, offsets in zip(batch_encoding.sequence_ids, token_offset_mapping) + ] + if partition.start > 0: + token_offset_mapping = [ + (start + partition.start, end + partition.start) + for start, end in token_offset_mapping + ] + char_to_token = None + tokenized_document = text_based_document_to_token_based( + doc, + tokens=batch_encoding.tokens, + result_document_type=result_document_type, + token_offset_mapping=token_offset_mapping, + char_to_token=char_to_token, + strict_span_conversion=strict_span_conversion, + verbose=verbose, + ) + tokenized_document.metadata["tokenizer_encoding"] = batch_encoding + result.append(tokenized_document) + return result diff --git a/src/pie_datasets/document/processing/regex_partitioner.py b/src/pie_datasets/document/processing/regex_partitioner.py index 49df118e..743c0fc5 100644 --- a/src/pie_datasets/document/processing/regex_partitioner.py +++ b/src/pie_datasets/document/processing/regex_partitioner.py @@ -8,9 +8,10 @@ from pytorch_ie import Dataset, IterableDataset from pytorch_ie.annotations import LabeledSpan -from pytorch_ie.data.common import EnterDatasetMixin, ExitDatasetMixin from pytorch_ie.documents import TextBasedDocument +from pie_datasets import EnterDatasetMixin, ExitDatasetMixin + logger = logging.getLogger(__name__) diff --git a/src/pie_datasets/document_formatter.py b/src/pie_datasets/document_formatter.py new file mode 100644 index 00000000..6e66013d --- /dev/null +++ b/src/pie_datasets/document_formatter.py @@ -0,0 +1,22 @@ +from typing import List + +import pyarrow as pa +from datasets.formatting.formatting import Formatter +from pytorch_ie.core.document import Document + + +class DocumentFormatter(Formatter[Document, list, List[Document]]): + def __init__(self, document_type, features=None, **kwargs): + super().__init__(features=None) + self.document_type = document_type + + def format_row(self, pa_table: pa.Table) -> Document: + row = self.python_arrow_extractor().extract_row(pa_table) + return self.document_type.fromdict(row) + + def format_column(self, pa_table: pa.Table) -> list: + return [] + + def format_batch(self, pa_table: pa.Table) -> List[Document]: + batch = self.simple_arrow_extractor().extract_batch(pa_table).to_pylist() + return [self.document_type.fromdict(b) for b in batch] diff --git a/tests/conftest.py b/tests/conftest.py index 87b72e22..3eda6d56 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,11 +6,11 @@ import pkg_resources import pytest from datasets import load_dataset -from pytorch_ie import DatasetDict from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.documents import TextBasedDocument +from pie_datasets import DatasetDict from tests import FIXTURES_ROOT from tests.dataset_builders.common import DATASET_BUILDER_BASE_PATH diff --git a/tests/dataset_builders/pie/test_conll2003.py b/tests/dataset_builders/pie/test_conll2003.py index 017eb8b8..a613283c 100644 --- a/tests/dataset_builders/pie/test_conll2003.py +++ b/tests/dataset_builders/pie/test_conll2003.py @@ -1,10 +1,10 @@ import datasets import pytest -from pytorch_ie import DatasetDict from pytorch_ie.core import Document from pytorch_ie.documents import TextDocumentWithLabeledSpans from dataset_builders.pie.conll2003.conll2003 import Conll2003 +from pie_datasets import DatasetDict from tests.dataset_builders.common import PIE_BASE_PATH DATASET_NAME = "conll2003" From be62129f314e98c0705f6537e247fc20d4fee8ba Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 7 Nov 2023 00:26:20 +0100 Subject: [PATCH 02/12] use pytorch-ie from https://github.com/ChristophAlt/pytorch-ie/pull/363 --- poetry.lock | 26 +++++++++++++++----------- pyproject.toml | 3 ++- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/poetry.lock b/poetry.lock index 3cd8a62e..b9fcea87 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1220,20 +1220,24 @@ name = "pytorch-ie" version = "0.26.0" description = "State-of-the-art Information Extraction in PyTorch" optional = false -python-versions = ">=3.9,<4.0" -files = [ - {file = "pytorch_ie-0.26.0-py3-none-any.whl", hash = "sha256:a6661231bed64aa882c5d2dd8391461c2fffff7aedd5af745d1207e6c9ca9513"}, - {file = "pytorch_ie-0.26.0.tar.gz", hash = "sha256:7e1adb972e2995dcfd01943a63a1d2385e215551fa5989320c44b0a7dfc1f58d"}, -] +python-versions = "^3.9" +files = [] +develop = false [package.dependencies] -absl-py = ">=1.0.0,<2.0.0" -datasets = ">=2.13,<3.0" +absl-py = "^1.0.0" +datasets = "^2.13" fsspec = "<2023.9.0" -pytorch-lightning = ">=2,<3" +pytorch-lightning = "^2" torch = ">=1.10" -torchmetrics = ">=1,<2" -transformers = ">=4.18,<5.0" +torchmetrics = "^1" +transformers = "^4.18" + +[package.source] +type = "git" +url = "https://github.com/ChristophAlt/pytorch-ie.git" +reference = "decouple_pie_dataset" +resolved_reference = "15862d54de0066d2dee0f69fd7bf27527bbfb81d" [[package]] name = "pytorch-lightning" @@ -2157,4 +2161,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "8934ab1306c4ec0c0bcab988e1fea5e760643a2fe6bf0d029392f6532480a493" +content-hash = "405005251a71a59558ec8b03c8b5bc44e84682b9035fe6a1053f864dec2c7e54" diff --git a/pyproject.toml b/pyproject.toml index 2d05d282..47162a67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,8 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.9" -pytorch-ie = ">=0.26.0,<0.27.0" +#pytorch-ie = ">=0.26.0,<0.27.0" +pytorch-ie = { git = "https://github.com/ChristophAlt/pytorch-ie.git", branch = "decouple_pie_dataset" } [tool.poetry.group.dev.dependencies] torch = {version = "^2.1.0+cpu", source = "pytorch"} From eb91f98639b9a87815888ba3119d105c61a65b48 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 7 Nov 2023 00:26:53 +0100 Subject: [PATCH 03/12] exclude tests/fixtures from spelling checks --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 423a3f22..3c6c8ee5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -79,7 +79,7 @@ repos: hooks: - id: codespell args: - - --skip=logs/**,data/** + - --skip=logs/**,data/**,tests/fixtures/** # hist: required for plotext.hist() # ba: denotes beginning of an encoding with label as 'a'. More details at src/pie_utils/sequence_tagging/ill_formed.py - --ignore-words-list=hist,ba From ac7514838ea80628cd39893ed13e2b08fc19a12f Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 7 Nov 2023 00:28:43 +0100 Subject: [PATCH 04/12] add tests for migrated classes (Dataset and friends) --- tests/__init__.py | 18 +- tests/conftest.py | 36 +- .../base_multi_config/base_multi_config.py | 249 ++++++++ .../base_single_config/base_single_config.py | 250 ++++++++ .../default_config_kwargs.py | 58 ++ .../datasets/multi_config/multi_config.py | 54 ++ .../datasets/name_mapping/name_mapping.py | 57 ++ .../name_mapping_disabled.py | 57 ++ .../datasets/single_config/single_config.py | 52 ++ .../wrong_builder_class_config.py | 52 ++ .../conll2003_extract/test/documents.jsonl | 3 + .../conll2003_extract/train/documents.jsonl | 3 + .../validation/documents.jsonl | 3 + tests/fixtures/hf_datasets/json/train.json | 4 +- tests/unit/__init__.py | 0 tests/unit/document/__init__.py | 0 tests/unit/document/processing/__init__.py | 0 tests/unit/document/test_conversion.py | 558 ++++++++++++++++++ tests/unit/test_builder.py | 224 +++++++ tests/unit/test_dataset.py | 460 +++++++++++++++ tests/unit/test_dataset_casting.py | 238 ++++++++ tests/unit/test_dataset_dict.py | 534 +++++++++++++++++ 22 files changed, 2905 insertions(+), 5 deletions(-) create mode 100644 tests/fixtures/builder/datasets/base_multi_config/base_multi_config.py create mode 100644 tests/fixtures/builder/datasets/base_single_config/base_single_config.py create mode 100644 tests/fixtures/builder/datasets/default_config_kwargs/default_config_kwargs.py create mode 100644 tests/fixtures/builder/datasets/multi_config/multi_config.py create mode 100644 tests/fixtures/builder/datasets/name_mapping/name_mapping.py create mode 100644 tests/fixtures/builder/datasets/name_mapping_disabled/name_mapping_disabled.py create mode 100644 tests/fixtures/builder/datasets/single_config/single_config.py create mode 100644 tests/fixtures/builder/datasets/wrong_builder_class_config/wrong_builder_class_config.py create mode 100644 tests/fixtures/dataset_dict/conll2003_extract/test/documents.jsonl create mode 100644 tests/fixtures/dataset_dict/conll2003_extract/train/documents.jsonl create mode 100644 tests/fixtures/dataset_dict/conll2003_extract/validation/documents.jsonl create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/document/__init__.py create mode 100644 tests/unit/document/processing/__init__.py create mode 100644 tests/unit/document/test_conversion.py create mode 100644 tests/unit/test_builder.py create mode 100644 tests/unit/test_dataset.py create mode 100644 tests/unit/test_dataset_casting.py create mode 100644 tests/unit/test_dataset_dict.py diff --git a/tests/__init__.py b/tests/__init__.py index 16ee5afc..f9869874 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,10 +1,24 @@ -import pathlib +from pathlib import Path from typing import Any, Dict -TESTS_ROOT = pathlib.Path(__file__).parent +from datasets import DownloadMode, load_dataset + +TESTS_ROOT = Path(__file__).parent FIXTURES_ROOT = TESTS_ROOT / "fixtures" +DATASET_BUILDERS_ROOT = Path("dataset_builders") def _config_to_str(cfg: Dict[str, Any]) -> str: result = "-".join([f"{k}={cfg[k]}" for k in sorted(cfg)]) return result + + +def _check_hf_conll2003_is_available(): + try: + load_dataset("conll2003", download_mode=DownloadMode.FORCE_REDOWNLOAD) + return True + except ConnectionError: + return False + + +_HF_CONLL2003_IS_AVAILABLE = _check_hf_conll2003_is_available() diff --git a/tests/conftest.py b/tests/conftest.py index 3eda6d56..6fcd7a12 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -29,6 +29,11 @@ _TABULATE_AVAILABLE = "tabulate" in {pkg.key for pkg in pkg_resources.working_set} +@pytest.fixture +def documents(dataset): + return list(dataset["train"]) + + @dataclasses.dataclass class TestDocument(TextBasedDocument): sentences: AnnotationList[Span] = annotation_field(target="text") @@ -78,7 +83,7 @@ def test_hf_dataset(hf_dataset): assert len(hf_dataset[split]) == SPLIT_SIZES[split] -@pytest.fixture(scope="session") +@pytest.fixture() def dataset(hf_dataset): mapped_dataset = hf_dataset.map(example_to_doc_dict) dataset = DatasetDict.from_hf(hf_dataset=mapped_dataset, document_type=TestDocument) @@ -97,3 +102,32 @@ def test_dataset(dataset): doc0 = d_train[0] assert doc0 is not None assert isinstance(doc0, TestDocument) + + +@pytest.fixture(scope="session") +def iterable_hf_dataset(): + result = load_dataset( + "json", + field="data", + data_dir=str(FIXTURES_ROOT / "hf_datasets" / "json"), + streaming=True, + ) + + return result + + +@pytest.fixture() +def iterable_dataset(iterable_hf_dataset): + mapped_dataset = iterable_hf_dataset.map(example_to_doc_dict) + dataset = DatasetDict.from_hf(hf_dataset=mapped_dataset, document_type=TestDocument) + return dataset + + +def test_iterable_dataset(iterable_dataset): + assert iterable_dataset is not None + assert set(iterable_dataset) == set(SPLIT_SIZES) + + +@pytest.fixture(params=["dataset", "iterable_dataset"]) +def maybe_iterable_dataset(request): + return request.getfixturevalue(request.param) diff --git a/tests/fixtures/builder/datasets/base_multi_config/base_multi_config.py b/tests/fixtures/builder/datasets/base_multi_config/base_multi_config.py new file mode 100644 index 00000000..1942a14a --- /dev/null +++ b/tests/fixtures/builder/datasets/base_multi_config/base_multi_config.py @@ -0,0 +1,249 @@ +# Copyright 2020 HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Introduction to the CoNLL-2002 Shared Task: Language-Independent Named Entity Recognition""" + +import datasets + +logger = datasets.logging.get_logger(__name__) + + +_CITATION = """\ +@inproceedings{tjong-kim-sang-2002-introduction, + title = "Introduction to the {C}o{NLL}-2002 Shared Task: Language-Independent Named Entity Recognition", + author = "Tjong Kim Sang, Erik F.", + booktitle = "{COLING}-02: The 6th Conference on Natural Language Learning 2002 ({C}o{NLL}-2002)", + year = "2002", + url = "https://www.aclweb.org/anthology/W02-2024", +} +""" + +_DESCRIPTION = """\ +Named entities are phrases that contain the names of persons, organizations, locations, times and quantities. + +Example: +[PER Wolff] , currently a journalist in [LOC Argentina] , played with [PER Del Bosque] in the final years of the seventies in [ORG Real Madrid] . + +The shared task of CoNLL-2002 concerns language-independent named entity recognition. +We will concentrate on four types of named entities: persons, locations, organizations and names of miscellaneous entities that do not belong to the previous three groups. +The participants of the shared task will be offered training and test data for at least two languages. +They will use the data for developing a named-entity recognition system that includes a machine learning component. +Information sources other than the training data may be used in this shared task. +We are especially interested in methods that can use additional unannotated data for improving their performance (for example co-training). + +The train/validation/test sets are available in Spanish and Dutch. + +For more details see https://www.clips.uantwerpen.be/conll2002/ner/ and https://www.aclweb.org/anthology/W02-2024/ +""" + +_URL = "https://raw.githubusercontent.com/teropa/nlp/master/resources/corpora/conll2002/" +_ES_TRAINING_FILE = "esp.train" +_ES_DEV_FILE = "esp.testa" +_ES_TEST_FILE = "esp.testb" +_NL_TRAINING_FILE = "ned.train" +_NL_DEV_FILE = "ned.testa" +_NL_TEST_FILE = "ned.testb" + + +class Conll2002Config(datasets.BuilderConfig): + """BuilderConfig for Conll2002.""" + + def __init__(self, **kwargs): + """BuilderConfig forConll2002. + + Args: + **kwargs: keyword arguments forwarded to super. + """ + super().__init__(**kwargs) + + +class Conll2002(datasets.GeneratorBasedBuilder): + """Conll2002 dataset.""" + + BUILDER_CONFIGS = [ + Conll2002Config( + name="es", version=datasets.Version("1.0.0"), description="Conll2002 Spanish dataset" + ), + Conll2002Config( + name="nl", version=datasets.Version("1.0.0"), description="Conll2002 Dutch dataset" + ), + ] + + def _info(self): + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=datasets.Features( + { + "id": datasets.Value("string"), + "tokens": datasets.Sequence(datasets.Value("string")), + "pos_tags": datasets.Sequence( + datasets.features.ClassLabel( + names=[ + "AO", + "AQ", + "CC", + "CS", + "DA", + "DE", + "DD", + "DI", + "DN", + "DP", + "DT", + "Faa", + "Fat", + "Fc", + "Fd", + "Fe", + "Fg", + "Fh", + "Fia", + "Fit", + "Fp", + "Fpa", + "Fpt", + "Fs", + "Ft", + "Fx", + "Fz", + "I", + "NC", + "NP", + "P0", + "PD", + "PI", + "PN", + "PP", + "PR", + "PT", + "PX", + "RG", + "RN", + "SP", + "VAI", + "VAM", + "VAN", + "VAP", + "VAS", + "VMG", + "VMI", + "VMM", + "VMN", + "VMP", + "VMS", + "VSG", + "VSI", + "VSM", + "VSN", + "VSP", + "VSS", + "Y", + "Z", + ] + ) + if self.config.name == "es" + else datasets.features.ClassLabel( + names=[ + "Adj", + "Adv", + "Art", + "Conj", + "Int", + "Misc", + "N", + "Num", + "Prep", + "Pron", + "Punc", + "V", + ] + ) + ), + "ner_tags": datasets.Sequence( + datasets.features.ClassLabel( + names=[ + "O", + "B-PER", + "I-PER", + "B-ORG", + "I-ORG", + "B-LOC", + "I-LOC", + "B-MISC", + "I-MISC", + ] + ) + ), + } + ), + supervised_keys=None, + homepage="https://www.aclweb.org/anthology/W02-2024/", + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + """Returns SplitGenerators.""" + urls_to_download = { + "train": f"{_URL}{_ES_TRAINING_FILE if self.config.name == 'es' else _NL_TRAINING_FILE}", + "dev": f"{_URL}{_ES_DEV_FILE if self.config.name == 'es' else _NL_DEV_FILE}", + "test": f"{_URL}{_ES_TEST_FILE if self.config.name == 'es' else _NL_TEST_FILE}", + } + downloaded_files = dl_manager.download_and_extract(urls_to_download) + + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, gen_kwargs={"filepath": downloaded_files["train"]} + ), + datasets.SplitGenerator( + name=datasets.Split.VALIDATION, gen_kwargs={"filepath": downloaded_files["dev"]} + ), + datasets.SplitGenerator( + name=datasets.Split.TEST, gen_kwargs={"filepath": downloaded_files["test"]} + ), + ] + + def _generate_examples(self, filepath): + logger.info("⏳ Generating examples from = %s", filepath) + with open(filepath, encoding="utf-8") as f: + guid = 0 + tokens = [] + pos_tags = [] + ner_tags = [] + for line in f: + if line.startswith("-DOCSTART-") or line == "" or line == "\n": + if tokens: + yield guid, { + "id": str(guid), + "tokens": tokens, + "pos_tags": pos_tags, + "ner_tags": ner_tags, + } + guid += 1 + tokens = [] + pos_tags = [] + ner_tags = [] + else: + # conll2002 tokens are space separated + splits = line.split(" ") + tokens.append(splits[0]) + pos_tags.append(splits[1]) + ner_tags.append(splits[2].rstrip()) + # last example + yield guid, { + "id": str(guid), + "tokens": tokens, + "pos_tags": pos_tags, + "ner_tags": ner_tags, + } diff --git a/tests/fixtures/builder/datasets/base_single_config/base_single_config.py b/tests/fixtures/builder/datasets/base_single_config/base_single_config.py new file mode 100644 index 00000000..169b16bf --- /dev/null +++ b/tests/fixtures/builder/datasets/base_single_config/base_single_config.py @@ -0,0 +1,250 @@ +# Copyright 2020 HuggingFace Datasets Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Lint as: python3 +"""Introduction to the CoNLL-2003 Shared Task: Language-Independent Named Entity Recognition""" + +import os + +import datasets + +logger = datasets.logging.get_logger(__name__) + + +_CITATION = """\ +@inproceedings{tjong-kim-sang-de-meulder-2003-introduction, + title = "Introduction to the {C}o{NLL}-2003 Shared Task: Language-Independent Named Entity Recognition", + author = "Tjong Kim Sang, Erik F. and + De Meulder, Fien", + booktitle = "Proceedings of the Seventh Conference on Natural Language Learning at {HLT}-{NAACL} 2003", + year = "2003", + url = "https://www.aclweb.org/anthology/W03-0419", + pages = "142--147", +} +""" + +_DESCRIPTION = """\ +The shared task of CoNLL-2003 concerns language-independent named entity recognition. We will concentrate on +four types of named entities: persons, locations, organizations and names of miscellaneous entities that do +not belong to the previous three groups. + +The CoNLL-2003 shared task data files contain four columns separated by a single space. Each word has been put on +a separate line and there is an empty line after each sentence. The first item on each line is a word, the second +a part-of-speech (POS) tag, the third a syntactic chunk tag and the fourth the named entity tag. The chunk tags +and the named entity tags have the format I-TYPE which means that the word is inside a phrase of type TYPE. Only +if two phrases of the same type immediately follow each other, the first word of the second phrase will have tag +B-TYPE to show that it starts a new phrase. A word with tag O is not part of a phrase. Note the dataset uses IOB2 +tagging scheme, whereas the original dataset uses IOB1. + +For more details see https://www.clips.uantwerpen.be/conll2003/ner/ and https://www.aclweb.org/anthology/W03-0419 +""" + +_URL = "https://data.deepai.org/conll2003.zip" +_TRAINING_FILE = "train.txt" +_DEV_FILE = "valid.txt" +_TEST_FILE = "test.txt" + + +class Conll2003Config(datasets.BuilderConfig): + """BuilderConfig for Conll2003.""" + + def __init__(self, **kwargs): + """BuilderConfig forConll2003. + + Args: + **kwargs: keyword arguments forwarded to super. + """ + super().__init__(**kwargs) + + +class Conll2003(datasets.GeneratorBasedBuilder): + """Conll2003 dataset.""" + + BUILDER_CONFIGS = [ + Conll2003Config( + name="conll2003", version=datasets.Version("1.0.0"), description="Conll2003 dataset" + ), + ] + + def _info(self): + return datasets.DatasetInfo( + description=_DESCRIPTION, + features=datasets.Features( + { + "id": datasets.Value("string"), + "tokens": datasets.Sequence(datasets.Value("string")), + "pos_tags": datasets.Sequence( + datasets.features.ClassLabel( + names=[ + '"', + "''", + "#", + "$", + "(", + ")", + ",", + ".", + ":", + "``", + "CC", + "CD", + "DT", + "EX", + "FW", + "IN", + "JJ", + "JJR", + "JJS", + "LS", + "MD", + "NN", + "NNP", + "NNPS", + "NNS", + "NN|SYM", + "PDT", + "POS", + "PRP", + "PRP$", + "RB", + "RBR", + "RBS", + "RP", + "SYM", + "TO", + "UH", + "VB", + "VBD", + "VBG", + "VBN", + "VBP", + "VBZ", + "WDT", + "WP", + "WP$", + "WRB", + ] + ) + ), + "chunk_tags": datasets.Sequence( + datasets.features.ClassLabel( + names=[ + "O", + "B-ADJP", + "I-ADJP", + "B-ADVP", + "I-ADVP", + "B-CONJP", + "I-CONJP", + "B-INTJ", + "I-INTJ", + "B-LST", + "I-LST", + "B-NP", + "I-NP", + "B-PP", + "I-PP", + "B-PRT", + "I-PRT", + "B-SBAR", + "I-SBAR", + "B-UCP", + "I-UCP", + "B-VP", + "I-VP", + ] + ) + ), + "ner_tags": datasets.Sequence( + datasets.features.ClassLabel( + names=[ + "O", + "B-PER", + "I-PER", + "B-ORG", + "I-ORG", + "B-LOC", + "I-LOC", + "B-MISC", + "I-MISC", + ] + ) + ), + } + ), + supervised_keys=None, + homepage="https://www.aclweb.org/anthology/W03-0419/", + citation=_CITATION, + ) + + def _split_generators(self, dl_manager): + """Returns SplitGenerators.""" + downloaded_file = dl_manager.download_and_extract(_URL) + data_files = { + "train": os.path.join(downloaded_file, _TRAINING_FILE), + "dev": os.path.join(downloaded_file, _DEV_FILE), + "test": os.path.join(downloaded_file, _TEST_FILE), + } + + return [ + datasets.SplitGenerator( + name=datasets.Split.TRAIN, gen_kwargs={"filepath": data_files["train"]} + ), + datasets.SplitGenerator( + name=datasets.Split.VALIDATION, gen_kwargs={"filepath": data_files["dev"]} + ), + datasets.SplitGenerator( + name=datasets.Split.TEST, gen_kwargs={"filepath": data_files["test"]} + ), + ] + + def _generate_examples(self, filepath): + logger.info("⏳ Generating examples from = %s", filepath) + with open(filepath, encoding="utf-8") as f: + guid = 0 + tokens = [] + pos_tags = [] + chunk_tags = [] + ner_tags = [] + for line in f: + if line.startswith("-DOCSTART-") or line == "" or line == "\n": + if tokens: + yield guid, { + "id": str(guid), + "tokens": tokens, + "pos_tags": pos_tags, + "chunk_tags": chunk_tags, + "ner_tags": ner_tags, + } + guid += 1 + tokens = [] + pos_tags = [] + chunk_tags = [] + ner_tags = [] + else: + # conll2003 tokens are space separated + splits = line.split(" ") + tokens.append(splits[0]) + pos_tags.append(splits[1]) + chunk_tags.append(splits[2]) + ner_tags.append(splits[3].rstrip()) + # last example + if tokens: + yield guid, { + "id": str(guid), + "tokens": tokens, + "pos_tags": pos_tags, + "chunk_tags": chunk_tags, + "ner_tags": ner_tags, + } diff --git a/tests/fixtures/builder/datasets/default_config_kwargs/default_config_kwargs.py b/tests/fixtures/builder/datasets/default_config_kwargs/default_config_kwargs.py new file mode 100644 index 00000000..60c9d48b --- /dev/null +++ b/tests/fixtures/builder/datasets/default_config_kwargs/default_config_kwargs.py @@ -0,0 +1,58 @@ +from dataclasses import dataclass + +import datasets +from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.documents import TextDocument + +from pie_datasets import GeneratorBasedBuilder +from tests import FIXTURES_ROOT + + +class ExampleConfig(datasets.BuilderConfig): + """BuilderConfig for CoNLL2002.""" + + def __init__(self, parameter: str, **kwargs): + """BuilderConfig for CoNLL2002. + + Args: + **kwargs: keyword arguments forwarded to super. + """ + super().__init__(**kwargs) + self.parameter = parameter + + +@dataclass +class ExampleDocument(TextDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + + +class Example(GeneratorBasedBuilder): + DOCUMENT_TYPE = ExampleDocument + + BASE_DATASET_PATH = str(FIXTURES_ROOT / "builder" / "datasets" / "base_multi_config") + + BASE_CONFIG_KWARGS_DICT = { + "nl": {"version": datasets.Version("0.0.0"), "description": "new description"}, + } + + BUILDER_CONFIGS = [ + ExampleConfig( + name="es", + version=datasets.Version("1.0.0"), + description="CoNLL2002 Spanish dataset", + parameter="test", + ), + ExampleConfig( + name="nl", + version=datasets.Version("1.0.0"), + description="CoNLL2002 Dutch dataset", + parameter="test", + ), + ] + + def _generate_document_kwargs(self, dataset): + pass + + def _generate_document(self, example, int_to_str): + pass diff --git a/tests/fixtures/builder/datasets/multi_config/multi_config.py b/tests/fixtures/builder/datasets/multi_config/multi_config.py new file mode 100644 index 00000000..9204705c --- /dev/null +++ b/tests/fixtures/builder/datasets/multi_config/multi_config.py @@ -0,0 +1,54 @@ +from dataclasses import dataclass + +import datasets +from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.documents import TextDocument + +from pie_datasets import GeneratorBasedBuilder +from tests import FIXTURES_ROOT + + +class ExampleConfig(datasets.BuilderConfig): + """BuilderConfig for CoNLL2002.""" + + def __init__(self, parameter: str, **kwargs): + """BuilderConfig for CoNLL2002. + + Args: + **kwargs: keyword arguments forwarded to super. + """ + super().__init__(**kwargs) + self.parameter = parameter + + +@dataclass +class ExampleDocument(TextDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + + +class Example(GeneratorBasedBuilder): + DOCUMENT_TYPE = ExampleDocument + + BASE_DATASET_PATH = str(FIXTURES_ROOT / "builder" / "datasets" / "base_multi_config") + + BUILDER_CONFIGS = [ + ExampleConfig( + name="es", + version=datasets.Version("1.0.0"), + description="CoNLL2002 Spanish dataset", + parameter="test", + ), + ExampleConfig( + name="nl", + version=datasets.Version("1.0.0"), + description="CoNLL2002 Dutch dataset", + parameter="test", + ), + ] + + def _generate_document_kwargs(self, dataset): + pass + + def _generate_document(self, example, int_to_str): + pass diff --git a/tests/fixtures/builder/datasets/name_mapping/name_mapping.py b/tests/fixtures/builder/datasets/name_mapping/name_mapping.py new file mode 100644 index 00000000..8d891317 --- /dev/null +++ b/tests/fixtures/builder/datasets/name_mapping/name_mapping.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass + +import datasets +from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.documents import TextDocument + +from pie_datasets import GeneratorBasedBuilder +from tests import FIXTURES_ROOT + + +class ExampleConfig(datasets.BuilderConfig): + """BuilderConfig for CoNLL2002.""" + + def __init__(self, parameter: str, **kwargs): + """BuilderConfig for CoNLL2002. + + Args: + **kwargs: keyword arguments forwarded to super. + """ + super().__init__(**kwargs) + self.parameter = parameter + + +@dataclass +class ExampleDocument(TextDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + + +class Example(GeneratorBasedBuilder): + DOCUMENT_TYPE = ExampleDocument + + BASE_DATASET_PATH = str(FIXTURES_ROOT / "builder" / "datasets" / "base_multi_config") + + # map everything to "nl" + BASE_CONFIG_KWARGS_DICT = {"es": {"name": "nl"}} + + BUILDER_CONFIGS = [ + ExampleConfig( + name="es", + version=datasets.Version("1.0.0"), + description="CoNLL2002 Spanish dataset", + parameter="test", + ), + ExampleConfig( + name="nl", + version=datasets.Version("1.0.0"), + description="CoNLL2002 Dutch dataset", + parameter="test", + ), + ] + + def _generate_document_kwargs(self, dataset): + pass + + def _generate_document(self, example, int_to_str): + pass diff --git a/tests/fixtures/builder/datasets/name_mapping_disabled/name_mapping_disabled.py b/tests/fixtures/builder/datasets/name_mapping_disabled/name_mapping_disabled.py new file mode 100644 index 00000000..2c5092ba --- /dev/null +++ b/tests/fixtures/builder/datasets/name_mapping_disabled/name_mapping_disabled.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass + +import datasets +from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.documents import TextDocument + +from pie_datasets import GeneratorBasedBuilder +from tests import FIXTURES_ROOT + + +class ExampleConfig(datasets.BuilderConfig): + """BuilderConfig for CoNLL2002.""" + + def __init__(self, parameter: str, **kwargs): + """BuilderConfig for CoNLL2002. + + Args: + **kwargs: keyword arguments forwarded to super. + """ + super().__init__(**kwargs) + self.parameter = parameter + + +@dataclass +class ExampleDocument(TextDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + + +class Example(GeneratorBasedBuilder): + DOCUMENT_TYPE = ExampleDocument + + BASE_DATASET_PATH = str(FIXTURES_ROOT / "builder" / "datasets" / "base_multi_config") + + # disable any mapping + BASE_CONFIG_KWARGS_DICT = None + + BUILDER_CONFIGS = [ + ExampleConfig( + name="es", + version=datasets.Version("1.0.0"), + description="CoNLL2002 Spanish dataset", + parameter="test", + ), + ExampleConfig( + name="nl", + version=datasets.Version("1.0.0"), + description="CoNLL2002 Dutch dataset", + parameter="test", + ), + ] + + def _generate_document_kwargs(self, dataset): + pass + + def _generate_document(self, example, int_to_str): + pass diff --git a/tests/fixtures/builder/datasets/single_config/single_config.py b/tests/fixtures/builder/datasets/single_config/single_config.py new file mode 100644 index 00000000..4a7f503e --- /dev/null +++ b/tests/fixtures/builder/datasets/single_config/single_config.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import Type + +import datasets +from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.documents import TextDocument + +from pie_datasets import GeneratorBasedBuilder +from tests import FIXTURES_ROOT + + +class ExampleConfig(datasets.BuilderConfig): + """BuilderConfig for CoNLL2003.""" + + def __init__(self, parameter: str, **kwargs): + """BuilderConfig for CoNLL2003. + + Args: + **kwargs: keyword arguments forwarded to super. + """ + super().__init__(**kwargs) + self.parameter = parameter + + +@dataclass +class ExampleDocument(TextDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + + +class Example(GeneratorBasedBuilder): + DOCUMENT_TYPE = ExampleDocument + + BASE_DATASET_PATH = str(FIXTURES_ROOT / "builder" / "datasets" / "base_single_config") + + BUILDER_CONFIGS = [ + ExampleConfig( + name="conll2003", + version=datasets.Version("1.0.0"), + description="Example dataset", + parameter="test", + ), + ] + + # required to create config from scratch via kwargs + BUILDER_CONFIG_CLASS: Type[datasets.BuilderConfig] = ExampleConfig + + def _generate_document_kwargs(self, dataset): + pass + + def _generate_document(self, example, int_to_str): + pass diff --git a/tests/fixtures/builder/datasets/wrong_builder_class_config/wrong_builder_class_config.py b/tests/fixtures/builder/datasets/wrong_builder_class_config/wrong_builder_class_config.py new file mode 100644 index 00000000..be9b1334 --- /dev/null +++ b/tests/fixtures/builder/datasets/wrong_builder_class_config/wrong_builder_class_config.py @@ -0,0 +1,52 @@ +from dataclasses import dataclass +from typing import Type + +import datasets +from pytorch_ie.annotations import LabeledSpan +from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.documents import TextDocument + +from pie_datasets.builder import ArrowBasedBuilder +from tests import FIXTURES_ROOT + + +class ExampleConfig(datasets.BuilderConfig): + """BuilderConfig for CoNLL2003.""" + + def __init__(self, parameter: str, **kwargs): + """BuilderConfig for CoNLL2003. + + Args: + **kwargs: keyword arguments forwarded to super. + """ + super().__init__(**kwargs) + self.parameter = parameter + + +@dataclass +class ExampleDocument(TextDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + + +class Example(ArrowBasedBuilder): + DOCUMENT_TYPE = ExampleDocument + + BASE_DATASET_PATH = str(FIXTURES_ROOT / "builder" / "datasets" / "base_single_config") + + BUILDER_CONFIGS = [ + ExampleConfig( + name="conll2003", + version=datasets.Version("1.0.0"), + description="Example dataset", + parameter="test", + ), + ] + + # required to create config from scratch via kwargs + BUILDER_CONFIG_CLASS: Type[datasets.BuilderConfig] = ExampleConfig + + def _generate_document_kwargs(self, dataset): + pass + + def _generate_document(self, example, int_to_str): + pass diff --git a/tests/fixtures/dataset_dict/conll2003_extract/test/documents.jsonl b/tests/fixtures/dataset_dict/conll2003_extract/test/documents.jsonl new file mode 100644 index 00000000..2d642f44 --- /dev/null +++ b/tests/fixtures/dataset_dict/conll2003_extract/test/documents.jsonl @@ -0,0 +1,3 @@ +{"text": "SOCCER - JAPAN GET LUCKY WIN , CHINA IN SURPRISE DEFEAT .", "id": "0", "metadata": null, "entities": {"annotations": [{"start": 9, "end": 14, "label": "LOC", "score": 1.0, "_id": -2619339436438505339}, {"start": 31, "end": 36, "label": "PER", "score": 1.0, "_id": -8138508157508680512}], "predictions": []}} +{"text": "Nadim Ladki", "id": "1", "metadata": null, "entities": {"annotations": [{"start": 0, "end": 11, "label": "PER", "score": 1.0, "_id": -888938487910085717}], "predictions": []}} +{"text": "AL-AIN , United Arab Emirates 1996-12-06", "id": "2", "metadata": null, "entities": {"annotations": [{"start": 0, "end": 6, "label": "LOC", "score": 1.0, "_id": -7372491258908974083}, {"start": 9, "end": 29, "label": "LOC", "score": 1.0, "_id": -8605924337576328352}], "predictions": []}} diff --git a/tests/fixtures/dataset_dict/conll2003_extract/train/documents.jsonl b/tests/fixtures/dataset_dict/conll2003_extract/train/documents.jsonl new file mode 100644 index 00000000..09f96137 --- /dev/null +++ b/tests/fixtures/dataset_dict/conll2003_extract/train/documents.jsonl @@ -0,0 +1,3 @@ +{"text": "EU rejects German call to boycott British lamb .", "id": "0", "metadata": null, "entities": {"annotations": [{"start": 0, "end": 2, "label": "ORG", "score": 1.0, "_id": -6025869170090849777}, {"start": 11, "end": 17, "label": "MISC", "score": 1.0, "_id": -8712404795926495516}, {"start": 34, "end": 41, "label": "MISC", "score": 1.0, "_id": 173163560486985535}], "predictions": []}} +{"text": "Peter Blackburn", "id": "1", "metadata": null, "entities": {"annotations": [{"start": 0, "end": 15, "label": "PER", "score": 1.0, "_id": 464505172076656073}], "predictions": []}} +{"text": "BRUSSELS 1996-08-22", "id": "2", "metadata": null, "entities": {"annotations": [{"start": 0, "end": 8, "label": "LOC", "score": 1.0, "_id": -3556694396036444869}], "predictions": []}} diff --git a/tests/fixtures/dataset_dict/conll2003_extract/validation/documents.jsonl b/tests/fixtures/dataset_dict/conll2003_extract/validation/documents.jsonl new file mode 100644 index 00000000..63c95d46 --- /dev/null +++ b/tests/fixtures/dataset_dict/conll2003_extract/validation/documents.jsonl @@ -0,0 +1,3 @@ +{"text": "CRICKET - LEICESTERSHIRE TAKE OVER AT TOP AFTER INNINGS VICTORY .", "id": "0", "metadata": null, "entities": {"annotations": [{"start": 10, "end": 24, "label": "ORG", "score": 1.0, "_id": 6901678984913972450}], "predictions": []}} +{"text": "LONDON 1996-08-30", "id": "1", "metadata": null, "entities": {"annotations": [{"start": 0, "end": 6, "label": "LOC", "score": 1.0, "_id": -7372491258908974083}], "predictions": []}} +{"text": "West Indian all-rounder Phil Simmons took four for 38 on Friday as Leicestershire beat Somerset by an innings and 39 runs in two days to take over at the head of the county championship .", "id": "2", "metadata": null, "entities": {"annotations": [{"start": 0, "end": 11, "label": "MISC", "score": 1.0, "_id": 6242261418838464680}, {"start": 24, "end": 36, "label": "PER", "score": 1.0, "_id": 6271522231709659741}, {"start": 67, "end": 81, "label": "ORG", "score": 1.0, "_id": -963921602163885151}, {"start": 87, "end": 95, "label": "ORG", "score": 1.0, "_id": 4752197907666177189}], "predictions": []}} diff --git a/tests/fixtures/hf_datasets/json/train.json b/tests/fixtures/hf_datasets/json/train.json index 8a6614fa..372a8607 100644 --- a/tests/fixtures/hf_datasets/json/train.json +++ b/tests/fixtures/hf_datasets/json/train.json @@ -89,7 +89,7 @@ } }, { - "id": "val_doc1", + "id": "train_doc7", "text": "A single sentence.", "sentences": [{ "start": 0, "end": 18 }], "entities": [], @@ -99,7 +99,7 @@ } }, { - "id": "val_doc2", + "id": "train_doc8", "text": "First sentence. Entity M works at N. And it founded O.", "sentences": [ { "start": 0, "end": 15 }, diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/document/__init__.py b/tests/unit/document/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/document/processing/__init__.py b/tests/unit/document/processing/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/document/test_conversion.py b/tests/unit/document/test_conversion.py new file mode 100644 index 00000000..730e8c12 --- /dev/null +++ b/tests/unit/document/test_conversion.py @@ -0,0 +1,558 @@ +import dataclasses + +import pytest +from pytorch_ie.annotations import BinaryRelation, LabeledSpan, Span +from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.documents import TokenBasedDocument +from transformers import AutoTokenizer, PreTrainedTokenizer + +from pie_datasets.document.conversion import ( + text_based_document_to_token_based, + token_based_document_to_text_based, + tokenize_document, +) +from tests.conftest import TestDocument + + +@dataclasses.dataclass +class TokenizedTestDocument(TokenBasedDocument): + sentences: AnnotationList[Span] = annotation_field(target="tokens") + entities: AnnotationList[LabeledSpan] = annotation_field(target="tokens") + relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") + + +@pytest.fixture(scope="module") +def tokenizer() -> PreTrainedTokenizer: + return AutoTokenizer.from_pretrained("bert-base-cased") + + +def test_text_based_document_to_token_based(documents, tokenizer): + assert len(documents) >= 3 + for i, doc in enumerate(documents[:3]): + tokenized_text = tokenizer(doc.text, return_offsets_mapping=True) + tokenized_doc = text_based_document_to_token_based( + doc, + tokens=tokenized_text.tokens(), + result_document_type=TokenizedTestDocument, + # to increase test coverage + token_offset_mapping=None if i == 1 else tokenized_text.offset_mapping, + # to increase test coverage + char_to_token=None if i == 0 else tokenized_text.char_to_token, + ) + assert tokenized_doc is not None + + # check (de-)serialization + tokenized_doc.copy() + + offset_mapping_lists = [list(offsets) for offsets in tokenized_text.offset_mapping] + if i == 0: + assert doc.id == "train_doc1" + assert tokenized_doc.metadata["text"] == doc.text == "A single sentence." + assert tokenized_doc.metadata["token_offset_mapping"] == offset_mapping_lists + assert tokenized_doc.metadata.get("char_to_token") is None + assert tokenized_doc.tokens == ("[CLS]", "A", "single", "sentence", ".", "[SEP]") + assert len(tokenized_doc.sentences) == len(doc.sentences) == 1 + assert str(doc.sentences[0]) == "A single sentence." + assert str(tokenized_doc.sentences[0]) == "('A', 'single', 'sentence', '.')" + assert len(tokenized_doc.entities) == len(doc.entities) == 0 + assert len(tokenized_doc.relations) == len(doc.relations) == 0 + elif i == 1: + assert doc.id == "train_doc2" + assert tokenized_doc.metadata["text"] == doc.text == "Entity A works at B." + assert tokenized_doc.metadata.get("token_offset_mapping") is None + assert tokenized_doc.metadata["char_to_token"] == tokenized_text.char_to_token + assert tokenized_doc.tokens == ( + "[CLS]", + "En", + "##ti", + "##ty", + "A", + "works", + "at", + "B", + ".", + "[SEP]", + ) + assert len(tokenized_doc.sentences) == len(doc.sentences) == 1 + assert str(doc.sentences[0]) == "Entity A works at B." + assert ( + str(tokenized_doc.sentences[0]) + == "('En', '##ti', '##ty', 'A', 'works', 'at', 'B', '.')" + ) + assert len(tokenized_doc.entities) == len(doc.entities) == 2 + assert str(doc.entities[0]) == "Entity A" + assert str(tokenized_doc.entities[0]) == "('En', '##ti', '##ty', 'A')" + assert str(doc.entities[1]) == "B" + assert str(tokenized_doc.entities[1]) == "('B',)" + assert len(tokenized_doc.relations) == len(doc.relations) == 1 + assert doc.relations[0].head == doc.entities[0] + assert tokenized_doc.relations[0].head == tokenized_doc.entities[0] + assert doc.relations[0].tail == doc.entities[1] + assert tokenized_doc.relations[0].tail == tokenized_doc.entities[1] + elif i == 2: + assert doc.id == "train_doc3" + assert tokenized_doc.metadata["text"] == doc.text == "Entity C and D." + assert tokenized_doc.metadata["token_offset_mapping"] == offset_mapping_lists + assert tokenized_doc.metadata["char_to_token"] == tokenized_text.char_to_token + assert tokenized_doc.tokens == ( + "[CLS]", + "En", + "##ti", + "##ty", + "C", + "and", + "D", + ".", + "[SEP]", + ) + assert len(tokenized_doc.sentences) == len(doc.sentences) == 1 + assert str(doc.sentences[0]) == "Entity C and D." + assert ( + str(tokenized_doc.sentences[0]) == "('En', '##ti', '##ty', 'C', 'and', 'D', '.')" + ) + assert len(tokenized_doc.entities) == len(doc.entities) == 2 + assert str(doc.entities[0]) == "Entity C" + assert str(tokenized_doc.entities[0]) == "('En', '##ti', '##ty', 'C')" + assert str(doc.entities[1]) == "D" + assert str(tokenized_doc.entities[1]) == "('D',)" + assert len(tokenized_doc.relations) == len(doc.relations) == 0 + else: + raise ValueError(f"Unexpected document: {doc.id}") + + +def test_text_based_document_to_token_based_missing_args(documents, tokenizer): + with pytest.raises(ValueError) as excinfo: + doc = documents[0] + tokenized_text = tokenizer(doc.text) + tokenized_doc = text_based_document_to_token_based( + doc, + tokens=tokenized_text.tokens(), + result_document_type=TokenizedTestDocument, + ) + assert ( + str(excinfo.value) + == "either token_offset_mapping or char_to_token must be provided to convert a text based document " + "to token based, but both are None" + ) + + +def test_text_based_document_to_token_based_unaligned_span_strict(documents, tokenizer): + doc = documents[0].copy() + # add a span that is not aligned with the tokenization + doc.entities.append(LabeledSpan(start=0, end=2, label="unaligned")) + assert str(doc.entities[-1]) == "A " + tokenized_text = tokenizer(doc.text, return_offsets_mapping=True) + with pytest.raises(ValueError) as excinfo: + tokenized_doc = text_based_document_to_token_based( + doc, + tokens=tokenized_text.tokens(), + result_document_type=TokenizedTestDocument, + # to increase test coverage + token_offset_mapping=tokenized_text.offset_mapping, + # to increase test coverage + char_to_token=tokenized_text.char_to_token, + ) + assert ( + str(excinfo.value) + == 'cannot find token span for character span: "A ", text="A single sentence.", ' + "token_offset_mapping=[(0, 0), (0, 1), (2, 8), (9, 17), (17, 18), (0, 0)]" + ) + + +def test_text_based_document_to_token_based_unaligned_span_not_strict(documents, tokenizer): + doc = documents[0].copy() + doc.entities.append(LabeledSpan(start=0, end=2, label="unaligned")) + assert str(doc.entities[-1]) == "A " + tokenized_text = tokenizer(doc.text, return_offsets_mapping=True) + tokenized_doc = text_based_document_to_token_based( + doc, + tokens=tokenized_text.tokens(), + result_document_type=TokenizedTestDocument, + # to increase test coverage + token_offset_mapping=tokenized_text.offset_mapping, + # to increase test coverage + char_to_token=tokenized_text.char_to_token, + strict_span_conversion=False, + ) + + # check (de-)serialization + tokenized_doc.copy() + + assert len(doc.entities) == 1 + # the unaligned span is not included in the tokenized document + assert len(tokenized_doc.entities) == 0 + + +@pytest.fixture +def token_documents(documents, tokenizer): + result = [] + for doc in documents: + tokenized_text = tokenizer(doc.text, return_offsets_mapping=True) + tokenized_doc = text_based_document_to_token_based( + doc, + tokens=tokenized_text.tokens(), + result_document_type=TokenizedTestDocument, + char_to_token=tokenized_text.char_to_token, + token_offset_mapping=tokenized_text.offset_mapping, + ) + result.append(tokenized_doc) + return result + + +def test_token_based_document_to_text_based(documents, token_documents): + for doc, tokenized_doc in zip(documents, token_documents): + reconstructed_doc = token_based_document_to_text_based( + tokenized_doc, + result_document_type=TestDocument, + ) + assert reconstructed_doc is not None + doc_dict = doc.asdict() + reconstructed_doc_dict = reconstructed_doc.asdict() + # remove all added metadata (original text, token_offset_mapping, char_to_token, tokens) + reconstructed_doc_dict["metadata"] = { + k: reconstructed_doc_dict["metadata"][k] for k in doc_dict["metadata"] + } + assert reconstructed_doc_dict == doc_dict + + +def test_token_based_document_to_text_based_with_join_tokens_with(documents): + for doc in documents: + # split the text by individual whitespace characters + # so that we can reconstruct the original text via " ".join(tokens) + tokens = [] + token_offset_mapping = [] + start = 0 + for token in doc.text.split(" "): + tokens.append(token) + end = start + len(token) + token_offset_mapping.append((start, end)) + start = end + 1 + + tokenized_doc = text_based_document_to_token_based( + doc, + tokens=tokens, + result_document_type=TokenizedTestDocument, + token_offset_mapping=token_offset_mapping, + ) + reconstructed_doc = token_based_document_to_text_based( + tokenized_doc, + result_document_type=TestDocument, + join_tokens_with=" ", + ) + assert reconstructed_doc is not None + assert reconstructed_doc.text == doc.text + + if doc.id in ["train_doc1", "train_doc7"]: + doc_dict = doc.asdict() + reconstructed_doc_dict = reconstructed_doc.asdict() + # remove all added metadata (original text, token_offset_mapping, char_to_token, tokens) + reconstructed_doc_dict["metadata"] = { + k: reconstructed_doc_dict["metadata"][k] for k in doc_dict["metadata"] + } + assert reconstructed_doc_dict == doc_dict + elif doc.id == "train_doc2": + assert reconstructed_doc.sentences == doc.sentences + assert len(reconstructed_doc.entities) == len(doc.entities) == 2 + assert str(reconstructed_doc.entities[0]) == str(doc.entities[0]) == "Entity A" + assert str(doc.entities[1]) == "B" + assert str(reconstructed_doc.entities[1]) == "B." + assert len(reconstructed_doc.relations) == len(doc.relations) == 1 + assert ( + reconstructed_doc.relations[0].label == doc.relations[0].label == "per:employee_of" + ) + assert doc.relations[0].head == doc.entities[0] + assert reconstructed_doc.relations[0].head == reconstructed_doc.entities[0] + assert doc.relations[0].tail == doc.entities[1] + assert reconstructed_doc.relations[0].tail == reconstructed_doc.entities[1] + elif doc.id == "train_doc3": + assert reconstructed_doc.sentences == doc.sentences + assert len(reconstructed_doc.entities) == len(doc.entities) == 2 + assert str(reconstructed_doc.entities[0]) == str(doc.entities[0]) == "Entity C" + assert str(doc.entities[1]) == "D" + assert str(reconstructed_doc.entities[1]) == "D." + assert len(reconstructed_doc.relations) == len(doc.relations) == 0 + elif doc.id == "train_doc4": + assert reconstructed_doc.sentences == doc.sentences + assert len(reconstructed_doc.entities) == len(doc.entities) == 2 + assert str(reconstructed_doc.entities[0]) == str(doc.entities[0]) == "Entity E" + assert str(doc.entities[1]) == "F" + assert str(reconstructed_doc.entities[1]) == "F." + assert len(reconstructed_doc.relations) == len(doc.relations) == 0 + elif doc.id == "train_doc5": + assert reconstructed_doc.sentences == doc.sentences + assert len(reconstructed_doc.entities) == len(doc.entities) == 3 + assert str(reconstructed_doc.entities[0]) == str(doc.entities[0]) == "Entity G" + assert str(doc.entities[1]) == "H" + assert str(reconstructed_doc.entities[1]) == "H." + assert str(doc.entities[2]) == "I" + assert str(reconstructed_doc.entities[2]) == "I." + assert len(reconstructed_doc.relations) == len(doc.relations) == 3 + assert ( + reconstructed_doc.relations[0].label == doc.relations[0].label == "per:employee_of" + ) + assert doc.relations[0].head == doc.entities[0] + assert reconstructed_doc.relations[0].head == reconstructed_doc.entities[0] + assert doc.relations[0].tail == doc.entities[1] + assert reconstructed_doc.relations[0].tail == reconstructed_doc.entities[1] + assert reconstructed_doc.relations[1].label == doc.relations[1].label == "per:founder" + assert doc.relations[1].head == doc.entities[0] + assert reconstructed_doc.relations[1].head == reconstructed_doc.entities[0] + assert doc.relations[1].tail == doc.entities[2] + assert reconstructed_doc.relations[1].tail == reconstructed_doc.entities[2] + assert ( + reconstructed_doc.relations[2].label == doc.relations[2].label == "org:founded_by" + ) + assert doc.relations[2].head == doc.entities[2] + assert reconstructed_doc.relations[2].head == reconstructed_doc.entities[2] + assert doc.relations[2].tail == doc.entities[1] + assert reconstructed_doc.relations[2].tail == reconstructed_doc.entities[1] + elif doc.id == "train_doc6": + assert reconstructed_doc.sentences == doc.sentences + assert len(reconstructed_doc.entities) == len(doc.entities) == 3 + assert str(doc.entities[0]) == "Entity J" + assert str(reconstructed_doc.entities[0]) == "Entity J," + assert str(doc.entities[1]) == "K" + assert str(reconstructed_doc.entities[1]) == "K," + assert str(doc.entities[2]) == "L" + assert str(reconstructed_doc.entities[2]) == "L." + assert len(reconstructed_doc.relations) == len(doc.relations) == 0 + elif doc.id == "train_doc8": + assert len(reconstructed_doc.sentences) == len(doc.sentences) == 3 + assert ( + str(reconstructed_doc.sentences[0]) == str(doc.sentences[0]) == "First sentence." + ) + assert ( + str(reconstructed_doc.sentences[1]) + == str(doc.sentences[1]) + == "Entity M works at N." + ) + assert str(doc.sentences[2]) == "And it founded O" + assert str(reconstructed_doc.sentences[2]) == "And it founded O." + assert len(reconstructed_doc.entities) == len(doc.entities) == 4 + assert str(reconstructed_doc.entities[0]) == str(doc.entities[0]) == "Entity M" + assert str(doc.entities[1]) == "N" + assert str(reconstructed_doc.entities[1]) == "N." + assert str(reconstructed_doc.entities[2]) == str(doc.entities[2]) == "it" + assert str(doc.entities[3]) == "O" + assert str(reconstructed_doc.entities[3]) == "O." + assert len(reconstructed_doc.relations) == len(doc.relations) == 3 + assert ( + reconstructed_doc.relations[0].label == doc.relations[0].label == "per:employee_of" + ) + assert doc.relations[0].head == doc.entities[0] + assert reconstructed_doc.relations[0].head == reconstructed_doc.entities[0] + assert doc.relations[0].tail == doc.entities[1] + assert reconstructed_doc.relations[0].tail == reconstructed_doc.entities[1] + assert reconstructed_doc.relations[1].label == doc.relations[1].label == "per:founder" + assert doc.relations[1].head == doc.entities[2] + assert reconstructed_doc.relations[1].head == reconstructed_doc.entities[2] + assert doc.relations[1].tail == doc.entities[3] + assert reconstructed_doc.relations[1].tail == reconstructed_doc.entities[3] + assert ( + reconstructed_doc.relations[2].label == doc.relations[2].label == "org:founded_by" + ) + assert doc.relations[2].head == doc.entities[3] + assert reconstructed_doc.relations[2].head == reconstructed_doc.entities[3] + assert doc.relations[2].tail == doc.entities[2] + assert reconstructed_doc.relations[2].tail == reconstructed_doc.entities[2] + else: + raise ValueError(f"Unexpected document: {doc.id}") + + +def test_tokenize_document(documents, tokenizer): + doc = documents[1] + tokenized_docs = tokenize_document( + doc, + tokenizer=tokenizer, + result_document_type=TokenizedTestDocument, + ) + assert len(tokenized_docs) == 1 + tokenized_doc = tokenized_docs[0] + + # check (de-)serialization + tokenized_doc.copy() + + assert doc.id == "train_doc2" + assert tokenized_doc.metadata["text"] == doc.text == "Entity A works at B." + assert tokenized_doc.tokens == ( + "[CLS]", + "En", + "##ti", + "##ty", + "A", + "works", + "at", + "B", + ".", + "[SEP]", + ) + assert len(tokenized_doc.sentences) == len(doc.sentences) == 1 + assert str(doc.sentences[0]) == "Entity A works at B." + assert ( + str(tokenized_doc.sentences[0]) == "('En', '##ti', '##ty', 'A', 'works', 'at', 'B', '.')" + ) + assert len(tokenized_doc.entities) == len(doc.entities) == 2 + assert str(doc.entities[0]) == "Entity A" + assert str(tokenized_doc.entities[0]) == "('En', '##ti', '##ty', 'A')" + assert str(doc.entities[1]) == "B" + assert str(tokenized_doc.entities[1]) == "('B',)" + assert len(tokenized_doc.relations) == len(doc.relations) == 1 + assert tokenized_doc.relations[0].label == doc.relations[0].label == "per:employee_of" + assert doc.relations[0].head == doc.entities[0] + assert tokenized_doc.relations[0].head == tokenized_doc.entities[0] + assert doc.relations[0].tail == doc.entities[1] + assert tokenized_doc.relations[0].tail == tokenized_doc.entities[1] + + +def test_tokenize_document_max_length(documents, tokenizer): + doc = documents[1] + assert doc.id == "train_doc2" + assert doc.text == "Entity A works at B." + assert len(doc.sentences) == 1 + assert str(doc.sentences[0]) == "Entity A works at B." + assert len(doc.entities) == 2 + assert str(doc.entities[0]) == "Entity A" + assert str(doc.entities[1]) == "B" + assert len(doc.relations) == 1 + assert doc.relations[0].label == "per:employee_of" + assert doc.relations[0].head == doc.entities[0] + assert doc.relations[0].tail == doc.entities[1] + + tokenized_docs = tokenize_document( + doc, + tokenizer=tokenizer, + result_document_type=TokenizedTestDocument, + strict_span_conversion=False, + # This will cut out the second entity. Also, the sentence annotation will be removed, + # because the sentence is not complete anymore. + max_length=8, + return_overflowing_tokens=True, + ) + assert len(tokenized_docs) == 2 + tokenized_doc = tokenized_docs[0] + + # check (de-)serialization + tokenized_doc.copy() + + assert tokenized_doc.id == doc.id == "train_doc2" + assert tokenized_doc.metadata["text"] == doc.text == "Entity A works at B." + assert tokenized_doc.tokens == ("[CLS]", "En", "##ti", "##ty", "A", "works", "at", "[SEP]") + assert len(tokenized_doc.sentences) == 0 + assert len(tokenized_doc.entities) == 1 + assert str(tokenized_doc.entities[0]) == "('En', '##ti', '##ty', 'A')" + assert len(tokenized_doc.relations) == 0 + + tokenized_doc = tokenized_docs[1] + + # check (de-)serialization + tokenized_doc.copy() + + assert tokenized_doc.id == doc.id == "train_doc2" + assert tokenized_doc.metadata["text"] == doc.text == "Entity A works at B." + assert tokenized_doc.tokens == ("[CLS]", "B", ".", "[SEP]") + assert len(tokenized_doc.sentences) == 0 + assert len(tokenized_doc.entities) == 1 + assert str(tokenized_doc.entities[0]) == "('B',)" + assert len(tokenized_doc.relations) == 0 + + +def test_tokenize_document_partition(documents, tokenizer): + doc = documents[7] + assert doc.id == "train_doc8" + assert doc.text == "First sentence. Entity M works at N. And it founded O." + assert len(doc.sentences) == 3 + assert str(doc.sentences[0]) == "First sentence." + assert str(doc.sentences[1]) == "Entity M works at N." + assert str(doc.sentences[2]) == "And it founded O" + assert len(doc.entities) == 4 + assert str(doc.entities[0]) == "Entity M" + assert str(doc.entities[1]) == "N" + assert str(doc.entities[2]) == "it" + assert str(doc.entities[3]) == "O" + assert len(doc.relations) == 3 + assert doc.relations[0].head == doc.entities[0] + assert doc.relations[0].tail == doc.entities[1] + assert doc.relations[1].head == doc.entities[2] + assert doc.relations[1].tail == doc.entities[3] + assert doc.relations[2].head == doc.entities[3] + assert doc.relations[2].tail == doc.entities[2] + + tokenized_docs = tokenize_document( + doc, + tokenizer=tokenizer, + result_document_type=TokenizedTestDocument, + strict_span_conversion=False, + partition_layer="sentences", + ) + assert len(tokenized_docs) == 3 + tokenized_doc = tokenized_docs[0] + + # check (de-)serialization + tokenized_doc.copy() + + assert tokenized_doc.id == doc.id == "train_doc8" + assert ( + tokenized_doc.metadata["text"] + == doc.text + == "First sentence. Entity M works at N. And it founded O." + ) + assert tokenized_doc.tokens == ("[CLS]", "First", "sentence", ".", "[SEP]") + assert len(tokenized_doc.sentences) == 1 + assert len(tokenized_doc.entities) == 0 + assert len(tokenized_doc.relations) == 0 + + tokenized_doc = tokenized_docs[1] + + # check (de-)serialization + tokenized_doc.copy() + + assert tokenized_doc.id == doc.id == "train_doc8" + assert ( + tokenized_doc.metadata["text"] + == doc.text + == "First sentence. Entity M works at N. And it founded O." + ) + assert tokenized_doc.tokens == ( + "[CLS]", + "En", + "##ti", + "##ty", + "M", + "works", + "at", + "N", + ".", + "[SEP]", + ) + assert len(tokenized_doc.sentences) == 1 + assert len(tokenized_doc.entities) == 2 + assert str(tokenized_doc.entities[0]) == "('En', '##ti', '##ty', 'M')" + assert str(tokenized_doc.entities[1]) == "('N',)" + assert len(tokenized_doc.relations) == 1 + assert tokenized_doc.relations[0].label == "per:employee_of" + assert tokenized_doc.relations[0].head == tokenized_doc.entities[0] + assert tokenized_doc.relations[0].tail == tokenized_doc.entities[1] + + tokenized_doc = tokenized_docs[2] + + # check (de-)serialization + tokenized_doc.copy() + + assert tokenized_doc.id == doc.id == "train_doc8" + assert ( + tokenized_doc.metadata["text"] + == doc.text + == "First sentence. Entity M works at N. And it founded O." + ) + assert tokenized_doc.tokens == ("[CLS]", "And", "it", "founded", "O", "[SEP]") + assert len(tokenized_doc.sentences) == 1 + assert len(tokenized_doc.entities) == 2 + assert str(tokenized_doc.entities[0]) == "('it',)" + assert str(tokenized_doc.entities[1]) == "('O',)" + assert len(tokenized_doc.relations) == 2 + assert tokenized_doc.relations[0].label == "per:founder" + assert tokenized_doc.relations[0].head == tokenized_doc.entities[0] + assert tokenized_doc.relations[0].tail == tokenized_doc.entities[1] + assert tokenized_doc.relations[1].label == "org:founded_by" + assert tokenized_doc.relations[1].head == tokenized_doc.entities[1] + assert tokenized_doc.relations[1].tail == tokenized_doc.entities[0] diff --git a/tests/unit/test_builder.py b/tests/unit/test_builder.py new file mode 100644 index 00000000..f05fd511 --- /dev/null +++ b/tests/unit/test_builder.py @@ -0,0 +1,224 @@ +import re +import tempfile +from dataclasses import dataclass +from typing import Type + +import pytest +from datasets import DatasetBuilder, Version +from datasets.load import dataset_module_factory, import_main_class +from pytorch_ie.annotations import LabeledSpan, Span +from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.documents import TextBasedDocument, TextDocumentWithSpans + +from pie_datasets.builder import PieDatasetBuilder +from tests import FIXTURES_ROOT + +DATASETS_ROOT = FIXTURES_ROOT / "builder" / "datasets" + + +def test_builder_class(): + dataset_module = dataset_module_factory(str(DATASETS_ROOT / "single_config")) + builder_cls = import_main_class(dataset_module.module_path) + with tempfile.TemporaryDirectory() as tmp_cache_dir: + builder = builder_cls(cache_dir=tmp_cache_dir) + assert isinstance(builder, DatasetBuilder) + + +def test_builder_class_with_kwargs(): + dataset_module = dataset_module_factory(str(DATASETS_ROOT / "single_config")) + builder_cls = import_main_class(dataset_module.module_path) + with tempfile.TemporaryDirectory() as tmp_cache_dir: + builder = builder_cls(cache_dir=tmp_cache_dir, parameter="test") + assert isinstance(builder, DatasetBuilder) + assert builder.config.parameter == "test" + + +def test_builder_class_with_kwargs_wrong_parameter(): + dataset_module = dataset_module_factory(str(DATASETS_ROOT / "single_config")) + builder_cls = import_main_class(dataset_module.module_path) + with tempfile.TemporaryDirectory() as tmp_cache_dir: + # this should raise an exception because the base config does not know the parameter + with pytest.raises( + TypeError, + match=re.escape("__init__() got an unexpected keyword argument 'unknown_parameter'"), + ): + builder = builder_cls( + cache_dir=tmp_cache_dir, parameter="test", unknown_parameter="test_unknown" + ) + + +def test_builder_class_with_base_dataset_kwargs(): + dataset_module = dataset_module_factory(str(DATASETS_ROOT / "single_config")) + builder_cls = import_main_class(dataset_module.module_path) + base_dataset_kwargs = dict(version=Version("0.0.0"), description="new description") + with tempfile.TemporaryDirectory() as tmp_cache_dir: + builder = builder_cls(cache_dir=tmp_cache_dir, base_dataset_kwargs=base_dataset_kwargs) + assert isinstance(builder, DatasetBuilder) + assert builder.base_builder.config.version == "0.0.0" + assert builder.base_builder.config.description == "new description" + + +def test_builder_class_with_base_dataset_kwargs_wrong_parameter(): + dataset_module = dataset_module_factory(str(DATASETS_ROOT / "single_config")) + builder_cls = import_main_class(dataset_module.module_path) + base_dataset_kwargs = dict(unknown_base_parameter="base_parameter_value") + with tempfile.TemporaryDirectory() as tmp_cache_dir: + # this should raise an exception because the base config does not know the parameter + with pytest.raises( + TypeError, + match=re.escape( + "__init__() got an unexpected keyword argument 'unknown_base_parameter'" + ), + ): + builder = builder_cls(cache_dir=tmp_cache_dir, base_dataset_kwargs=base_dataset_kwargs) + + +def test_builder_class_multi_configs(): + dataset_module = dataset_module_factory(str(DATASETS_ROOT / "multi_config")) + builder_cls = import_main_class(dataset_module.module_path) + with tempfile.TemporaryDirectory() as tmp_cache_dir: + with pytest.raises(ValueError, match="Config name is missing."): + builder = builder_cls(cache_dir=tmp_cache_dir) + + builder = builder_cls(config_name="es", cache_dir=tmp_cache_dir) + assert isinstance(builder, DatasetBuilder) + + +def test_builder_class_name_mapping(): + dataset_module = dataset_module_factory(str(DATASETS_ROOT / "name_mapping")) + builder_cls = import_main_class(dataset_module.module_path) + with tempfile.TemporaryDirectory() as tmp_cache_dir: + builder = builder_cls(config_name="es", cache_dir=tmp_cache_dir) + assert builder.info.config_name == "es" + assert builder.base_builder.info.config_name == "nl" + + builder = builder_cls(config_name="nl", cache_dir=tmp_cache_dir) + assert builder.info.config_name == "nl" + assert builder.base_builder.info.config_name == "nl" + + +def test_builder_class_name_mapping_disabled(): + dataset_module = dataset_module_factory(str(DATASETS_ROOT / "name_mapping_disabled")) + builder_cls = import_main_class(dataset_module.module_path) + with tempfile.TemporaryDirectory() as tmp_cache_dir: + # this should raise an exception because the config name is not passed + with pytest.raises(ValueError, match="Config name is missing."): + builder = builder_cls(config_name="es", cache_dir=tmp_cache_dir) + + # here we set the base config name via base_dataset_kwargs + builder = builder_cls( + config_name="es", cache_dir=tmp_cache_dir, base_dataset_kwargs=dict(name="nl") + ) + assert builder.info.config_name == "es" + assert builder.base_builder.info.config_name == "nl" + + +def test_builder_class_name_mapping_and_defaults(): + dataset_module = dataset_module_factory(str(DATASETS_ROOT / "default_config_kwargs")) + builder_cls = import_main_class(dataset_module.module_path) + with tempfile.TemporaryDirectory() as tmp_cache_dir: + # this comes from passing the config as base config name + builder = builder_cls(config_name="es", cache_dir=tmp_cache_dir) + assert builder.info.config_name == "es" + assert builder.base_builder.info.config_name == "es" + + # this gets created by the default setting from BASE_CONFIG_KWARGS_DICT + builder = builder_cls(config_name="nl", cache_dir=tmp_cache_dir) + assert builder.info.config_name == "nl" + assert builder.base_builder.info.config_name == "default" + assert builder.base_builder.info.version == "0.0.0" + + +def test_wrong_builder_class_config(): + dataset_module = dataset_module_factory(str(DATASETS_ROOT / "wrong_builder_class_config")) + builder_cls = import_main_class(dataset_module.module_path) + with tempfile.TemporaryDirectory() as tmp_cache_dir: + # This should raise an exception because the base builder is derived from GeneratorBasedBuilder, + # but the PIE dataset builder is derived from ArrowBasedBuilder. + with pytest.raises( + TypeError, + match=re.escape( + "The PyTorch-IE dataset builder class 'Example' is derived from " + ", but the base builder is not which is not allowed. " + "The base builder is of type 'Conll2003' that is derived from " + ". Consider to derive your PyTorch-IE dataset builder " + "'Example' from a PyTorch-IE variant of 'GeneratorBasedBuilder'." + ), + ): + builder_cls(cache_dir=tmp_cache_dir) + + +def test_builder_with_document_converters_rename(): + @dataclass + class RenamedExampleDocument(TextBasedDocument): + spans: AnnotationList[LabeledSpan] = annotation_field(target="text") + + dataset_module = dataset_module_factory(str(DATASETS_ROOT / "single_config")) + builder_cls: Type[PieDatasetBuilder] = import_main_class(dataset_module.module_path) + with tempfile.TemporaryDirectory() as tmp_cache_dir: + builder = builder_cls( + cache_dir=tmp_cache_dir, + document_converters={ + RenamedExampleDocument: {"entities": "spans"}, + }, + ) + assert isinstance(builder, PieDatasetBuilder) + assert builder.document_converters == { + RenamedExampleDocument: {"entities": "spans"}, + } + + +@dataclass +class ExampleDocumentWithSimpleSpans(TextBasedDocument): + spans: AnnotationList[Span] = annotation_field(target="text") + + +def convert_example_document_to_example_document_with_simple_spans( + document: TextDocumentWithSpans, +) -> ExampleDocumentWithSimpleSpans: + result = ExampleDocumentWithSimpleSpans(text=document.text, spans=document.spans) + for entity in document.spans: + result.spans.append(Span(start=entity.start, end=entity.end)) + return result + + +def test_builder_with_document_converters_resolve_document_type_and_converter(): + @dataclass + class RenamedExampleDocument(TextBasedDocument): + spans: AnnotationList[LabeledSpan] = annotation_field(target="text") + + dataset_module = dataset_module_factory(str(DATASETS_ROOT / "single_config")) + builder_cls: Type[PieDatasetBuilder] = import_main_class(dataset_module.module_path) + with tempfile.TemporaryDirectory() as tmp_cache_dir: + builder = builder_cls( + cache_dir=tmp_cache_dir, + document_converters={ + "tests.unit.test_builder.ExampleDocumentWithSimpleSpans": "tests.unit.test_builder.convert_example_document_to_example_document_with_simple_spans", + }, + ) + assert isinstance(builder, PieDatasetBuilder) + assert builder.document_converters == { + ExampleDocumentWithSimpleSpans: convert_example_document_to_example_document_with_simple_spans, + } + + +class NoDocumentType: + pass + + +def test_builder_with_document_converters_resolve_wrong_document_type(): + dataset_module = dataset_module_factory(str(DATASETS_ROOT / "single_config")) + builder_cls: Type[PieDatasetBuilder] = import_main_class(dataset_module.module_path) + with tempfile.TemporaryDirectory() as tmp_cache_dir: + with pytest.raises( + TypeError, + match=re.escape( + "The key 'tests.unit.test_builder.NoDocumentType' for one of the converters can not be resolved to a document type." + ), + ): + builder = builder_cls( + cache_dir=tmp_cache_dir, + document_converters={ + "tests.unit.test_builder.NoDocumentType": convert_example_document_to_example_document_with_simple_spans, + }, + ) diff --git a/tests/unit/test_dataset.py b/tests/unit/test_dataset.py new file mode 100644 index 00000000..0b9641a6 --- /dev/null +++ b/tests/unit/test_dataset.py @@ -0,0 +1,460 @@ +from collections.abc import Iterator, Sequence +from dataclasses import dataclass +from typing import Union + +import datasets +import numpy +import pytest +import torch +from pytorch_ie.annotations import BinaryRelation, Label, LabeledSpan, Span +from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.core.taskmodule import ( + IterableTaskEncodingDataset, + TaskEncodingDataset, + TaskEncodingSequence, +) +from pytorch_ie.documents import TextDocument +from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule + +from pie_datasets import Dataset, IterableDataset +from pie_datasets.dataset import get_pie_dataset_type +from tests import _HF_CONLL2003_IS_AVAILABLE, DATASET_BUILDERS_ROOT +from tests.conftest import TestDocument + +DATASET_NAME = "conll2003" +PIE_DATASET_PATH = DATASET_BUILDERS_ROOT / "pie" / DATASET_NAME +HF_DATASET_PATH = DATASET_NAME + + +@pytest.fixture(scope="module") +def taskmodule(): + tokenizer_name_or_path = "bert-base-cased" + taskmodule = TransformerSpanClassificationTaskModule( + tokenizer_name_or_path=tokenizer_name_or_path, + entity_annotation="entities", + ) + return taskmodule + + +@pytest.fixture +def model_output(): + return { + "logits": torch.from_numpy( + numpy.log( + [ + # O, ORG, PER + [0.5, 0.2, 0.3], + [0.1, 0.1, 0.8], + [0.1, 0.5, 0.4], + [0.1, 0.4, 0.5], + [0.1, 0.6, 0.3], + ] + ) + ), + "start_indices": torch.tensor([1, 1, 7, 1, 6]), + "end_indices": torch.tensor([2, 4, 7, 4, 6]), + "batch_indices": torch.tensor([0, 1, 1, 2, 2]), + } + + +def test_dataset(maybe_iterable_dataset): + dataset = { + k: list(v) if isinstance(v, IterableDataset) else v + for k, v in maybe_iterable_dataset.items() + } + assert set(dataset.keys()) == {"train", "validation", "test"} + + assert len(dataset["train"]) == 8 + assert len(dataset["validation"]) == 2 + assert len(dataset["test"]) == 2 + + train_doc5 = dataset["train"][4] + assert train_doc5.id == "train_doc5" + assert len(train_doc5.sentences) == 3 + assert len(train_doc5.entities) == 3 + assert len(train_doc5.relations) == 3 + + assert str(train_doc5.sentences[1]) == "Entity G works at H." + + +def test_dataset_index(dataset): + train_dataset = dataset["train"] + assert train_dataset[4].id == "train_doc5" + assert [doc.id for doc in train_dataset[0, 3, 5]] == ["train_doc1", "train_doc4", "train_doc6"] + assert [doc.id for doc in train_dataset[2:5]] == ["train_doc3", "train_doc4", "train_doc5"] + + +def test_dataset_map(maybe_iterable_dataset): + train_dataset = maybe_iterable_dataset["train"] + + def clear_relations(document): + document.relations.clear() + return document + + assert sum(len(doc.relations) for doc in train_dataset) == 7 + + mapped_dataset1 = train_dataset.map(clear_relations) + + assert sum(len(doc.relations) for doc in mapped_dataset1) == 0 + assert sum(len(doc.relations) for doc in train_dataset) == 7 + + +def test_dataset_map_batched(maybe_iterable_dataset): + train_dataset = maybe_iterable_dataset["train"] + + def clear_relations_batched(documents): + assert len(documents) == 2 + for document in documents: + document.relations.clear() + return documents + + assert sum(len(doc.relations) for doc in train_dataset) == 7 + + mapped_dataset1 = train_dataset.map(clear_relations_batched, batched=True, batch_size=2) + + assert sum(len(doc.relations) for doc in mapped_dataset1) == 0 + assert sum(len(doc.relations) for doc in train_dataset) == 7 + + +def test_dataset_map_with_result_document_type(maybe_iterable_dataset): + @dataclass + class TestDocument(TextDocument): + sentences: AnnotationList[Span] = annotation_field(target="text") + entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") + + @dataclass + class TestDocumentWithTokensButNoRelations(TextDocument): + sentences: AnnotationList[Span] = annotation_field(target="text") + tokens: AnnotationList[Span] = annotation_field(target="text") + entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + + def clear_relations_and_add_one_token( + document: TestDocument, + ) -> TestDocumentWithTokensButNoRelations: + document.relations.clear() + # the conversion here is not really necessary, but to have correct typing + result = document.as_type(TestDocumentWithTokensButNoRelations) + # subtract 1 to create a Span different from the sentence to account for + # https://github.com/ChristophAlt/pytorch-ie/pull/222 + result.tokens.append(Span(0, len(document.text) - 1)) + return result + + train_dataset = maybe_iterable_dataset["train"] + + assert sum(len(doc.relations) for doc in train_dataset) == 7 + + mapped_dataset1 = train_dataset.map( + clear_relations_and_add_one_token, + result_document_type=TestDocumentWithTokensButNoRelations, + ) + + assert sum(len(doc.relations) for doc in train_dataset) == 7 + + doc0 = list(train_dataset)[0] + doc0_mapped = list(mapped_dataset1)[0] + assert len(doc0_mapped.tokens) == 1 + token = doc0_mapped.tokens[0] + assert token.start == 0 + assert token.end == len(doc0.text) - 1 + # check field names because isinstance does not work (the code of the document types + # is the same, but lives at different locations) + assert {f.name for f in doc0.fields()} == {f.name for f in TestDocument.fields()} + assert {f.name for f in doc0_mapped.fields()} == { + f.name for f in TestDocumentWithTokensButNoRelations.fields() + } + + +@pytest.mark.parametrize("encode_target", [False, True]) +@pytest.mark.parametrize("inplace", [False, True]) +@pytest.mark.parametrize("as_dataset", [False, True]) +def test_dataset_with_taskmodule( + maybe_iterable_dataset, taskmodule, model_output, encode_target, inplace, as_dataset +): + train_dataset = maybe_iterable_dataset["train"] + + taskmodule.prepare(train_dataset) + assert set(taskmodule.label_to_id.keys()) == {"PER", "ORG", "O"} + assert [taskmodule.id_to_label[i] for i in range(3)] == ["O", "ORG", "PER"] + assert taskmodule.label_to_id["O"] == 0 + + as_task_encoding_sequence = not encode_target + as_iterator = isinstance(train_dataset, (IterableDataset, Iterator)) + if as_task_encoding_sequence: + if as_iterator: + with pytest.raises( + ValueError, match="can not return a TaskEncodingSequence as Iterator" + ): + taskmodule.encode( + train_dataset, encode_target=encode_target, as_dataset=as_dataset + ) + return + if as_dataset: + with pytest.raises( + ValueError, match="can not return a TaskEncodingSequence as a dataset" + ): + taskmodule.encode( + train_dataset, encode_target=encode_target, as_dataset=as_dataset + ) + return + + task_encodings = taskmodule.encode( + train_dataset, encode_target=encode_target, as_dataset=as_dataset + ) + + if as_iterator: + if as_task_encoding_sequence: + raise NotImplementedError("this is not yet implemented") + if as_dataset: + assert isinstance(task_encodings, IterableTaskEncodingDataset) + else: + assert isinstance(task_encodings, Iterator) + else: + if as_dataset: + if as_task_encoding_sequence: + raise NotImplementedError("this is not yet implemented") + else: + assert isinstance(task_encodings, TaskEncodingDataset) + else: + if as_task_encoding_sequence: + assert isinstance(task_encodings, TaskEncodingSequence) + else: + assert isinstance(task_encodings, Sequence) + + task_encoding_list = list(task_encodings) + assert len(task_encoding_list) == 8 + task_encoding = task_encoding_list[5] + document = list(train_dataset)[5] + assert task_encoding.document == document + assert "input_ids" in task_encoding.inputs + assert ( + taskmodule.tokenizer.decode(task_encoding.inputs["input_ids"], skip_special_tokens=True) + == document.text + ) + + if encode_target: + assert task_encoding.targets == [ + (1, 4, taskmodule.label_to_id["PER"]), + (6, 6, taskmodule.label_to_id["ORG"]), + (9, 9, taskmodule.label_to_id["ORG"]), + ] + else: + assert not task_encoding.has_targets + + unbatched_outputs = taskmodule.unbatch_output(model_output) + + decoded_documents = taskmodule.decode( + task_encodings=task_encodings, + task_outputs=unbatched_outputs, + inplace=inplace, + ) + + if isinstance(train_dataset, Dataset): + assert len(decoded_documents) == len(train_dataset) + + assert {id(doc) for doc in decoded_documents}.isdisjoint({id(doc) for doc in train_dataset}) + + expected_scores = [0.8, 0.5, 0.5, 0.6] + i = 0 + for document in decoded_documents: + for entity_expected, entity_decoded in zip( + document["entities"], document["entities"].predictions + ): + assert entity_expected.start == entity_decoded.start + assert entity_expected.end == entity_decoded.end + assert entity_expected.label == entity_decoded.label + assert expected_scores[i] == pytest.approx(entity_decoded.score) + i += 1 + + for document in train_dataset: + assert not document["entities"].predictions + + +@pytest.mark.skipif( + not _HF_CONLL2003_IS_AVAILABLE, + reason="the Huggingface conll2003 dataset is not reachable and the local PIE-variant depends on it", +) +def test_load_with_hf_datasets(): + dataset = datasets.load_dataset(path=str(HF_DATASET_PATH)) + + assert set(dataset.keys()) == {"train", "validation", "test"} + + assert len(dataset["train"]) == 14041 + assert len(dataset["validation"]) == 3250 + assert len(dataset["test"]) == 3453 + + +@pytest.mark.skipif( + not _HF_CONLL2003_IS_AVAILABLE, + reason="the Huggingface conll2003 dataset is not reachable and the remote PIE-variant depends on it", +) +def test_load_with_hf_datasets_from_hub(): + dataset = datasets.load_dataset(path=str(PIE_DATASET_PATH)) + + assert set(dataset.keys()) == {"train", "validation", "test"} + + assert len(dataset["train"]) == 14041 + assert len(dataset["validation"]) == 3250 + assert len(dataset["test"]) == 3453 + + +def test_get_pie_dataset_type(hf_dataset, iterable_hf_dataset): + assert get_pie_dataset_type(hf_dataset["train"]) == Dataset + assert get_pie_dataset_type(iterable_hf_dataset["train"]) == IterableDataset + with pytest.raises(TypeError) as excinfo: + get_pie_dataset_type("not a dataset") + assert ( + str(excinfo.value) + == "the dataset must be of type Dataset or IterableDataset, but is of type " + ) + + +@dataclass +class TestDocumentWithLabel(TextDocument): + label: AnnotationList[Label] = annotation_field() + + +def convert_to_document_with_label(document: TestDocument) -> TestDocumentWithLabel: + result = TestDocumentWithLabel(text=document.text) + result.label.append(Label(label="label")) + return result + + +@pytest.fixture +def dataset_with_converter_functions(maybe_iterable_dataset) -> Union[Dataset, IterableDataset]: + train_dataset: Union[Dataset, IterableDataset] = maybe_iterable_dataset["train"] + assert len(train_dataset.document_converters) == 0 + + train_dataset.register_document_converter(convert_to_document_with_label) + return train_dataset + + +def test_register_document_converter_function(dataset_with_converter_functions): + assert len(dataset_with_converter_functions.document_converters) == 1 + assert TestDocumentWithLabel in dataset_with_converter_functions.document_converters + assert ( + dataset_with_converter_functions.document_converters[TestDocumentWithLabel] + == convert_to_document_with_label + ) + + +@dataclass +class TestDocumentWithLabeledSpans(TextDocument): + spans: AnnotationList[LabeledSpan] = annotation_field(target="text") + + +@pytest.fixture +def dataset_with_converter_mapping(maybe_iterable_dataset) -> Union[Dataset, IterableDataset]: + train_dataset: Union[Dataset, IterableDataset] = maybe_iterable_dataset["train"] + assert len(train_dataset.document_converters) == 0 + + field_mapping = {"entities": "spans"} + train_dataset.register_document_converter( + converter=field_mapping, document_type=TestDocumentWithLabeledSpans + ) + return train_dataset + + +def test_register_document_converter_mapping(dataset_with_converter_mapping): + assert len(dataset_with_converter_mapping.document_converters) == 1 + assert TestDocumentWithLabeledSpans in dataset_with_converter_mapping.document_converters + assert dataset_with_converter_mapping.document_converters[TestDocumentWithLabeledSpans] == { + "entities": "spans" + } + + +def test_to_document_type_function(dataset_with_converter_functions): + assert dataset_with_converter_functions.document_type == TestDocument + converted_dataset = dataset_with_converter_functions.to_document_type(TestDocumentWithLabel) + assert converted_dataset.document_type == TestDocumentWithLabel + + assert len(converted_dataset.document_converters) == 0 + for doc in converted_dataset: + assert isinstance(doc, TestDocumentWithLabel) + assert len(doc.label) == 1 + assert doc.label[0].label == "label" + + +def test_to_document_type_mapping(dataset_with_converter_mapping): + assert dataset_with_converter_mapping.document_type == TestDocument + converted_dataset = dataset_with_converter_mapping.to_document_type( + TestDocumentWithLabeledSpans + ) + assert converted_dataset.document_type == TestDocumentWithLabeledSpans + + assert len(converted_dataset.document_converters) == 0 + for doc_converted, doc in zip(converted_dataset, dataset_with_converter_mapping): + assert isinstance(doc, TestDocument) + assert isinstance(doc_converted, TestDocumentWithLabeledSpans) + assert "spans" in doc_converted + assert doc_converted.spans == doc.entities + original_annotation_field_names = {f.name for f in doc.annotation_fields()} + assert original_annotation_field_names == {"sentences", "entities", "relations"} + for annotation_field_name in original_annotation_field_names: + assert annotation_field_name not in doc_converted + + +def test_to_document_type_noop(maybe_iterable_dataset): + train_dataset: Union[Dataset, IterableDataset] = maybe_iterable_dataset["train"] + assert len(train_dataset.document_converters) == 0 + train_dataset.register_document_converter( + convert_to_document_with_label, document_type=TestDocument + ) + assert train_dataset.document_type == TestDocument + converted_dataset = train_dataset.to_document_type(TestDocument) + # the conversion should be a noop + assert converted_dataset.document_type == TestDocument + assert converted_dataset == train_dataset + assert len(converted_dataset.document_converters) == 1 + assert TestDocument in converted_dataset.document_converters + assert converted_dataset.document_converters[TestDocument] == convert_to_document_with_label + + +def test_to_document_type_convert_and_cast(dataset_with_converter_functions): + @dataclass + class TestDocumentWithLabelAndSpans(TestDocumentWithLabel): + label: AnnotationList[Label] = annotation_field() + spans: AnnotationList[Span] = annotation_field(target="text") + + assert dataset_with_converter_functions.document_type == TestDocument + # The only converter is registered for TestDocumentWithLabel, but we request a conversion to + # TestDocumentWithLabelAndSpans which is a *subclass* of TestDocumentWithLabel. This is a valid type + # and the conversion is performed by first converting to TestDocumentWithLabel and then casting + # to TestDocumentWithLabelAndSpans. + converted_dataset = dataset_with_converter_functions.to_document_type( + TestDocumentWithLabelAndSpans + ) + assert converted_dataset.document_type == TestDocumentWithLabelAndSpans + + assert len(converted_dataset.document_converters) == 0 + for converted_doc, doc in zip(converted_dataset, dataset_with_converter_functions): + assert isinstance(doc, TestDocument) + assert isinstance(converted_doc, TestDocumentWithLabelAndSpans) + assert converted_doc.text == doc.text + assert len(converted_doc.label) == 1 + assert converted_doc.label[0].label == "label" + assert len(converted_doc.spans) == 0 + + +def test_to_document_type_not_found(dataset_with_converter_functions): + assert dataset_with_converter_functions.document_type == TestDocument + + @dataclass + class TestDocumentWithSpans(TestDocument): + spans: AnnotationList[Span] = annotation_field(target="text") + + # The only converter is registered for TestDocumentWithLabel, but we request a conversion to + # TestDocumentWithSpans. This is not a valid type because it is neither a subclass nor a superclass of + # TestDocumentWithLabel, so an error is raised. + with pytest.raises(ValueError) as excinfo: + dataset_with_converter_functions.to_document_type(TestDocumentWithSpans) + assert ( + str(excinfo.value) + == "No valid key (either subclass or superclass) was found for the document type " + "'.TestDocumentWithSpans'>' " + "in the document_converters of the dataset. Available keys: " + "{}. Consider adding a respective converter " + "to the dataset with dataset.register_document_converter(my_converter_method) where " + "my_converter_method should accept as input and return " + "'.TestDocumentWithSpans'>'." + ) diff --git a/tests/unit/test_dataset_casting.py b/tests/unit/test_dataset_casting.py new file mode 100644 index 00000000..fede04e3 --- /dev/null +++ b/tests/unit/test_dataset_casting.py @@ -0,0 +1,238 @@ +import re +from dataclasses import dataclass + +import pytest +from pytorch_ie.annotations import LabeledSpan, Span +from pytorch_ie.core import AnnotationList, annotation_field +from pytorch_ie.documents import TextDocument + +from pie_datasets import Dataset, IterableDataset + + +@dataclass +class CoNLL2002Document(TextDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + + +@dataclass +class DocumentWithParts(TextDocument): + parts: AnnotationList[Span] = annotation_field(target="text") + + +@dataclass +class CoNLL2002WithPartsDocument(CoNLL2002Document, DocumentWithParts): + pass + + +@dataclass +class DocumentWithEnts(TextDocument): + ents: AnnotationList[LabeledSpan] = annotation_field(target="text") + + +@dataclass +class DocumentWithEntsWrongType(TextDocument): + ents: AnnotationList[Span] = annotation_field(target="text") + + +@dataclass +class DocumentWithEntsAndParts(DocumentWithParts, DocumentWithEnts): + pass + + +@dataclass +class DocumentWithPartsAndEntitiesSwapped(TextDocument): + parts: AnnotationList[LabeledSpan] = annotation_field(target="text") + entities: AnnotationList[Span] = annotation_field(target="text") + + +@pytest.fixture() +def dataset_train(maybe_iterable_dataset): + return maybe_iterable_dataset["train"].cast_document_type( + CoNLL2002Document, remove_columns=True + ) + + +def _add_full_part(doc: DocumentWithParts) -> DocumentWithParts: + doc.parts.append(Span(start=0, end=len(doc.text))) + return doc + + +def _get_doc(ds): + # use the second document since it has entities + IDX = 2 + if isinstance(ds, Dataset): + return ds[IDX] + elif isinstance(ds, IterableDataset): + it = iter(ds) + doc = None + for i in range(IDX + 1): + doc = next(it) + return doc + else: + raise TypeError(f"Unknown dataset type: {type(ds)}") + + +def test_cast_document_type(dataset_train): + casted = dataset_train.cast_document_type(CoNLL2002WithPartsDocument) + doc0_orig = _get_doc(dataset_train) + with_parts = casted.map(lambda doc: _add_full_part(doc)) + assert "entities" in with_parts.column_names + assert "parts" in with_parts.column_names + doc0 = _get_doc(with_parts) + assert set(doc0) == {"entities", "parts"} + assert doc0.entities == doc0_orig.entities + + part0 = doc0.parts[0] + assert isinstance(part0, Span) + assert part0.start == 0 + assert part0.end == len(doc0.text) + + +def test_cast_document_type_remove_field(dataset_train): + doc0_orig = _get_doc(dataset_train) + casted = dataset_train.cast_document_type(DocumentWithParts, remove_columns=True) + with_partitions = casted.map(lambda doc: _add_full_part(doc)) + assert "entities" not in with_partitions.column_names + assert "parts" in with_partitions.column_names + doc0 = _get_doc(with_partitions) + assert set(doc0) == {"parts"} + + part0 = doc0.parts[0] + assert isinstance(part0, Span) + assert part0.start == 0 + assert part0.end == len(doc0.text) + + casted_back = with_partitions.cast_document_type(CoNLL2002Document) + assert "entities" in casted_back.column_names + # original entities are not available anymore after casting back + assert len(doc0_orig.entities) > 0 + assert len(list(casted_back)[0].entities) == 0 + + +def test_cast_document_type_recover_field(dataset_train): + doc_orig = _get_doc(dataset_train) + casted = dataset_train.cast_document_type(DocumentWithParts) + # "entities" stay in the arrow table because remove_columns=False per default + assert "entities" in casted.column_names + assert "parts" in casted.column_names + + doc_casted = _get_doc(casted) + assert set(doc_casted) == {"parts"} + + casted_back = casted.cast_document_type(CoNLL2002Document) + assert "entities" in casted_back.column_names + # original entities are recovered after casting back + doc_back = _get_doc(casted_back) + assert len(doc_back.entities) > 0 + assert doc_back.entities == doc_orig.entities + + +def test_cast_document_type_recover_field_with_mapping(dataset_train): + doc_orig = _get_doc(dataset_train) + casted = dataset_train.cast_document_type(DocumentWithParts) + # "entities" stay in the arrow table because remove_columns=False per default + assert "entities" in casted.column_names + assert "parts" in casted.column_names + + doc_casted = _get_doc(casted) + assert set(doc_casted) == {"parts"} + + casted_back = casted.cast_document_type( + DocumentWithEntsAndParts, field_mapping={"entities": "ents"} + ) + assert "ents" in casted_back.column_names + # original entities are recovered after casting back + doc_back = _get_doc(casted_back) + assert len(doc_back.ents) > 0 + assert doc_back.ents == doc_orig.entities + + +def test_cast_document_type_recover_field_wrong(dataset_train): + casted = dataset_train.cast_document_type(DocumentWithEntsAndParts) + # "entities" stay in the arrow table because remove_columns=False per default + assert "entities" in casted.column_names + assert "parts" in casted.column_names + assert "ents" in casted.column_names + + doc_casted = _get_doc(casted) + assert set(doc_casted) == {"parts", "ents"} + + with pytest.raises( + ValueError, + match=re.escape( + "rename targets are already in column names: {'entities'}. Did you miss to set remove_columns=True in a previous call of cast_document_type?" + ), + ): + casted.cast_document_type(CoNLL2002Document, field_mapping={"ents": "entities"}) + + +def test_cast_document_type_rename_field(dataset_train): + doc0_orig = _get_doc(dataset_train) + casted = dataset_train.cast_document_type( + DocumentWithEntsAndParts, field_mapping={"entities": "ents"} + ) + with_parts = casted.map(lambda doc: _add_full_part(doc)) + assert "ents" in with_parts.column_names + assert "parts" in with_parts.column_names + doc0 = _get_doc(with_parts) + assert set(doc0) == {"ents", "parts"} + assert doc0.ents == doc0_orig.entities + + part0 = doc0.parts[0] + assert isinstance(part0, Span) + assert part0.start == 0 + assert part0.end == len(doc0.text) + + +def test_cast_document_type_swap_fields(dataset_train): + if isinstance(dataset_train, IterableDataset): + # TODO: for now, this would fail because datasets.IterableDataset.rename_columns() is too restrictive + # (does not allow swapping) + return + + # just add "parts" to have another field to swap "entities" with + casted = dataset_train.cast_document_type(CoNLL2002WithPartsDocument) + with_parts = casted.map(lambda doc: _add_full_part(doc)) + doc_with_parts = _get_doc(with_parts) + + swapped = with_parts.cast_document_type( + DocumentWithPartsAndEntitiesSwapped, + field_mapping={"entities": "parts", "parts": "entities"}, + ) + assert "entities" in swapped.column_names + assert "parts" in swapped.column_names + doc_swapped = _get_doc(swapped) + assert set(doc_swapped) == {"entities", "parts"} + assert doc_swapped.parts == doc_with_parts.entities + assert doc_swapped.entities == doc_with_parts.parts + + +def test_cast_document_type_rename_source_not_available(dataset_train): + with pytest.raises( + ValueError, + match=re.escape( + "some fields to rename are not in the original document_type or hidden fields: {'not_in_original_document'}" + ), + ): + dataset_train.cast_document_type( + DocumentWithEntsWrongType, field_mapping={"not_in_original_document": "ents"} + ) + + +def test_cast_document_type_rename_target_not_available(dataset_train): + with pytest.raises( + ValueError, + match=re.escape( + "some renamed fields are not in the new document_type: {'not_in_new_document'}" + ), + ): + dataset_train.cast_document_type( + DocumentWithEntsWrongType, field_mapping={"entities": "not_in_new_document"} + ) + + +def test_cast_document_type_rename_wrong_type(dataset_train): + with pytest.raises(ValueError, match=re.escape("new field is not the same as old field:")): + dataset_train.cast_document_type( + DocumentWithEntsWrongType, field_mapping={"entities": "ents"} + ) diff --git a/tests/unit/test_dataset_dict.py b/tests/unit/test_dataset_dict.py new file mode 100644 index 00000000..dbc420a7 --- /dev/null +++ b/tests/unit/test_dataset_dict.py @@ -0,0 +1,534 @@ +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, Optional, Union + +import datasets +import pytest +from pytorch_ie.annotations import Label, LabeledSpan +from pytorch_ie.core import AnnotationList, Document, annotation_field +from pytorch_ie.documents import TextBasedDocument, TextDocument + +from pie_datasets import ( + Dataset, + DatasetDict, + EnterDatasetDictMixin, + EnterDatasetMixin, + ExitDatasetDictMixin, + ExitDatasetMixin, + IterableDataset, +) +from tests import DATASET_BUILDERS_ROOT, FIXTURES_ROOT +from tests.conftest import TestDocument + +logger = logging.getLogger(__name__) + +DATA_PATH = FIXTURES_ROOT / "dataset_dict" / "conll2003_extract" +DATASET_NAME = "conll2003" +PIE_DATASET_PATH = DATASET_BUILDERS_ROOT / "pie" / DATASET_NAME +TEST_CLASS_PREFIX = "tests.unit.test_dataset_dict" + +CREATE_FIXTURE_DATA = False + + +@pytest.mark.skipif(condition=not CREATE_FIXTURE_DATA, reason="don't create fixture data again") +def test_create_fixture_data(): + conll2003 = DatasetDict(datasets.load_dataset(str(PIE_DATASET_PATH))) + for split in list(conll2003): + # restrict all splits to 3 examples + conll2003 = conll2003.select(split=split, stop=3) + conll2003.to_json(DATA_PATH) + + +@dataclass +class DocumentWithEntitiesAndRelations(TextBasedDocument): + entities: AnnotationList[LabeledSpan] = annotation_field(target="text") + + +@pytest.fixture(scope="module") +def dataset_dict(): + return DatasetDict.from_json( + data_dir=DATA_PATH, document_type=DocumentWithEntitiesAndRelations + ) + + +def test_from_json(dataset_dict): + assert set(dataset_dict) == {"train", "test", "validation"} + assert len(dataset_dict["train"]) == 3 + assert len(dataset_dict["test"]) == 3 + assert len(dataset_dict["validation"]) == 3 + + +def test_from_json_no_serialized_document_type(dataset_dict): + with pytest.raises(ValueError) as excinfo: + DatasetDict.from_json(data_dir=DATA_PATH) + assert ( + str(excinfo.value) + == "document_type must be provided if it cannot be loaded from the metadata file" + ) + + +@pytest.fixture(scope="module") +def iterable_dataset_dict(): + return DatasetDict.from_json( + data_dir=DATA_PATH, + document_type=DocumentWithEntitiesAndRelations, + streaming=True, + ) + + +def test_iterable_dataset_dict(iterable_dataset_dict): + assert set(iterable_dataset_dict) == {"train", "test", "validation"} + + +def test_to_json_and_back(dataset_dict, tmp_path): + path = Path(tmp_path) / "dataset_dict" + dataset_dict.to_json(path) + dataset_dict_from_json = DatasetDict.from_json( + data_dir=path, + document_type=dataset_dict.document_type, + ) + assert set(dataset_dict_from_json) == set(dataset_dict) + for split in dataset_dict: + assert len(dataset_dict_from_json[split]) == len(dataset_dict[split]) + for doc1, doc2 in zip(dataset_dict_from_json[split], dataset_dict[split]): + assert doc1 == doc2 + + +def test_to_json_and_back_serialize_document_type(dataset_dict, tmp_path): + path = Path(tmp_path) / "dataset_dict" + dataset_dict.to_json(path) + dataset_dict_from_json = DatasetDict.from_json( + data_dir=path, + ) + assert set(dataset_dict_from_json) == set(dataset_dict) + for split in dataset_dict: + assert len(dataset_dict_from_json[split]) == len(dataset_dict[split]) + for doc1, doc2 in zip(dataset_dict_from_json[split], dataset_dict[split]): + assert doc1 == doc2 + + +def test_document_type_empty_no_splits(): + with pytest.raises(ValueError) as excinfo: + DatasetDict().document_type + assert ( + str(excinfo.value) == "dataset does not contain any splits, cannot determine document type" + ) + + +def test_document_type_different_types(dataset_dict): + # load the example dataset as a different document type + dataset_dict_different_type = DatasetDict.from_json( + data_dir=DATA_PATH, + document_type=TextBasedDocument, + ) + assert dataset_dict_different_type.document_type is TextBasedDocument + # create a dataset dict with different document types for train and test splits + dataset_dict_different_types = DatasetDict( + { + "train": dataset_dict["train"], + "test": dataset_dict_different_type["test"], + } + ) + # accessing the document type should raise an error with the message that starts with + # "dataset contains splits with different document types:" + with pytest.raises(ValueError) as excinfo: + dataset_dict_different_types.document_type + assert str(excinfo.value).startswith("dataset contains splits with different document types:") + + +def test_dataset_type(dataset_dict): + assert dataset_dict.dataset_type is Dataset + + +def test_dataset_type_no_splits(): + with pytest.raises(ValueError) as excinfo: + DatasetDict().dataset_type + assert ( + str(excinfo.value) + == "dataset does not contain any splits, cannot determine the dataset type" + ) + + +def test_dataset_type_different_type(dataset_dict, iterable_dataset_dict): + dataset_dict_different_type = DatasetDict( + { + "train": dataset_dict["train"], + "test": iterable_dataset_dict["test"], + } + ) + with pytest.raises(ValueError) as excinfo: + dataset_dict_different_type.dataset_type + assert str(excinfo.value).startswith("dataset contains splits with different dataset types:") + + +def map_fn(doc): + doc.text = doc.text.upper() + return doc + + +@pytest.mark.parametrize( + "function", + [map_fn, f"{TEST_CLASS_PREFIX}.map_fn"], +) +def test_map(dataset_dict, function): + dataset_dict_mapped = dataset_dict.map(function) + for split in dataset_dict: + assert len(dataset_dict_mapped[split]) == len(dataset_dict[split]) + for doc1, doc2 in zip(dataset_dict_mapped[split], dataset_dict[split]): + assert doc1.text == doc2.text.upper() + + +def test_map_noop(dataset_dict): + dataset_dict_mapped = dataset_dict.map() + for split in dataset_dict: + assert len(dataset_dict_mapped[split]) == len(dataset_dict[split]) + for doc1, doc2 in zip(dataset_dict_mapped[split], dataset_dict[split]): + assert doc1 == doc2 + + +def test_map_with_result_document_type(dataset_dict): + dataset_dict_mapped = dataset_dict.map(result_document_type=TextBasedDocument) + for split in dataset_dict: + assert len(dataset_dict_mapped[split]) == len(dataset_dict[split]) + for doc1, doc2 in zip(dataset_dict_mapped[split], dataset_dict[split]): + assert isinstance(doc1, TextBasedDocument) + assert isinstance(doc2, DocumentWithEntitiesAndRelations) + assert doc1.text == doc2.text + + +def test_map_with_context_manager(dataset_dict): + class DocumentCounter( + EnterDatasetMixin, ExitDatasetMixin, EnterDatasetDictMixin, ExitDatasetDictMixin + ): + def reset_statistics(self): + self.number = 0 + + def __call__(self, doc): + self.number += 1 + return doc + + def enter_dataset( + self, dataset: Union[Dataset, IterableDataset], name: Optional[str] = None + ) -> None: + self.reset_statistics() + self.split = name + + def exit_dataset( + self, dataset: Union[Dataset, IterableDataset], name: Optional[str] = None + ) -> None: + self.all_docs[self.split] = self.number + + def enter_dataset_dict(self, dataset_dict: DatasetDict) -> None: + self.all_docs: Dict[Optional[str], int] = {} + self.split = None + + def exit_dataset_dict(self, dataset_dict: DatasetDict) -> None: + logger.info(f"Number of documents per split: {self.all_docs}") + + document_counter = DocumentCounter() + # note that we need to disable caching here, otherwise the __call__ method may not be called for any dataset split + dataset_dict_mapped = dataset_dict.map(function=document_counter, load_from_cache_file=False) + assert document_counter.all_docs == {"train": 3, "test": 3, "validation": 3} + + # the document_counter should not have been modified the dataset + assert set(dataset_dict_mapped) == set(dataset_dict) + for split in dataset_dict: + assert len(dataset_dict_mapped[split]) == len(dataset_dict[split]) + for doc1, doc2 in zip(dataset_dict_mapped[split], dataset_dict[split]): + assert doc1 == doc2 + + +def test_select(dataset_dict): + # select documents by index + dataset_dict_selected = dataset_dict.select( + split="train", + indices=[0, 2], + ) + assert len(dataset_dict_selected["train"]) == 2 + assert dataset_dict_selected["train"][0] == dataset_dict["train"][0] + assert dataset_dict_selected["train"][1] == dataset_dict["train"][2] + + # select documents by range + dataset_dict_selected = dataset_dict.select( + split="train", + stop=2, + start=1, + step=1, + ) + assert len(dataset_dict_selected["train"]) == 1 + assert dataset_dict_selected["train"][0] == dataset_dict["train"][1] + + # calling with no arguments that do result in the creation of indices should return the same dataset, + # but will log a warning if other arguments (here "any_arg") are passed + dataset_dict_selected = dataset_dict.select(split="train", any_arg="ignored") + assert len(dataset_dict_selected["train"]) == len(dataset_dict["train"]) + assert dataset_dict_selected["train"][0] == dataset_dict["train"][0] + assert dataset_dict_selected["train"][1] == dataset_dict["train"][1] + assert dataset_dict_selected["train"][2] == dataset_dict["train"][2] + + +def test_rename_splits(dataset_dict): + mapping = { + "train": "train_renamed", + "test": "test_renamed", + "validation": "validation_renamed", + } + dataset_dict_renamed = dataset_dict.rename_splits(mapping) + assert set(dataset_dict_renamed) == set(mapping.values()) + for split in dataset_dict: + split_renamed = mapping[split] + assert len(dataset_dict_renamed[split_renamed]) == len(dataset_dict[split]) + for doc1, doc2 in zip(dataset_dict_renamed[split_renamed], dataset_dict[split]): + assert doc1 == doc2 + + +def test_rename_split_noop(dataset_dict): + dataset_dict_renamed = dataset_dict.rename_splits() + assert set(dataset_dict_renamed) == set(dataset_dict) + for split in dataset_dict: + assert len(dataset_dict_renamed[split]) == len(dataset_dict[split]) + for doc1, doc2 in zip(dataset_dict_renamed[split], dataset_dict[split]): + assert doc1 == doc2 + + +def assert_doc_lists_equal(docs: Iterable[Document], other_docs: Iterable[Document]): + assert all(doc1 == doc2 for doc1, doc2 in zip(docs, other_docs)) + + +def test_add_test_split(dataset_dict): + dataset_dict_with_test = dataset_dict.add_test_split( + source_split="test", target_split="new_test", test_size=1, shuffle=False + ) + assert "new_test" in dataset_dict_with_test + assert len(dataset_dict_with_test["new_test"]) + len(dataset_dict_with_test["test"]) == len( + dataset_dict["test"] + ) + assert len(dataset_dict_with_test["new_test"]) == 1 + assert len(dataset_dict_with_test["test"]) == 2 + assert_doc_lists_equal(dataset_dict_with_test["new_test"], dataset_dict["test"][2:]) + assert_doc_lists_equal(dataset_dict_with_test["test"], dataset_dict["test"][:2]) + test_ids = [doc.id for doc in dataset_dict_with_test["test"]] + new_test_ids = [doc.id for doc in dataset_dict_with_test["new_test"]] + assert set(test_ids).intersection(set(new_test_ids)) == set() + + # remaining splits should be unchanged + assert len(dataset_dict_with_test["train"]) == len(dataset_dict["train"]) + assert len(dataset_dict_with_test["validation"]) == len(dataset_dict["validation"]) + assert_doc_lists_equal(dataset_dict_with_test["train"], dataset_dict["train"]) + assert_doc_lists_equal(dataset_dict_with_test["validation"], dataset_dict["validation"]) + + +def test_drop_splits(dataset_dict): + dataset_dict_dropped = dataset_dict.drop_splits(["train", "validation"]) + assert set(dataset_dict_dropped) == {"test"} + assert len(dataset_dict_dropped["test"]) == len(dataset_dict["test"]) + assert_doc_lists_equal(dataset_dict_dropped["test"], dataset_dict["test"]) + + +def test_concat_splits(dataset_dict): + dataset_dict_concatenated = dataset_dict.concat_splits(["train", "validation"], target="train") + assert set(dataset_dict_concatenated) == {"test", "train"} + assert len(dataset_dict_concatenated["train"]) == len(dataset_dict["train"]) + len( + dataset_dict["validation"] + ) + assert_doc_lists_equal( + dataset_dict_concatenated["train"], + list(dataset_dict["train"]) + list(dataset_dict["validation"]), + ) + + +def test_concat_splits_no_splits(dataset_dict): + with pytest.raises(ValueError) as excinfo: + dataset_dict.concat_splits(splits=[], target="train") + assert str(excinfo.value) == "please provide at least one split to concatenate" + + +def test_concat_splits_different_dataset_types(dataset_dict, iterable_dataset_dict): + dataset_dict_to_concat = DatasetDict( + { + "train": dataset_dict["train"], + "validation": iterable_dataset_dict["validation"], + } + ) + with pytest.raises(ValueError) as excinfo: + dataset_dict_to_concat.concat_splits(splits=["train", "validation"], target="train") + assert str(excinfo.value).startswith("dataset contains splits with different dataset types:") + + +def test_filter(dataset_dict): + dataset_dict_filtered = dataset_dict.filter( + function=lambda doc: len(doc["text"]) > 15, + split="train", + ) + assert all(len(doc.text) > 15 for doc in dataset_dict_filtered["train"]) + assert len(dataset_dict["train"]) == 3 + assert len(dataset_dict_filtered["train"]) == 2 + assert dataset_dict_filtered["train"][0] == dataset_dict["train"][0] + assert dataset_dict_filtered["train"][1] == dataset_dict["train"][2] + + # remaining splits should be unchanged + assert len(dataset_dict_filtered["validation"]) == len(dataset_dict["validation"]) == 3 + assert len(dataset_dict_filtered["test"]) == len(dataset_dict["test"]) == 3 + assert_doc_lists_equal(dataset_dict_filtered["validation"], dataset_dict["validation"]) + assert_doc_lists_equal(dataset_dict_filtered["test"], dataset_dict["test"]) + + +def test_filter_iterable(iterable_dataset_dict): + dataset_dict_filtered = iterable_dataset_dict.filter( + function=lambda doc: len(doc["text"]) > 15, + split="train", + ) + docs_train = list(dataset_dict_filtered["train"]) + assert len(docs_train) == 2 + assert all(len(doc.text) > 15 for doc in docs_train) + + +def test_filter_unknown_dataset_type(): + dataset_dict = DatasetDict({"train": "foo"}) + with pytest.raises(TypeError) as excinfo: + dataset_dict.filter(function=lambda doc: True, split="train") + assert str(excinfo.value) == "dataset must be of type Dataset, but is " + + +def test_filter_noop(dataset_dict): + # passing no filter function should be a noop + dataset_dict_filtered = dataset_dict.filter(split="train") + assert len(dataset_dict_filtered["train"]) == len(dataset_dict["train"]) == 3 + assert len(dataset_dict_filtered["validation"]) == len(dataset_dict["validation"]) == 3 + assert len(dataset_dict_filtered["test"]) == len(dataset_dict["test"]) == 3 + assert_doc_lists_equal(dataset_dict_filtered["train"], dataset_dict["train"]) + assert_doc_lists_equal(dataset_dict_filtered["validation"], dataset_dict["validation"]) + assert_doc_lists_equal(dataset_dict_filtered["test"], dataset_dict["test"]) + + +@pytest.mark.parametrize( + # we can either provide ids or a filter function + "ids,filter_function", + [ + (["1", "2"], None), + (None, lambda doc: doc["id"] in ["1", "2"]), + ], +) +def test_move_to_new_split(dataset_dict, ids, filter_function): + # move the second and third document from train to new_validation + dataset_dict_moved = dataset_dict.move_to_new_split( + ids=ids, + filter_function=filter_function, + source_split="train", + target_split="new_validation", + ) + assert len(dataset_dict_moved["train"]) == 1 + assert len(dataset_dict_moved["new_validation"]) == 2 + assert_doc_lists_equal(dataset_dict_moved["train"], dataset_dict["train"][:1]) + + # the remaining splits should be unchanged + assert len(dataset_dict_moved["validation"]) == len(dataset_dict["validation"]) == 3 + assert len(dataset_dict_moved["test"]) == len(dataset_dict["test"]) == 3 + assert_doc_lists_equal(dataset_dict_moved["validation"], dataset_dict["validation"]) + assert_doc_lists_equal(dataset_dict_moved["test"], dataset_dict["test"]) + + +def test_move_to_new_split_missing_arguments(dataset_dict): + with pytest.raises(ValueError) as excinfo: + dataset_dict.move_to_new_split( + ids=None, + filter_function=None, + source_split="train", + target_split="new_validation", + ) + assert str(excinfo.value) == "please provide either a list of ids or a filter function" + + +def test_cast_document_type(dataset_dict): + dataset_dict_cast = dataset_dict.cast_document_type(TextBasedDocument) + assert dataset_dict_cast.document_type == TextBasedDocument + for split in dataset_dict_cast: + assert all(isinstance(doc, TextBasedDocument) for doc in dataset_dict_cast[split]) + + +@dataclass +class TestDocumentWithLabel(TextDocument): + label: AnnotationList[Label] = annotation_field() + + +def convert_to_document_with_label(document: TestDocument) -> TestDocumentWithLabel: + result = TestDocumentWithLabel(text=document.text) + result.label.append(Label(label="label")) + return result + + +def test_register_document_converter(dataset_dict): + dataset_dict.register_document_converter( + convert_to_document_with_label, document_type=TestDocumentWithLabel + ) + + for name, split in dataset_dict.items(): + assert split.document_converters[TestDocumentWithLabel] == convert_to_document_with_label + + +def test_register_document_converter_resolve(dataset_dict): + dataset_dict.register_document_converter( + f"{TEST_CLASS_PREFIX}.convert_to_document_with_label", + document_type=f"{TEST_CLASS_PREFIX}.TestDocumentWithLabel", + ) + + for name, split in dataset_dict.items(): + assert split.document_converters[TestDocumentWithLabel] == convert_to_document_with_label + + +class NoDocument: + pass + + +def test_register_document_converter_resolve_wrong_document_type(dataset_dict): + with pytest.raises(TypeError) as excinfo: + dataset_dict.register_document_converter( + convert_to_document_with_label, document_type=f"{TEST_CLASS_PREFIX}.NoDocument" + ) + assert ( + str(excinfo.value) + == f"document_type must be or resolve to a subclass of Document, but is '{TEST_CLASS_PREFIX}.NoDocument'" + ) + + +def test_register_document_converter_resolve_wrong_converter(dataset_dict): + with pytest.raises(TypeError) as excinfo: + dataset_dict.register_document_converter([1, 2, 3], document_type=TestDocumentWithLabel) + assert str(excinfo.value) == "converter must be a callable or a dict, but is " + + +def test_to_document_type(dataset_dict): + dataset_dict.register_document_converter(convert_to_document_with_label) + dataset_dict_converted = dataset_dict.to_document_type(TestDocumentWithLabel) + assert dataset_dict_converted.document_type == TestDocumentWithLabel + for split in dataset_dict_converted.values(): + assert all(isinstance(doc, TestDocumentWithLabel) for doc in split) + + +def test_to_document_resolve(dataset_dict): + dataset_dict.register_document_converter(convert_to_document_with_label) + dataset_dict_converted = dataset_dict.to_document_type( + f"{TEST_CLASS_PREFIX}.TestDocumentWithLabel" + ) + assert dataset_dict_converted.document_type == TestDocumentWithLabel + for split in dataset_dict_converted.values(): + assert all(isinstance(doc, TestDocumentWithLabel) for doc in split) + + +def test_to_document_type_resolve_wrong_document_type(dataset_dict): + dataset_dict.register_document_converter(convert_to_document_with_label) + with pytest.raises(TypeError) as excinfo: + dataset_dict.to_document_type(f"{TEST_CLASS_PREFIX}.NoDocument") + assert ( + str(excinfo.value) + == f"document_type must be a document type or a string that can be resolved to such a type, but got " + f"{TEST_CLASS_PREFIX}.NoDocument." + ) + + +def test_to_document_type_noop(dataset_dict): + assert dataset_dict.document_type == DocumentWithEntitiesAndRelations + dataset_dict_converted = dataset_dict.to_document_type(DocumentWithEntitiesAndRelations) + assert dataset_dict_converted.document_type == DocumentWithEntitiesAndRelations + assert dataset_dict_converted == dataset_dict From c45c20ccd2b87ddf85b3ce3c3682983a13fb75da Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 7 Nov 2023 00:46:42 +0100 Subject: [PATCH 05/12] use pytorch-ie 0.26.1.dev1699314147 (pre-release on test pypi) --- poetry.lock | 29 +++++++++++++++-------------- pyproject.toml | 2 +- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/poetry.lock b/poetry.lock index b9fcea87..46f94b0d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1217,27 +1217,28 @@ six = ">=1.5" [[package]] name = "pytorch-ie" -version = "0.26.0" +version = "0.26.1.dev1699314147" description = "State-of-the-art Information Extraction in PyTorch" optional = false -python-versions = "^3.9" -files = [] -develop = false +python-versions = ">=3.9,<4.0" +files = [ + {file = "pytorch_ie-0.26.1.dev1699314147-py3-none-any.whl", hash = "sha256:10e8aa445058b29cf4ae2a8e905f5614f803afb29ebad93f75e1f4fb1fde32bd"}, + {file = "pytorch_ie-0.26.1.dev1699314147.tar.gz", hash = "sha256:01a398eef1874b5de464941774bbbeeef0ef862b5435a4523ade110f7707934e"}, +] [package.dependencies] -absl-py = "^1.0.0" -datasets = "^2.13" +absl-py = ">=1.0.0,<2.0.0" +datasets = ">=2.13,<3.0" fsspec = "<2023.9.0" -pytorch-lightning = "^2" +pytorch-lightning = ">=2,<3" torch = ">=1.10" -torchmetrics = "^1" -transformers = "^4.18" +torchmetrics = ">=1,<2" +transformers = ">=4.18,<5.0" [package.source] -type = "git" -url = "https://github.com/ChristophAlt/pytorch-ie.git" -reference = "decouple_pie_dataset" -resolved_reference = "15862d54de0066d2dee0f69fd7bf27527bbfb81d" +type = "legacy" +url = "https://test.pypi.org/simple" +reference = "pre-release" [[package]] name = "pytorch-lightning" @@ -2161,4 +2162,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "405005251a71a59558ec8b03c8b5bc44e84682b9035fe6a1053f864dec2c7e54" +content-hash = "8e15677d9874e0189266d03d9fb22cf968f9b1437476fd41ab0010d2551a6975" diff --git a/pyproject.toml b/pyproject.toml index 47162a67..9bfd2a7f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.9" #pytorch-ie = ">=0.26.0,<0.27.0" -pytorch-ie = { git = "https://github.com/ChristophAlt/pytorch-ie.git", branch = "decouple_pie_dataset" } +pytorch-ie = { version = "0.26.1.dev1699314147", source = "pre-release" } [tool.poetry.group.dev.dependencies] torch = {version = "^2.1.0+cpu", source = "pytorch"} From 1b962121fdded83999d8941b7296dbeb90e1705f Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 7 Nov 2023 01:46:51 +0100 Subject: [PATCH 06/12] upgrade pytorch-ie to 0.27.0 --- poetry.lock | 365 +++++++++++++++++++++++++------------------------ pyproject.toml | 3 +- 2 files changed, 184 insertions(+), 184 deletions(-) diff --git a/poetry.lock b/poetry.lock index 46f94b0d..3c4aa10d 100644 --- a/poetry.lock +++ b/poetry.lock @@ -186,101 +186,101 @@ files = [ [[package]] name = "charset-normalizer" -version = "3.3.1" +version = "3.3.2" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7.0" files = [ - {file = "charset-normalizer-3.3.1.tar.gz", hash = "sha256:d9137a876020661972ca6eec0766d81aef8a5627df628b664b234b73396e727e"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8aee051c89e13565c6bd366813c386939f8e928af93c29fda4af86d25b73d8f8"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:352a88c3df0d1fa886562384b86f9a9e27563d4704ee0e9d56ec6fcd270ea690"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:223b4d54561c01048f657fa6ce41461d5ad8ff128b9678cfe8b2ecd951e3f8a2"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f861d94c2a450b974b86093c6c027888627b8082f1299dfd5a4bae8e2292821"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1171ef1fc5ab4693c5d151ae0fdad7f7349920eabbaca6271f95969fa0756c2d"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28f512b9a33235545fbbdac6a330a510b63be278a50071a336afc1b78781b147"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0e842112fe3f1a4ffcf64b06dc4c61a88441c2f02f373367f7b4c1aa9be2ad5"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f9bc2ce123637a60ebe819f9fccc614da1bcc05798bbbaf2dd4ec91f3e08846"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f194cce575e59ffe442c10a360182a986535fd90b57f7debfaa5c845c409ecc3"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:9a74041ba0bfa9bc9b9bb2cd3238a6ab3b7618e759b41bd15b5f6ad958d17605"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:b578cbe580e3b41ad17b1c428f382c814b32a6ce90f2d8e39e2e635d49e498d1"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:6db3cfb9b4fcecb4390db154e75b49578c87a3b9979b40cdf90d7e4b945656e1"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:debb633f3f7856f95ad957d9b9c781f8e2c6303ef21724ec94bea2ce2fcbd056"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-win32.whl", hash = "sha256:87071618d3d8ec8b186d53cb6e66955ef2a0e4fa63ccd3709c0c90ac5a43520f"}, - {file = "charset_normalizer-3.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:e372d7dfd154009142631de2d316adad3cc1c36c32a38b16a4751ba78da2a397"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:ae4070f741f8d809075ef697877fd350ecf0b7c5837ed68738607ee0a2c572cf"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:58e875eb7016fd014c0eea46c6fa92b87b62c0cb31b9feae25cbbe62c919f54d"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:dbd95e300367aa0827496fe75a1766d198d34385a58f97683fe6e07f89ca3e3c"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:de0b4caa1c8a21394e8ce971997614a17648f94e1cd0640fbd6b4d14cab13a72"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:985c7965f62f6f32bf432e2681173db41336a9c2611693247069288bcb0c7f8b"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a15c1fe6d26e83fd2e5972425a772cca158eae58b05d4a25a4e474c221053e2d"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ae55d592b02c4349525b6ed8f74c692509e5adffa842e582c0f861751701a673"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:be4d9c2770044a59715eb57c1144dedea7c5d5ae80c68fb9959515037cde2008"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:851cf693fb3aaef71031237cd68699dded198657ec1e76a76eb8be58c03a5d1f"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:31bbaba7218904d2eabecf4feec0d07469284e952a27400f23b6628439439fa7"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:871d045d6ccc181fd863a3cd66ee8e395523ebfbc57f85f91f035f50cee8e3d4"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:501adc5eb6cd5f40a6f77fbd90e5ab915c8fd6e8c614af2db5561e16c600d6f3"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f5fb672c396d826ca16a022ac04c9dce74e00a1c344f6ad1a0fdc1ba1f332213"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-win32.whl", hash = "sha256:bb06098d019766ca16fc915ecaa455c1f1cd594204e7f840cd6258237b5079a8"}, - {file = "charset_normalizer-3.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:8af5a8917b8af42295e86b64903156b4f110a30dca5f3b5aedea123fbd638bff"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:7ae8e5142dcc7a49168f4055255dbcced01dc1714a90a21f87448dc8d90617d1"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5b70bab78accbc672f50e878a5b73ca692f45f5b5e25c8066d748c09405e6a55"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5ceca5876032362ae73b83347be8b5dbd2d1faf3358deb38c9c88776779b2e2f"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:34d95638ff3613849f473afc33f65c401a89f3b9528d0d213c7037c398a51296"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9edbe6a5bf8b56a4a84533ba2b2f489d0046e755c29616ef8830f9e7d9cf5728"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f6a02a3c7950cafaadcd46a226ad9e12fc9744652cc69f9e5534f98b47f3bbcf"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10b8dd31e10f32410751b3430996f9807fc4d1587ca69772e2aa940a82ab571a"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:edc0202099ea1d82844316604e17d2b175044f9bcb6b398aab781eba957224bd"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b891a2f68e09c5ef989007fac11476ed33c5c9994449a4e2c3386529d703dc8b"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:71ef3b9be10070360f289aea4838c784f8b851be3ba58cf796262b57775c2f14"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:55602981b2dbf8184c098bc10287e8c245e351cd4fdcad050bd7199d5a8bf514"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:46fb9970aa5eeca547d7aa0de5d4b124a288b42eaefac677bde805013c95725c"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:520b7a142d2524f999447b3a0cf95115df81c4f33003c51a6ab637cbda9d0bf4"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-win32.whl", hash = "sha256:8ec8ef42c6cd5856a7613dcd1eaf21e5573b2185263d87d27c8edcae33b62a61"}, - {file = "charset_normalizer-3.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:baec8148d6b8bd5cee1ae138ba658c71f5b03e0d69d5907703e3e1df96db5e41"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63a6f59e2d01310f754c270e4a257426fe5a591dc487f1983b3bbe793cf6bac6"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d6bfc32a68bc0933819cfdfe45f9abc3cae3877e1d90aac7259d57e6e0f85b1"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4f3100d86dcd03c03f7e9c3fdb23d92e32abbca07e7c13ebd7ddfbcb06f5991f"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39b70a6f88eebe239fa775190796d55a33cfb6d36b9ffdd37843f7c4c1b5dc67"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4e12f8ee80aa35e746230a2af83e81bd6b52daa92a8afaef4fea4a2ce9b9f4fa"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b6cefa579e1237ce198619b76eaa148b71894fb0d6bcf9024460f9bf30fd228"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:61f1e3fb621f5420523abb71f5771a204b33c21d31e7d9d86881b2cffe92c47c"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:4f6e2a839f83a6a76854d12dbebde50e4b1afa63e27761549d006fa53e9aa80e"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:1ec937546cad86d0dce5396748bf392bb7b62a9eeb8c66efac60e947697f0e58"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:82ca51ff0fc5b641a2d4e1cc8c5ff108699b7a56d7f3ad6f6da9dbb6f0145b48"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:633968254f8d421e70f91c6ebe71ed0ab140220469cf87a9857e21c16687c034"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-win32.whl", hash = "sha256:c0c72d34e7de5604df0fde3644cc079feee5e55464967d10b24b1de268deceb9"}, - {file = "charset_normalizer-3.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:63accd11149c0f9a99e3bc095bbdb5a464862d77a7e309ad5938fbc8721235ae"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5a3580a4fdc4ac05f9e53c57f965e3594b2f99796231380adb2baaab96e22761"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2465aa50c9299d615d757c1c888bc6fef384b7c4aec81c05a0172b4400f98557"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:cb7cd68814308aade9d0c93c5bd2ade9f9441666f8ba5aa9c2d4b389cb5e2a45"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:91e43805ccafa0a91831f9cd5443aa34528c0c3f2cc48c4cb3d9a7721053874b"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:854cc74367180beb327ab9d00f964f6d91da06450b0855cbbb09187bcdb02de5"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c15070ebf11b8b7fd1bfff7217e9324963c82dbdf6182ff7050519e350e7ad9f"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c4c99f98fc3a1835af8179dcc9013f93594d0670e2fa80c83aa36346ee763d2"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fb765362688821404ad6cf86772fc54993ec11577cd5a92ac44b4c2ba52155b"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:dced27917823df984fe0c80a5c4ad75cf58df0fbfae890bc08004cd3888922a2"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a66bcdf19c1a523e41b8e9d53d0cedbfbac2e93c649a2e9502cb26c014d0980c"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ecd26be9f112c4f96718290c10f4caea6cc798459a3a76636b817a0ed7874e42"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:3f70fd716855cd3b855316b226a1ac8bdb3caf4f7ea96edcccc6f484217c9597"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:17a866d61259c7de1bdadef418a37755050ddb4b922df8b356503234fff7932c"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-win32.whl", hash = "sha256:548eefad783ed787b38cb6f9a574bd8664468cc76d1538215d510a3cd41406cb"}, - {file = "charset_normalizer-3.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:45f053a0ece92c734d874861ffe6e3cc92150e32136dd59ab1fb070575189c97"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bc791ec3fd0c4309a753f95bb6c749ef0d8ea3aea91f07ee1cf06b7b02118f2f"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:0c8c61fb505c7dad1d251c284e712d4e0372cef3b067f7ddf82a7fa82e1e9a93"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:2c092be3885a1b7899cd85ce24acedc1034199d6fca1483fa2c3a35c86e43041"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2000c54c395d9e5e44c99dc7c20a64dc371f777faf8bae4919ad3e99ce5253e"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4cb50a0335382aac15c31b61d8531bc9bb657cfd848b1d7158009472189f3d62"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c30187840d36d0ba2893bc3271a36a517a717f9fd383a98e2697ee890a37c273"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe81b35c33772e56f4b6cf62cf4aedc1762ef7162a31e6ac7fe5e40d0149eb67"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d0bf89afcbcf4d1bb2652f6580e5e55a840fdf87384f6063c4a4f0c95e378656"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:06cf46bdff72f58645434d467bf5228080801298fbba19fe268a01b4534467f5"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:3c66df3f41abee950d6638adc7eac4730a306b022570f71dd0bd6ba53503ab57"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd805513198304026bd379d1d516afbf6c3c13f4382134a2c526b8b854da1c2e"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:9505dc359edb6a330efcd2be825fdb73ee3e628d9010597aa1aee5aa63442e97"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:31445f38053476a0c4e6d12b047b08ced81e2c7c712e5a1ad97bc913256f91b2"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-win32.whl", hash = "sha256:bd28b31730f0e982ace8663d108e01199098432a30a4c410d06fe08fdb9e93f4"}, - {file = "charset_normalizer-3.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:555fe186da0068d3354cdf4bbcbc609b0ecae4d04c921cc13e209eece7720727"}, - {file = "charset_normalizer-3.3.1-py3-none-any.whl", hash = "sha256:800561453acdecedaac137bf09cd719c7a440b6800ec182f077bb8e7025fb708"}, + {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-win32.whl", hash = "sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-win32.whl", hash = "sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-win32.whl", hash = "sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d"}, + {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, ] [[package]] @@ -444,19 +444,19 @@ test = ["pytest (>=6)"] [[package]] name = "filelock" -version = "3.12.4" +version = "3.13.1" description = "A platform independent file lock." optional = false python-versions = ">=3.8" files = [ - {file = "filelock-3.12.4-py3-none-any.whl", hash = "sha256:08c21d87ded6e2b9da6728c3dff51baf1dcecf973b768ef35bcbc3447edb9ad4"}, - {file = "filelock-3.12.4.tar.gz", hash = "sha256:2e6f249f1f3654291606e046b09f1fd5eac39b360664c27f5aad072012f8bcbd"}, + {file = "filelock-3.13.1-py3-none-any.whl", hash = "sha256:57dbda9b35157b05fb3e58ee91448612eb674172fab98ee235ccb0b5bee19a1c"}, + {file = "filelock-3.13.1.tar.gz", hash = "sha256:521f5f56c50f8426f5e03ad3b281b490a87ef15bc6c526f168290f0c7148d44e"}, ] [package.extras] -docs = ["furo (>=2023.7.26)", "sphinx (>=7.1.2)", "sphinx-autodoc-typehints (>=1.24)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.3)", "diff-cover (>=7.7)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "pytest-timeout (>=2.1)"] -typing = ["typing-extensions (>=4.7.1)"] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.24)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8)", "pytest (>=7.4.3)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)"] +typing = ["typing-extensions (>=4.8)"] [[package]] name = "frozenlist" @@ -602,13 +602,13 @@ typing = ["pydantic (<2.0)", "types-PyYAML", "types-requests", "types-simplejson [[package]] name = "identify" -version = "2.5.30" +version = "2.5.31" description = "File identification library for Python" optional = false python-versions = ">=3.8" files = [ - {file = "identify-2.5.30-py2.py3-none-any.whl", hash = "sha256:afe67f26ae29bab007ec21b03d4114f41316ab9dd15aa8736a167481e108da54"}, - {file = "identify-2.5.30.tar.gz", hash = "sha256:f302a4256a15c849b91cfcdcec052a8ce914634b2f77ae87dad29cd749f2d88d"}, + {file = "identify-2.5.31-py2.py3-none-any.whl", hash = "sha256:90199cb9e7bd3c5407a9b7e81b4abec4bb9d249991c79439ec8af740afc6293d"}, + {file = "identify-2.5.31.tar.gz", hash = "sha256:7736b3c7a28233637e3c36550646fc6389bedd74ae84cb788200cc8e2dd60b75"}, ] [package.extras] @@ -872,13 +872,13 @@ dill = ">=0.3.7" [[package]] name = "networkx" -version = "3.2" +version = "3.2.1" description = "Python package for creating and manipulating graphs and networks" optional = false python-versions = ">=3.9" files = [ - {file = "networkx-3.2-py3-none-any.whl", hash = "sha256:8b25f564bd28f94ac821c58b04ae1a3109e73b001a7d476e4bb0d00d63706bf8"}, - {file = "networkx-3.2.tar.gz", hash = "sha256:bda29edf392d9bfa5602034c767d28549214ec45f620081f0b74dc036a1fbbc1"}, + {file = "networkx-3.2.1-py3-none-any.whl", hash = "sha256:f18c69adc97877c42332c170849c96cefa91881c99a7cb3e95b7c659ebdc1ec2"}, + {file = "networkx-3.2.1.tar.gz", hash = "sha256:9f1bb5cf3409bf324e0a722c20bdb4c20ee39bf1c30ce8ae499c8502b0b5e0c6"}, ] [package.extras] @@ -1007,42 +1007,42 @@ xml = ["lxml (>=4.8.0)"] [[package]] name = "pandas" -version = "2.1.1" +version = "2.1.2" description = "Powerful data structures for data analysis, time series, and statistics" optional = false python-versions = ">=3.9" files = [ - {file = "pandas-2.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:58d997dbee0d4b64f3cb881a24f918b5f25dd64ddf31f467bb9b67ae4c63a1e4"}, - {file = "pandas-2.1.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02304e11582c5d090e5a52aec726f31fe3f42895d6bfc1f28738f9b64b6f0614"}, - {file = "pandas-2.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffa8f0966de2c22de408d0e322db2faed6f6e74265aa0856f3824813cf124363"}, - {file = "pandas-2.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c1f84c144dee086fe4f04a472b5cd51e680f061adf75c1ae4fc3a9275560f8f4"}, - {file = "pandas-2.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:75ce97667d06d69396d72be074f0556698c7f662029322027c226fd7a26965cb"}, - {file = "pandas-2.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:4c3f32fd7c4dccd035f71734df39231ac1a6ff95e8bdab8d891167197b7018d2"}, - {file = "pandas-2.1.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9e2959720b70e106bb1d8b6eadd8ecd7c8e99ccdbe03ee03260877184bb2877d"}, - {file = "pandas-2.1.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:25e8474a8eb258e391e30c288eecec565bfed3e026f312b0cbd709a63906b6f8"}, - {file = "pandas-2.1.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b8bd1685556f3374520466998929bade3076aeae77c3e67ada5ed2b90b4de7f0"}, - {file = "pandas-2.1.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc3657869c7902810f32bd072f0740487f9e030c1a3ab03e0af093db35a9d14e"}, - {file = "pandas-2.1.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:05674536bd477af36aa2effd4ec8f71b92234ce0cc174de34fd21e2ee99adbc2"}, - {file = "pandas-2.1.1-cp311-cp311-win_amd64.whl", hash = "sha256:b407381258a667df49d58a1b637be33e514b07f9285feb27769cedb3ab3d0b3a"}, - {file = "pandas-2.1.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c747793c4e9dcece7bb20156179529898abf505fe32cb40c4052107a3c620b49"}, - {file = "pandas-2.1.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3bcad1e6fb34b727b016775bea407311f7721db87e5b409e6542f4546a4951ea"}, - {file = "pandas-2.1.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f5ec7740f9ccb90aec64edd71434711f58ee0ea7f5ed4ac48be11cfa9abf7317"}, - {file = "pandas-2.1.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:29deb61de5a8a93bdd033df328441a79fcf8dd3c12d5ed0b41a395eef9cd76f0"}, - {file = "pandas-2.1.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4f99bebf19b7e03cf80a4e770a3e65eee9dd4e2679039f542d7c1ace7b7b1daa"}, - {file = "pandas-2.1.1-cp312-cp312-win_amd64.whl", hash = "sha256:84e7e910096416adec68075dc87b986ff202920fb8704e6d9c8c9897fe7332d6"}, - {file = "pandas-2.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:366da7b0e540d1b908886d4feb3d951f2f1e572e655c1160f5fde28ad4abb750"}, - {file = "pandas-2.1.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9e50e72b667415a816ac27dfcfe686dc5a0b02202e06196b943d54c4f9c7693e"}, - {file = "pandas-2.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc1ab6a25da197f03ebe6d8fa17273126120874386b4ac11c1d687df288542dd"}, - {file = "pandas-2.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0dbfea0dd3901ad4ce2306575c54348d98499c95be01b8d885a2737fe4d7a98"}, - {file = "pandas-2.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0489b0e6aa3d907e909aef92975edae89b1ee1654db5eafb9be633b0124abe97"}, - {file = "pandas-2.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:4cdb0fab0400c2cb46dafcf1a0fe084c8bb2480a1fa8d81e19d15e12e6d4ded2"}, - {file = "pandas-2.1.1.tar.gz", hash = "sha256:fecb198dc389429be557cde50a2d46da8434a17fe37d7d41ff102e3987fd947b"}, + {file = "pandas-2.1.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:24057459f19db9ebb02984c6fdd164a970b31a95f38e4a49cf7615b36a1b532c"}, + {file = "pandas-2.1.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a6cf8fcc8a63d333970b950a7331a30544cf59b1a97baf0a7409e09eafc1ac38"}, + {file = "pandas-2.1.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6ae6ffbd9d614c20d028c7117ee911fc4e266b4dca2065d5c5909e401f8ff683"}, + {file = "pandas-2.1.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eff794eeb7883c5aefb1ed572e7ff533ae779f6c6277849eab9e77986e352688"}, + {file = "pandas-2.1.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:02954e285e8e2f4006b6f22be6f0df1f1c3c97adbb7ed211c6b483426f20d5c8"}, + {file = "pandas-2.1.2-cp310-cp310-win_amd64.whl", hash = "sha256:5b40c9f494e1f27588c369b9e4a6ca19cd924b3a0e1ef9ef1a8e30a07a438f43"}, + {file = "pandas-2.1.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:08d287b68fd28906a94564f15118a7ca8c242e50ae7f8bd91130c362b2108a81"}, + {file = "pandas-2.1.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:bbd98dcdcd32f408947afdb3f7434fade6edd408c3077bbce7bd840d654d92c6"}, + {file = "pandas-2.1.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e90c95abb3285d06f6e4feedafc134306a8eced93cb78e08cf50e224d5ce22e2"}, + {file = "pandas-2.1.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:52867d69a54e71666cd184b04e839cff7dfc8ed0cd6b936995117fdae8790b69"}, + {file = "pandas-2.1.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8d0382645ede2fde352da2a885aac28ec37d38587864c0689b4b2361d17b1d4c"}, + {file = "pandas-2.1.2-cp311-cp311-win_amd64.whl", hash = "sha256:65177d1c519b55e5b7f094c660ed357bb7d86e799686bb71653b8a4803d8ff0d"}, + {file = "pandas-2.1.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5aa6b86802e8cf7716bf4b4b5a3c99b12d34e9c6a9d06dad254447a620437931"}, + {file = "pandas-2.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d594e2ce51b8e0b4074e6644758865dc2bb13fd654450c1eae51201260a539f1"}, + {file = "pandas-2.1.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3223f997b6d2ebf9c010260cf3d889848a93f5d22bb4d14cd32638b3d8bba7ad"}, + {file = "pandas-2.1.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fc4944dc004ca6cc701dfa19afb8bdb26ad36b9bed5bcec617d2a11e9cae6902"}, + {file = "pandas-2.1.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3f76280ce8ec216dde336e55b2b82e883401cf466da0fe3be317c03fb8ee7c7d"}, + {file = "pandas-2.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:7ad20d24acf3a0042512b7e8d8fdc2e827126ed519d6bd1ed8e6c14ec8a2c813"}, + {file = "pandas-2.1.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:021f09c15e1381e202d95d4a21ece8e7f2bf1388b6d7e9cae09dfe27bd2043d1"}, + {file = "pandas-2.1.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e7f12b2de0060b0b858cfec0016e7d980ae5bae455a1746bfcc70929100ee633"}, + {file = "pandas-2.1.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83c166b9bb27c1715bed94495d9598a7f02950b4749dba9349c1dd2cbf10729d"}, + {file = "pandas-2.1.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:25c9976c17311388fcd953cb3d0697999b2205333f4e11e669d90ff8d830d429"}, + {file = "pandas-2.1.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:851b5afbb0d62f6129ae891b533aa508cc357d5892c240c91933d945fff15731"}, + {file = "pandas-2.1.2-cp39-cp39-win_amd64.whl", hash = "sha256:e78507adcc730533619de07bfdd1c62b2918a68cd4419ea386e28abf7f6a1e5c"}, + {file = "pandas-2.1.2.tar.gz", hash = "sha256:52897edc2774d2779fbeb6880d2cfb305daa0b1a29c16b91f531a18918a6e0f3"}, ] [package.dependencies] numpy = [ - {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.23.2", markers = "python_version == \"3.11\""}, + {version = ">=1.22.4,<2", markers = "python_version < \"3.11\""}, + {version = ">=1.23.2,<2", markers = "python_version == \"3.11\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -1122,40 +1122,47 @@ virtualenv = ">=20.10.0" [[package]] name = "pyarrow" -version = "13.0.0" +version = "14.0.0" description = "Python library for Apache Arrow" optional = false python-versions = ">=3.8" files = [ - {file = "pyarrow-13.0.0-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:1afcc2c33f31f6fb25c92d50a86b7a9f076d38acbcb6f9e74349636109550148"}, - {file = "pyarrow-13.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:70fa38cdc66b2fc1349a082987f2b499d51d072faaa6b600f71931150de2e0e3"}, - {file = "pyarrow-13.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cd57b13a6466822498238877892a9b287b0a58c2e81e4bdb0b596dbb151cbb73"}, - {file = "pyarrow-13.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f8ce69f7bf01de2e2764e14df45b8404fc6f1a5ed9871e8e08a12169f87b7a26"}, - {file = "pyarrow-13.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:588f0d2da6cf1b1680974d63be09a6530fd1bd825dc87f76e162404779a157dc"}, - {file = "pyarrow-13.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:6241afd72b628787b4abea39e238e3ff9f34165273fad306c7acf780dd850956"}, - {file = "pyarrow-13.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:fda7857e35993673fcda603c07d43889fca60a5b254052a462653f8656c64f44"}, - {file = "pyarrow-13.0.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:aac0ae0146a9bfa5e12d87dda89d9ef7c57a96210b899459fc2f785303dcbb67"}, - {file = "pyarrow-13.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d7759994217c86c161c6a8060509cfdf782b952163569606bb373828afdd82e8"}, - {file = "pyarrow-13.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:868a073fd0ff6468ae7d869b5fc1f54de5c4255b37f44fb890385eb68b68f95d"}, - {file = "pyarrow-13.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:51be67e29f3cfcde263a113c28e96aa04362ed8229cb7c6e5f5c719003659d33"}, - {file = "pyarrow-13.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:d1b4e7176443d12610874bb84d0060bf080f000ea9ed7c84b2801df851320295"}, - {file = "pyarrow-13.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:69b6f9a089d116a82c3ed819eea8fe67dae6105f0d81eaf0fdd5e60d0c6e0944"}, - {file = "pyarrow-13.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:ab1268db81aeb241200e321e220e7cd769762f386f92f61b898352dd27e402ce"}, - {file = "pyarrow-13.0.0-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:ee7490f0f3f16a6c38f8c680949551053c8194e68de5046e6c288e396dccee80"}, - {file = "pyarrow-13.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e3ad79455c197a36eefbd90ad4aa832bece7f830a64396c15c61a0985e337287"}, - {file = "pyarrow-13.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:68fcd2dc1b7d9310b29a15949cdd0cb9bc34b6de767aff979ebf546020bf0ba0"}, - {file = "pyarrow-13.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dc6fd330fd574c51d10638e63c0d00ab456498fc804c9d01f2a61b9264f2c5b2"}, - {file = "pyarrow-13.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:e66442e084979a97bb66939e18f7b8709e4ac5f887e636aba29486ffbf373763"}, - {file = "pyarrow-13.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:0f6eff839a9e40e9c5610d3ff8c5bdd2f10303408312caf4c8003285d0b49565"}, - {file = "pyarrow-13.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b30a27f1cddf5c6efcb67e598d7823a1e253d743d92ac32ec1eb4b6a1417867"}, - {file = "pyarrow-13.0.0-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:09552dad5cf3de2dc0aba1c7c4b470754c69bd821f5faafc3d774bedc3b04bb7"}, - {file = "pyarrow-13.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:3896ae6c205d73ad192d2fc1489cd0edfab9f12867c85b4c277af4d37383c18c"}, - {file = "pyarrow-13.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6647444b21cb5e68b593b970b2a9a07748dd74ea457c7dadaa15fd469c48ada1"}, - {file = "pyarrow-13.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47663efc9c395e31d09c6aacfa860f4473815ad6804311c5433f7085415d62a7"}, - {file = "pyarrow-13.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:b9ba6b6d34bd2563345488cf444510588ea42ad5613df3b3509f48eb80250afd"}, - {file = "pyarrow-13.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:d00d374a5625beeb448a7fa23060df79adb596074beb3ddc1838adb647b6ef09"}, - {file = "pyarrow-13.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:c51afd87c35c8331b56f796eff954b9c7f8d4b7fef5903daf4e05fcf017d23a8"}, - {file = "pyarrow-13.0.0.tar.gz", hash = "sha256:83333726e83ed44b0ac94d8d7a21bbdee4a05029c3b1e8db58a863eec8fd8a33"}, + {file = "pyarrow-14.0.0-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:4fce1db17efbc453080c5b306f021926de7c636456a128328797e574c151f81a"}, + {file = "pyarrow-14.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:28de7c05b4d7a71ec660360639cc9b65ceb1175e0e9d4dfccd879a1545bc38f7"}, + {file = "pyarrow-14.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1541e9209c094e7f4d7b43fdd9de3a8c71d3069cf6fc03b59bf5774042411849"}, + {file = "pyarrow-14.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8c05e6c45d303c80e41ab04996430a0251321f70986ed51213903ea7bc0b7efd"}, + {file = "pyarrow-14.0.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:426ffec63ab9b4dff23dec51be2150e3a4a99eb38e66c10a70e2c48779fe9c9d"}, + {file = "pyarrow-14.0.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:968844f591902160bd3c9ee240ce8822a3b4e7de731e91daea76ad43fe0ff062"}, + {file = "pyarrow-14.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:dcedbc0b4ea955c530145acfe99e324875c386419a09db150291a24cb01aeb81"}, + {file = "pyarrow-14.0.0-cp311-cp311-macosx_10_14_x86_64.whl", hash = "sha256:97993a12aacc781efad9c92d4545a877e803c4d106d34237ec4ce987bec825a3"}, + {file = "pyarrow-14.0.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:80225768d94024d59a31320374f5e6abf8899866c958dfb4f4ea8e2d9ec91bde"}, + {file = "pyarrow-14.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b61546977a8bd7e3d0c697ede723341ef4737e761af2239aef6e1db447f97727"}, + {file = "pyarrow-14.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42509e6c93b4a1c8ae8ccd939a43f437097783fe130a1991497a6a1abbba026f"}, + {file = "pyarrow-14.0.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:3eccce331a1392e46573f2ce849a9ee3c074e0d7008e9be0b44566ac149fd6a1"}, + {file = "pyarrow-14.0.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:ecc463c45f2b6b36431f5f2025842245e8c15afe4d42072230575785f3bb00c6"}, + {file = "pyarrow-14.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:4362ed90def81640addcd521811dd16a13015f0a8255bec324a41262c1524b6c"}, + {file = "pyarrow-14.0.0-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:2fbb7ab62537782c5ab31aa08db0e1f6de92c2c515fdfc0790128384e919adcb"}, + {file = "pyarrow-14.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ad7095f8f0fe0bfa3d3fca1909b8fa15c70e630b0cc1ff8d35e143f5e2704064"}, + {file = "pyarrow-14.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e6602272fce71c0fb64f266e7cdbe51b93b00c22fc1bb57f2b0cb681c4aeedf4"}, + {file = "pyarrow-14.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b2b8f87951b08a3e72265c8963da3fe4f737bb81290269037e047dd172aa591"}, + {file = "pyarrow-14.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:a1c9675966662a042caebbaafa1ae7fc26291287ebc3da06aa63ad74c323ec30"}, + {file = "pyarrow-14.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:771079fddc0b4440c41af541dbdebc711a7062c93d3c4764476a9442606977db"}, + {file = "pyarrow-14.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:c4096136318de1c4937370c0c365f949961c371201c396d8cc94a353f342069d"}, + {file = "pyarrow-14.0.0-cp38-cp38-macosx_10_14_x86_64.whl", hash = "sha256:6c94056fb5f0ee0bae2206c3f776881e1db2bd0d133d06805755ae7ac5145349"}, + {file = "pyarrow-14.0.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:687d0df1e08876b2d24d42abae129742fc655367e3fe6700aa4d79fcf2e3215e"}, + {file = "pyarrow-14.0.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0f4054e5ee6c88ca256a67fc8b27f9c59bcd385216346265831d462a6069033f"}, + {file = "pyarrow-14.0.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:768b962e4c042ab2c96576ca0757935472e220d11af855c7d0be3279d7fced5f"}, + {file = "pyarrow-14.0.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:77293b1319c7044f68ebfa43db8c929a0a5254ce371f1a0873d343f1460171d0"}, + {file = "pyarrow-14.0.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:d2bc7c53941d85f0133b1bd5a814bca0af213922f50d8a8dc0eed4d9ed477845"}, + {file = "pyarrow-14.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:378955365dd087c285ef4f34ad939d7e551b7715326710e8cd21cfa2ce511bd7"}, + {file = "pyarrow-14.0.0-cp39-cp39-macosx_10_14_x86_64.whl", hash = "sha256:f05e81b4c621e6ad4bcd8f785e3aa1d6c49a935818b809ea6e7bf206a5b1a4e8"}, + {file = "pyarrow-14.0.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6867f6a8057eaef5a7ac6d27fe5518133f67973c5d4295d79a943458350e7c61"}, + {file = "pyarrow-14.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca54b87c46abdfe027f18f959ca388102bd7326c344838f72244807462d091b2"}, + {file = "pyarrow-14.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:35abf61bd0cc9daca3afc715f6ba74ea83d792fa040025352624204bec66bf6a"}, + {file = "pyarrow-14.0.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:65c377523b369f7ef1ba02be814e832443bb3b15065010838f02dae5bdc0f53c"}, + {file = "pyarrow-14.0.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:e8a1e470e4b5f7bda7bede0410291daec55ab69f346d77795d34fd6a45b41579"}, + {file = "pyarrow-14.0.0-cp39-cp39-win_amd64.whl", hash = "sha256:466c1a5a7a4b279cfa363ac34dedd0c3c6af388cec9e6a468ffc095a6627849a"}, + {file = "pyarrow-14.0.0.tar.gz", hash = "sha256:45d3324e1c9871a07de6b4d514ebd73225490963a6dd46c64c465c4b6079fe1e"}, ] [package.dependencies] @@ -1163,13 +1170,13 @@ numpy = ">=1.16.6" [[package]] name = "pytest" -version = "7.4.2" +version = "7.4.3" description = "pytest: simple powerful testing with Python" optional = false python-versions = ">=3.7" files = [ - {file = "pytest-7.4.2-py3-none-any.whl", hash = "sha256:1d881c6124e08ff0a1bb75ba3ec0bfd8b5354a01c194ddd5a0a870a48d99b002"}, - {file = "pytest-7.4.2.tar.gz", hash = "sha256:a766259cfab564a2ad52cb1aae1b881a75c3eb7e34ca3779697c23ed47c47069"}, + {file = "pytest-7.4.3-py3-none-any.whl", hash = "sha256:0d009c083ea859a71b76adf7c1d502e4bc170b80a8ef002da5806527b9591fac"}, + {file = "pytest-7.4.3.tar.gz", hash = "sha256:d989d136982de4e3b29dabcc838ad581c64e8ed52c11fbe86ddebd9da0818cd5"}, ] [package.dependencies] @@ -1217,13 +1224,13 @@ six = ">=1.5" [[package]] name = "pytorch-ie" -version = "0.26.1.dev1699314147" +version = "0.27.0" description = "State-of-the-art Information Extraction in PyTorch" optional = false python-versions = ">=3.9,<4.0" files = [ - {file = "pytorch_ie-0.26.1.dev1699314147-py3-none-any.whl", hash = "sha256:10e8aa445058b29cf4ae2a8e905f5614f803afb29ebad93f75e1f4fb1fde32bd"}, - {file = "pytorch_ie-0.26.1.dev1699314147.tar.gz", hash = "sha256:01a398eef1874b5de464941774bbbeeef0ef862b5435a4523ade110f7707934e"}, + {file = "pytorch_ie-0.27.0-py3-none-any.whl", hash = "sha256:d8eec1183d260e2ad13b3aeea10342bd46ef2b3cefb64fafdbddecc91181c14e"}, + {file = "pytorch_ie-0.27.0.tar.gz", hash = "sha256:6711d8afe63c7754e70dc6bf20427f005edd0b0a60d1d670290b4d81068614a4"}, ] [package.dependencies] @@ -1235,11 +1242,6 @@ torch = ">=1.10" torchmetrics = ">=1,<2" transformers = ">=4.18,<5.0" -[package.source] -type = "legacy" -url = "https://test.pypi.org/simple" -reference = "pre-release" - [[package]] name = "pytorch-lightning" version = "2.1.0" @@ -1829,13 +1831,13 @@ telegram = ["requests"] [[package]] name = "transformers" -version = "4.34.1" +version = "4.35.0" description = "State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow" optional = false python-versions = ">=3.8.0" files = [ - {file = "transformers-4.34.1-py3-none-any.whl", hash = "sha256:d06ac09151d7b845e4a4acd6b143a591d946031ee67b4cbb20693b241920ffc0"}, - {file = "transformers-4.34.1.tar.gz", hash = "sha256:1d0258d5a18063b66005bbe1e3276ec5943d9ab4ab47f020db1fd485cc40ea22"}, + {file = "transformers-4.35.0-py3-none-any.whl", hash = "sha256:45aa9370d7d9ba1c43e6bfa04d7f8b61238497d4b646e573fd95e597fe4040ff"}, + {file = "transformers-4.35.0.tar.gz", hash = "sha256:e4b41763f651282fc979348d3aa148244387ddc9165f4b18455798c770ae23b9"}, ] [package.dependencies] @@ -1857,13 +1859,12 @@ all = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon audio = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] codecarbon = ["codecarbon (==1.2.0)"] deepspeed = ["accelerate (>=0.20.3)", "deepspeed (>=0.9.3)"] -deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "timeout-decorator"] -dev = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] -dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.15)", "urllib3 (<2.0.0)"] -dev-torch = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +deepspeed-testing = ["GitPython (<3.1.19)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "deepspeed (>=0.9.3)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "optuna", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "timeout-decorator"] +dev = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "decord (==0.6.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "flax (>=0.4.1,<=0.7.0)", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] +dev-tensorflow = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "isort (>=5.5.4)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "nltk", "onnxconverter-common", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "tensorboard", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timeout-decorator", "tokenizers (>=0.14,<0.15)", "urllib3 (<2.0.0)"] +dev-torch = ["GitPython (<3.1.19)", "Pillow (<10.0.0)", "accelerate (>=0.20.3)", "beautifulsoup4", "black (>=23.1,<24.0)", "codecarbon (==1.2.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "fugashi (>=1.0)", "hf-doc-builder", "hf-doc-builder (>=0.3.0)", "ipadic (>=1.0.0,<2.0)", "isort (>=5.5.4)", "kenlm", "librosa", "nltk", "onnxruntime (>=1.4.0)", "onnxruntime-tools (>=1.4.2)", "optuna", "parameterized", "phonemizer", "protobuf", "psutil", "pyctcdecode (>=0.4.0)", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "ray[tune]", "rhoknp (>=1.1.0,<1.3.1)", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "ruff (>=0.0.241,<=0.0.259)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "scikit-learn", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "sudachidict-core (>=20220729)", "sudachipy (>=0.6.6)", "tensorboard", "timeout-decorator", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision", "unidic (>=1.0.2)", "unidic-lite (>=1.0.7)", "urllib3 (<2.0.0)"] docs = ["Pillow (<10.0.0)", "accelerate (>=0.20.3)", "av (==9.2.0)", "codecarbon (==1.2.0)", "decord (==0.6.0)", "flax (>=0.4.1,<=0.7.0)", "hf-doc-builder", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "kenlm", "keras-nlp (>=0.3.1)", "librosa", "onnxconverter-common", "optax (>=0.0.8,<=0.1.4)", "optuna", "phonemizer", "protobuf", "pyctcdecode (>=0.4.0)", "ray[tune]", "sentencepiece (>=0.1.91,!=0.1.92)", "sigopt", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx", "timm", "tokenizers (>=0.14,<0.15)", "torch (>=1.10,!=1.12.0)", "torchaudio", "torchvision"] docs-specific = ["hf-doc-builder"] -fairscale = ["fairscale (>0.3)"] flax = ["flax (>=0.4.1,<=0.7.0)", "jax (>=0.4.1,<=0.4.13)", "jaxlib (>=0.4.1,<=0.4.13)", "optax (>=0.0.8,<=0.1.4)"] flax-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] ftfy = ["ftfy"] @@ -1883,7 +1884,7 @@ serving = ["fastapi", "pydantic (<2)", "starlette", "uvicorn"] sigopt = ["sigopt"] sklearn = ["scikit-learn"] speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)", "torchaudio"] -testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "timeout-decorator"] +testing = ["GitPython (<3.1.19)", "beautifulsoup4", "black (>=23.1,<24.0)", "cookiecutter (==1.7.3)", "datasets (!=2.5.0)", "dill (<0.3.5)", "evaluate (>=0.2.0)", "faiss-cpu", "hf-doc-builder (>=0.3.0)", "nltk", "parameterized", "protobuf", "psutil", "pytest (>=7.2.0)", "pytest-timeout", "pytest-xdist", "rjieba", "rouge-score (!=0.0.7,!=0.0.8,!=0.1,!=0.1.1)", "sacrebleu (>=1.4.12,<2.0.0)", "sacremoses", "tensorboard", "timeout-decorator"] tf = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx"] tf-cpu = ["keras-nlp (>=0.3.1)", "onnxconverter-common", "tensorflow-cpu (>=2.6,<2.15)", "tensorflow-text (<2.15)", "tf2onnx"] tf-speech = ["kenlm", "librosa", "phonemizer", "pyctcdecode (>=0.4.0)"] @@ -2162,4 +2163,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "8e15677d9874e0189266d03d9fb22cf968f9b1437476fd41ab0010d2551a6975" +content-hash = "05fdb17a8a21088696573d3a3356c1b815446af17e6589c28c11eb80d2e8788e" diff --git a/pyproject.toml b/pyproject.toml index 9bfd2a7f..ae4f83ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,8 +24,7 @@ classifiers = [ [tool.poetry.dependencies] python = "^3.9" -#pytorch-ie = ">=0.26.0,<0.27.0" -pytorch-ie = { version = "0.26.1.dev1699314147", source = "pre-release" } +pytorch-ie = ">=0.27.0,<0.28.0" [tool.poetry.group.dev.dependencies] torch = {version = "^2.1.0+cpu", source = "pytorch"} From c22f9d56166539888fdf7761abb0b32ed482b35c Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 7 Nov 2023 14:20:17 +0100 Subject: [PATCH 07/12] use TextBasedDocument instead of deprecated TextDocument --- tests/unit/test_dataset_casting.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/unit/test_dataset_casting.py b/tests/unit/test_dataset_casting.py index fede04e3..891e910a 100644 --- a/tests/unit/test_dataset_casting.py +++ b/tests/unit/test_dataset_casting.py @@ -4,18 +4,18 @@ import pytest from pytorch_ie.annotations import LabeledSpan, Span from pytorch_ie.core import AnnotationList, annotation_field -from pytorch_ie.documents import TextDocument +from pytorch_ie.documents import TextBasedDocument from pie_datasets import Dataset, IterableDataset @dataclass -class CoNLL2002Document(TextDocument): +class CoNLL2002Document(TextBasedDocument): entities: AnnotationList[LabeledSpan] = annotation_field(target="text") @dataclass -class DocumentWithParts(TextDocument): +class DocumentWithParts(TextBasedDocument): parts: AnnotationList[Span] = annotation_field(target="text") @@ -25,12 +25,12 @@ class CoNLL2002WithPartsDocument(CoNLL2002Document, DocumentWithParts): @dataclass -class DocumentWithEnts(TextDocument): +class DocumentWithEnts(TextBasedDocument): ents: AnnotationList[LabeledSpan] = annotation_field(target="text") @dataclass -class DocumentWithEntsWrongType(TextDocument): +class DocumentWithEntsWrongType(TextBasedDocument): ents: AnnotationList[Span] = annotation_field(target="text") @@ -40,7 +40,7 @@ class DocumentWithEntsAndParts(DocumentWithParts, DocumentWithEnts): @dataclass -class DocumentWithPartsAndEntitiesSwapped(TextDocument): +class DocumentWithPartsAndEntitiesSwapped(TextBasedDocument): parts: AnnotationList[LabeledSpan] = annotation_field(target="text") entities: AnnotationList[Span] = annotation_field(target="text") From 57bb42ae8e0e63e084658ce498517ed39c2ea366 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 7 Nov 2023 14:36:26 +0100 Subject: [PATCH 08/12] cleanup test_dataset.py --- tests/__init__.py | 3 - tests/unit/test_dataset.py | 317 +++++++++++++++++-------------------- 2 files changed, 142 insertions(+), 178 deletions(-) diff --git a/tests/__init__.py b/tests/__init__.py index f9869874..4cd2c364 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -19,6 +19,3 @@ def _check_hf_conll2003_is_available(): return True except ConnectionError: return False - - -_HF_CONLL2003_IS_AVAILABLE = _check_hf_conll2003_is_available() diff --git a/tests/unit/test_dataset.py b/tests/unit/test_dataset.py index 0b9641a6..831dd5f1 100644 --- a/tests/unit/test_dataset.py +++ b/tests/unit/test_dataset.py @@ -2,7 +2,6 @@ from dataclasses import dataclass from typing import Union -import datasets import numpy import pytest import torch @@ -13,49 +12,13 @@ TaskEncodingDataset, TaskEncodingSequence, ) -from pytorch_ie.documents import TextDocument +from pytorch_ie.documents import TextBasedDocument from pytorch_ie.taskmodules import TransformerSpanClassificationTaskModule from pie_datasets import Dataset, IterableDataset from pie_datasets.dataset import get_pie_dataset_type -from tests import _HF_CONLL2003_IS_AVAILABLE, DATASET_BUILDERS_ROOT from tests.conftest import TestDocument -DATASET_NAME = "conll2003" -PIE_DATASET_PATH = DATASET_BUILDERS_ROOT / "pie" / DATASET_NAME -HF_DATASET_PATH = DATASET_NAME - - -@pytest.fixture(scope="module") -def taskmodule(): - tokenizer_name_or_path = "bert-base-cased" - taskmodule = TransformerSpanClassificationTaskModule( - tokenizer_name_or_path=tokenizer_name_or_path, - entity_annotation="entities", - ) - return taskmodule - - -@pytest.fixture -def model_output(): - return { - "logits": torch.from_numpy( - numpy.log( - [ - # O, ORG, PER - [0.5, 0.2, 0.3], - [0.1, 0.1, 0.8], - [0.1, 0.5, 0.4], - [0.1, 0.4, 0.5], - [0.1, 0.6, 0.3], - ] - ) - ), - "start_indices": torch.tensor([1, 1, 7, 1, 6]), - "end_indices": torch.tensor([2, 4, 7, 4, 6]), - "batch_indices": torch.tensor([0, 1, 1, 2, 2]), - } - def test_dataset(maybe_iterable_dataset): dataset = { @@ -118,13 +81,13 @@ def clear_relations_batched(documents): def test_dataset_map_with_result_document_type(maybe_iterable_dataset): @dataclass - class TestDocument(TextDocument): + class TestDocument(TextBasedDocument): sentences: AnnotationList[Span] = annotation_field(target="text") entities: AnnotationList[LabeledSpan] = annotation_field(target="text") relations: AnnotationList[BinaryRelation] = annotation_field(target="entities") @dataclass - class TestDocumentWithTokensButNoRelations(TextDocument): + class TestDocumentWithTokensButNoRelations(TextBasedDocument): sentences: AnnotationList[Span] = annotation_field(target="text") tokens: AnnotationList[Span] = annotation_field(target="text") entities: AnnotationList[LabeledSpan] = annotation_field(target="text") @@ -165,139 +128,6 @@ def clear_relations_and_add_one_token( } -@pytest.mark.parametrize("encode_target", [False, True]) -@pytest.mark.parametrize("inplace", [False, True]) -@pytest.mark.parametrize("as_dataset", [False, True]) -def test_dataset_with_taskmodule( - maybe_iterable_dataset, taskmodule, model_output, encode_target, inplace, as_dataset -): - train_dataset = maybe_iterable_dataset["train"] - - taskmodule.prepare(train_dataset) - assert set(taskmodule.label_to_id.keys()) == {"PER", "ORG", "O"} - assert [taskmodule.id_to_label[i] for i in range(3)] == ["O", "ORG", "PER"] - assert taskmodule.label_to_id["O"] == 0 - - as_task_encoding_sequence = not encode_target - as_iterator = isinstance(train_dataset, (IterableDataset, Iterator)) - if as_task_encoding_sequence: - if as_iterator: - with pytest.raises( - ValueError, match="can not return a TaskEncodingSequence as Iterator" - ): - taskmodule.encode( - train_dataset, encode_target=encode_target, as_dataset=as_dataset - ) - return - if as_dataset: - with pytest.raises( - ValueError, match="can not return a TaskEncodingSequence as a dataset" - ): - taskmodule.encode( - train_dataset, encode_target=encode_target, as_dataset=as_dataset - ) - return - - task_encodings = taskmodule.encode( - train_dataset, encode_target=encode_target, as_dataset=as_dataset - ) - - if as_iterator: - if as_task_encoding_sequence: - raise NotImplementedError("this is not yet implemented") - if as_dataset: - assert isinstance(task_encodings, IterableTaskEncodingDataset) - else: - assert isinstance(task_encodings, Iterator) - else: - if as_dataset: - if as_task_encoding_sequence: - raise NotImplementedError("this is not yet implemented") - else: - assert isinstance(task_encodings, TaskEncodingDataset) - else: - if as_task_encoding_sequence: - assert isinstance(task_encodings, TaskEncodingSequence) - else: - assert isinstance(task_encodings, Sequence) - - task_encoding_list = list(task_encodings) - assert len(task_encoding_list) == 8 - task_encoding = task_encoding_list[5] - document = list(train_dataset)[5] - assert task_encoding.document == document - assert "input_ids" in task_encoding.inputs - assert ( - taskmodule.tokenizer.decode(task_encoding.inputs["input_ids"], skip_special_tokens=True) - == document.text - ) - - if encode_target: - assert task_encoding.targets == [ - (1, 4, taskmodule.label_to_id["PER"]), - (6, 6, taskmodule.label_to_id["ORG"]), - (9, 9, taskmodule.label_to_id["ORG"]), - ] - else: - assert not task_encoding.has_targets - - unbatched_outputs = taskmodule.unbatch_output(model_output) - - decoded_documents = taskmodule.decode( - task_encodings=task_encodings, - task_outputs=unbatched_outputs, - inplace=inplace, - ) - - if isinstance(train_dataset, Dataset): - assert len(decoded_documents) == len(train_dataset) - - assert {id(doc) for doc in decoded_documents}.isdisjoint({id(doc) for doc in train_dataset}) - - expected_scores = [0.8, 0.5, 0.5, 0.6] - i = 0 - for document in decoded_documents: - for entity_expected, entity_decoded in zip( - document["entities"], document["entities"].predictions - ): - assert entity_expected.start == entity_decoded.start - assert entity_expected.end == entity_decoded.end - assert entity_expected.label == entity_decoded.label - assert expected_scores[i] == pytest.approx(entity_decoded.score) - i += 1 - - for document in train_dataset: - assert not document["entities"].predictions - - -@pytest.mark.skipif( - not _HF_CONLL2003_IS_AVAILABLE, - reason="the Huggingface conll2003 dataset is not reachable and the local PIE-variant depends on it", -) -def test_load_with_hf_datasets(): - dataset = datasets.load_dataset(path=str(HF_DATASET_PATH)) - - assert set(dataset.keys()) == {"train", "validation", "test"} - - assert len(dataset["train"]) == 14041 - assert len(dataset["validation"]) == 3250 - assert len(dataset["test"]) == 3453 - - -@pytest.mark.skipif( - not _HF_CONLL2003_IS_AVAILABLE, - reason="the Huggingface conll2003 dataset is not reachable and the remote PIE-variant depends on it", -) -def test_load_with_hf_datasets_from_hub(): - dataset = datasets.load_dataset(path=str(PIE_DATASET_PATH)) - - assert set(dataset.keys()) == {"train", "validation", "test"} - - assert len(dataset["train"]) == 14041 - assert len(dataset["validation"]) == 3250 - assert len(dataset["test"]) == 3453 - - def test_get_pie_dataset_type(hf_dataset, iterable_hf_dataset): assert get_pie_dataset_type(hf_dataset["train"]) == Dataset assert get_pie_dataset_type(iterable_hf_dataset["train"]) == IterableDataset @@ -310,7 +140,7 @@ def test_get_pie_dataset_type(hf_dataset, iterable_hf_dataset): @dataclass -class TestDocumentWithLabel(TextDocument): +class TestDocumentWithLabel(TextBasedDocument): label: AnnotationList[Label] = annotation_field() @@ -339,7 +169,7 @@ def test_register_document_converter_function(dataset_with_converter_functions): @dataclass -class TestDocumentWithLabeledSpans(TextDocument): +class TestDocumentWithLabeledSpans(TextBasedDocument): spans: AnnotationList[LabeledSpan] = annotation_field(target="text") @@ -458,3 +288,140 @@ class TestDocumentWithSpans(TestDocument): "my_converter_method should accept as input and return " "'.TestDocumentWithSpans'>'." ) + + +@pytest.fixture(scope="module") +def taskmodule(): + # TODO: use a mock taskmodule instead + tokenizer_name_or_path = "bert-base-cased" + taskmodule = TransformerSpanClassificationTaskModule( + tokenizer_name_or_path=tokenizer_name_or_path, + entity_annotation="entities", + ) + return taskmodule + + +@pytest.fixture +def model_output(): + return { + "logits": torch.from_numpy( + numpy.log( + [ + # O, ORG, PER + [0.5, 0.2, 0.3], + [0.1, 0.1, 0.8], + [0.1, 0.5, 0.4], + [0.1, 0.4, 0.5], + [0.1, 0.6, 0.3], + ] + ) + ), + "start_indices": torch.tensor([1, 1, 7, 1, 6]), + "end_indices": torch.tensor([2, 4, 7, 4, 6]), + "batch_indices": torch.tensor([0, 1, 1, 2, 2]), + } + + +@pytest.mark.parametrize("encode_target", [False, True]) +@pytest.mark.parametrize("inplace", [False, True]) +@pytest.mark.parametrize("as_dataset", [False, True]) +def test_dataset_with_taskmodule( + maybe_iterable_dataset, taskmodule, model_output, encode_target, inplace, as_dataset +): + train_dataset = maybe_iterable_dataset["train"] + + taskmodule.prepare(train_dataset) + assert set(taskmodule.label_to_id.keys()) == {"PER", "ORG", "O"} + assert [taskmodule.id_to_label[i] for i in range(3)] == ["O", "ORG", "PER"] + assert taskmodule.label_to_id["O"] == 0 + + as_task_encoding_sequence = not encode_target + as_iterator = isinstance(train_dataset, (IterableDataset, Iterator)) + if as_task_encoding_sequence: + if as_iterator: + with pytest.raises( + ValueError, match="can not return a TaskEncodingSequence as Iterator" + ): + taskmodule.encode( + train_dataset, encode_target=encode_target, as_dataset=as_dataset + ) + return + if as_dataset: + with pytest.raises( + ValueError, match="can not return a TaskEncodingSequence as a dataset" + ): + taskmodule.encode( + train_dataset, encode_target=encode_target, as_dataset=as_dataset + ) + return + + task_encodings = taskmodule.encode( + train_dataset, encode_target=encode_target, as_dataset=as_dataset + ) + + if as_iterator: + if as_task_encoding_sequence: + raise NotImplementedError("this is not yet implemented") + if as_dataset: + assert isinstance(task_encodings, IterableTaskEncodingDataset) + else: + assert isinstance(task_encodings, Iterator) + else: + if as_dataset: + if as_task_encoding_sequence: + raise NotImplementedError("this is not yet implemented") + else: + assert isinstance(task_encodings, TaskEncodingDataset) + else: + if as_task_encoding_sequence: + assert isinstance(task_encodings, TaskEncodingSequence) + else: + assert isinstance(task_encodings, Sequence) + + task_encoding_list = list(task_encodings) + assert len(task_encoding_list) == 8 + task_encoding = task_encoding_list[5] + document = list(train_dataset)[5] + assert task_encoding.document == document + assert "input_ids" in task_encoding.inputs + assert ( + taskmodule.tokenizer.decode(task_encoding.inputs["input_ids"], skip_special_tokens=True) + == document.text + ) + + if encode_target: + assert task_encoding.targets == [ + (1, 4, taskmodule.label_to_id["PER"]), + (6, 6, taskmodule.label_to_id["ORG"]), + (9, 9, taskmodule.label_to_id["ORG"]), + ] + else: + assert not task_encoding.has_targets + + unbatched_outputs = taskmodule.unbatch_output(model_output) + + decoded_documents = taskmodule.decode( + task_encodings=task_encodings, + task_outputs=unbatched_outputs, + inplace=inplace, + ) + + if isinstance(train_dataset, Dataset): + assert len(decoded_documents) == len(train_dataset) + + assert {id(doc) for doc in decoded_documents}.isdisjoint({id(doc) for doc in train_dataset}) + + expected_scores = [0.8, 0.5, 0.5, 0.6] + i = 0 + for document in decoded_documents: + for entity_expected, entity_decoded in zip( + document["entities"], document["entities"].predictions + ): + assert entity_expected.start == entity_decoded.start + assert entity_expected.end == entity_decoded.end + assert entity_expected.label == entity_decoded.label + assert expected_scores[i] == pytest.approx(entity_decoded.score) + i += 1 + + for document in train_dataset: + assert not document["entities"].predictions From 8a13fa658fcf0641e722f3e8968d048fefb06ff2 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 7 Nov 2023 14:36:47 +0100 Subject: [PATCH 09/12] minor --- tests/unit/test_dataset_dict.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/unit/test_dataset_dict.py b/tests/unit/test_dataset_dict.py index dbc420a7..4952d35e 100644 --- a/tests/unit/test_dataset_dict.py +++ b/tests/unit/test_dataset_dict.py @@ -23,9 +23,10 @@ logger = logging.getLogger(__name__) -DATA_PATH = FIXTURES_ROOT / "dataset_dict" / "conll2003_extract" DATASET_NAME = "conll2003" PIE_DATASET_PATH = DATASET_BUILDERS_ROOT / "pie" / DATASET_NAME +DATA_PATH = FIXTURES_ROOT / "dataset_dict" / f"{DATASET_NAME}_extract" + TEST_CLASS_PREFIX = "tests.unit.test_dataset_dict" CREATE_FIXTURE_DATA = False From 6851a43796ffdef04870f021343297d94375400a Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 7 Nov 2023 14:43:05 +0100 Subject: [PATCH 10/12] move CREATE_FIXTURE_DATA to conftest.py --- tests/conftest.py | 7 +++++++ tests/unit/test_dataset_dict.py | 4 +--- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 6fcd7a12..b0bccda5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,6 +28,13 @@ _TABULATE_AVAILABLE = "tabulate" in {pkg.key for pkg in pkg_resources.working_set} +CREATE_FIXTURE_DATA = False + + +# just ensure that this never happens on CI +def test_dont_create_fixture_data(): + assert not CREATE_FIXTURE_DATA + @pytest.fixture def documents(dataset): diff --git a/tests/unit/test_dataset_dict.py b/tests/unit/test_dataset_dict.py index 4952d35e..cbaaf132 100644 --- a/tests/unit/test_dataset_dict.py +++ b/tests/unit/test_dataset_dict.py @@ -19,7 +19,7 @@ IterableDataset, ) from tests import DATASET_BUILDERS_ROOT, FIXTURES_ROOT -from tests.conftest import TestDocument +from tests.conftest import CREATE_FIXTURE_DATA, TestDocument logger = logging.getLogger(__name__) @@ -29,8 +29,6 @@ TEST_CLASS_PREFIX = "tests.unit.test_dataset_dict" -CREATE_FIXTURE_DATA = False - @pytest.mark.skipif(condition=not CREATE_FIXTURE_DATA, reason="don't create fixture data again") def test_create_fixture_data(): From ef9074f91c3a735291bcc9f41aae992fde51c923 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 7 Nov 2023 14:43:32 +0100 Subject: [PATCH 11/12] renaming --- tests/unit/test_dataset_dict.py | 53 ++++++++++++++++++++------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/tests/unit/test_dataset_dict.py b/tests/unit/test_dataset_dict.py index cbaaf132..420382d2 100644 --- a/tests/unit/test_dataset_dict.py +++ b/tests/unit/test_dataset_dict.py @@ -24,8 +24,9 @@ logger = logging.getLogger(__name__) DATASET_NAME = "conll2003" +N_FIXTURE_SAMPLES = 3 PIE_DATASET_PATH = DATASET_BUILDERS_ROOT / "pie" / DATASET_NAME -DATA_PATH = FIXTURES_ROOT / "dataset_dict" / f"{DATASET_NAME}_extract" +FIXTURE_DATA_PATH = FIXTURES_ROOT / "dataset_dict" / f"{DATASET_NAME}_extract" TEST_CLASS_PREFIX = "tests.unit.test_dataset_dict" @@ -35,8 +36,8 @@ def test_create_fixture_data(): conll2003 = DatasetDict(datasets.load_dataset(str(PIE_DATASET_PATH))) for split in list(conll2003): # restrict all splits to 3 examples - conll2003 = conll2003.select(split=split, stop=3) - conll2003.to_json(DATA_PATH) + conll2003 = conll2003.select(split=split, stop=N_FIXTURE_SAMPLES) + conll2003.to_json(FIXTURE_DATA_PATH) @dataclass @@ -47,20 +48,20 @@ class DocumentWithEntitiesAndRelations(TextBasedDocument): @pytest.fixture(scope="module") def dataset_dict(): return DatasetDict.from_json( - data_dir=DATA_PATH, document_type=DocumentWithEntitiesAndRelations + data_dir=FIXTURE_DATA_PATH, document_type=DocumentWithEntitiesAndRelations ) def test_from_json(dataset_dict): assert set(dataset_dict) == {"train", "test", "validation"} - assert len(dataset_dict["train"]) == 3 - assert len(dataset_dict["test"]) == 3 - assert len(dataset_dict["validation"]) == 3 + assert len(dataset_dict["train"]) == N_FIXTURE_SAMPLES + assert len(dataset_dict["test"]) == N_FIXTURE_SAMPLES + assert len(dataset_dict["validation"]) == N_FIXTURE_SAMPLES def test_from_json_no_serialized_document_type(dataset_dict): with pytest.raises(ValueError) as excinfo: - DatasetDict.from_json(data_dir=DATA_PATH) + DatasetDict.from_json(data_dir=FIXTURE_DATA_PATH) assert ( str(excinfo.value) == "document_type must be provided if it cannot be loaded from the metadata file" @@ -70,7 +71,7 @@ def test_from_json_no_serialized_document_type(dataset_dict): @pytest.fixture(scope="module") def iterable_dataset_dict(): return DatasetDict.from_json( - data_dir=DATA_PATH, + data_dir=FIXTURE_DATA_PATH, document_type=DocumentWithEntitiesAndRelations, streaming=True, ) @@ -84,7 +85,7 @@ def test_to_json_and_back(dataset_dict, tmp_path): path = Path(tmp_path) / "dataset_dict" dataset_dict.to_json(path) dataset_dict_from_json = DatasetDict.from_json( - data_dir=path, + data_dir=str(path), document_type=dataset_dict.document_type, ) assert set(dataset_dict_from_json) == set(dataset_dict) @@ -98,7 +99,7 @@ def test_to_json_and_back_serialize_document_type(dataset_dict, tmp_path): path = Path(tmp_path) / "dataset_dict" dataset_dict.to_json(path) dataset_dict_from_json = DatasetDict.from_json( - data_dir=path, + data_dir=str(path), ) assert set(dataset_dict_from_json) == set(dataset_dict) for split in dataset_dict: @@ -118,7 +119,7 @@ def test_document_type_empty_no_splits(): def test_document_type_different_types(dataset_dict): # load the example dataset as a different document type dataset_dict_different_type = DatasetDict.from_json( - data_dir=DATA_PATH, + data_dir=FIXTURE_DATA_PATH, document_type=TextBasedDocument, ) assert dataset_dict_different_type.document_type is TextBasedDocument @@ -361,14 +362,18 @@ def test_filter(dataset_dict): split="train", ) assert all(len(doc.text) > 15 for doc in dataset_dict_filtered["train"]) - assert len(dataset_dict["train"]) == 3 + assert len(dataset_dict["train"]) == N_FIXTURE_SAMPLES assert len(dataset_dict_filtered["train"]) == 2 assert dataset_dict_filtered["train"][0] == dataset_dict["train"][0] assert dataset_dict_filtered["train"][1] == dataset_dict["train"][2] # remaining splits should be unchanged - assert len(dataset_dict_filtered["validation"]) == len(dataset_dict["validation"]) == 3 - assert len(dataset_dict_filtered["test"]) == len(dataset_dict["test"]) == 3 + assert ( + len(dataset_dict_filtered["validation"]) + == len(dataset_dict["validation"]) + == N_FIXTURE_SAMPLES + ) + assert len(dataset_dict_filtered["test"]) == len(dataset_dict["test"]) == N_FIXTURE_SAMPLES assert_doc_lists_equal(dataset_dict_filtered["validation"], dataset_dict["validation"]) assert_doc_lists_equal(dataset_dict_filtered["test"], dataset_dict["test"]) @@ -393,9 +398,13 @@ def test_filter_unknown_dataset_type(): def test_filter_noop(dataset_dict): # passing no filter function should be a noop dataset_dict_filtered = dataset_dict.filter(split="train") - assert len(dataset_dict_filtered["train"]) == len(dataset_dict["train"]) == 3 - assert len(dataset_dict_filtered["validation"]) == len(dataset_dict["validation"]) == 3 - assert len(dataset_dict_filtered["test"]) == len(dataset_dict["test"]) == 3 + assert len(dataset_dict_filtered["train"]) == len(dataset_dict["train"]) == N_FIXTURE_SAMPLES + assert ( + len(dataset_dict_filtered["validation"]) + == len(dataset_dict["validation"]) + == N_FIXTURE_SAMPLES + ) + assert len(dataset_dict_filtered["test"]) == len(dataset_dict["test"]) == N_FIXTURE_SAMPLES assert_doc_lists_equal(dataset_dict_filtered["train"], dataset_dict["train"]) assert_doc_lists_equal(dataset_dict_filtered["validation"], dataset_dict["validation"]) assert_doc_lists_equal(dataset_dict_filtered["test"], dataset_dict["test"]) @@ -422,8 +431,12 @@ def test_move_to_new_split(dataset_dict, ids, filter_function): assert_doc_lists_equal(dataset_dict_moved["train"], dataset_dict["train"][:1]) # the remaining splits should be unchanged - assert len(dataset_dict_moved["validation"]) == len(dataset_dict["validation"]) == 3 - assert len(dataset_dict_moved["test"]) == len(dataset_dict["test"]) == 3 + assert ( + len(dataset_dict_moved["validation"]) + == len(dataset_dict["validation"]) + == N_FIXTURE_SAMPLES + ) + assert len(dataset_dict_moved["test"]) == len(dataset_dict["test"]) == N_FIXTURE_SAMPLES assert_doc_lists_equal(dataset_dict_moved["validation"], dataset_dict["validation"]) assert_doc_lists_equal(dataset_dict_moved["test"], dataset_dict["test"]) From b672b73a60bbfbfc9294dcfd94f978f0ee71a61d Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 7 Nov 2023 14:45:27 +0100 Subject: [PATCH 12/12] minor --- tests/unit/test_dataset_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/test_dataset_dict.py b/tests/unit/test_dataset_dict.py index 420382d2..2f7b97eb 100644 --- a/tests/unit/test_dataset_dict.py +++ b/tests/unit/test_dataset_dict.py @@ -24,6 +24,7 @@ logger = logging.getLogger(__name__) DATASET_NAME = "conll2003" +# restrict all splits to 3 examples N_FIXTURE_SAMPLES = 3 PIE_DATASET_PATH = DATASET_BUILDERS_ROOT / "pie" / DATASET_NAME FIXTURE_DATA_PATH = FIXTURES_ROOT / "dataset_dict" / f"{DATASET_NAME}_extract" @@ -35,7 +36,6 @@ def test_create_fixture_data(): conll2003 = DatasetDict(datasets.load_dataset(str(PIE_DATASET_PATH))) for split in list(conll2003): - # restrict all splits to 3 examples conll2003 = conll2003.select(split=split, stop=N_FIXTURE_SAMPLES) conll2003.to_json(FIXTURE_DATA_PATH)