diff --git a/src/pie_datasets/core/dataset.py b/src/pie_datasets/core/dataset.py index ca18f0b0..d0d81ab7 100644 --- a/src/pie_datasets/core/dataset.py +++ b/src/pie_datasets/core/dataset.py @@ -292,6 +292,36 @@ def from_hf_dataset( ) return document_dataset + @classmethod + def from_documents( + cls, + documents: List[Document], + document_converters: Optional[DocumentConvertersType] = None, + **dataset_kwargs, + ) -> "Dataset": + """Create a Dataset from a list of documents. It wraps the Huggingface + datasets.Dataset.from_list method, see the documentation for more details. + + Args: + documents (List[Document]): A list of documents. + document_converters (Optional[DocumentConvertersType], optional): A dictionary of document + converters. Defaults to None. + **dataset_kwargs: Additional arguments for the Huggingface dataset creation. + + Returns: + Dataset: The created dataset. + """ + + if len(documents) == 0: + raise ValueError("No documents to create dataset from") + document_type = type(documents[0]) + data = [doc.asdict() for doc in documents] + hf_dataset = datasets.Dataset.from_list(mapping=data, **dataset_kwargs) + dataset = cls.from_hf_dataset( + hf_dataset, document_type=document_type, document_converters=document_converters + ) + return dataset + def apply_hf_func(self, func, **kwargs) -> "Dataset": return Dataset.from_hf_dataset( func(self, **kwargs), @@ -470,6 +500,50 @@ def from_hf_dataset( ) return dataset + @classmethod + def from_documents( + cls, + documents: Callable, + document_converters: Optional[DocumentConvertersType] = None, + **dataset_kwargs, + ) -> "IterableDataset": + """Create an Iterable Dataset from a generator that yields documents. It wraps the + Huggingface datasets.IterableDataset.from_generator method, see the documentation for more + details. + + Args: + documents (Callable): A generator function that `yields` documents. + document_converters (Optional[DocumentConvertersType], optional): A dictionary of document + converters. Defaults to None. + **dataset_kwargs: Additional arguments for the Huggingface dataset creation. + + Returns: + IterableDataset: The created iterable dataset. + """ + + # get first document to infer the document type + try: + gen_kwargs = dataset_kwargs.get("gen_kwargs", {}) + first_doc = next(documents(**gen_kwargs)) + except StopIteration: + raise ValueError("No documents to create dataset from") + document_type = type(first_doc) + + # wrap the generator to yield dictionaries + def wrapped_documents_generator(**kwargs): + for doc in documents(**kwargs): + yield doc.asdict() + + hf_dataset = datasets.IterableDataset.from_generator( + wrapped_documents_generator, **dataset_kwargs + ) + dataset = cls.from_hf_dataset( + hf_dataset, + document_type=document_type, + document_converters=document_converters, + ) + return dataset + def __iter__(self): for example in iter(super().__iter__()): yield self.document_type.fromdict(example) diff --git a/tests/unit/core/test_dataset.py b/tests/unit/core/test_dataset.py index 3481d191..789b5d3a 100644 --- a/tests/unit/core/test_dataset.py +++ b/tests/unit/core/test_dataset.py @@ -6,6 +6,7 @@ import numpy import pytest import torch +from pytorch_ie import Document from pytorch_ie.annotations import BinaryRelation, Label, LabeledSpan, Span from pytorch_ie.core import AnnotationList, annotation_field from pytorch_ie.core.taskmodule import ( @@ -431,3 +432,50 @@ def test_dataset_with_taskmodule( for document in train_dataset: assert not document["entities"].predictions + + +@pytest.mark.parametrize("as_iterable_dataset", [False, True]) +def test_pie_dataset_from_documents(documents, as_iterable_dataset): + if as_iterable_dataset: + dataset_class = IterableDataset + + # make generators from list + def _documents(): + yield from documents + + def _empty_docs(): + return iter([]) + + else: + dataset_class = Dataset + _documents = documents + _empty_docs = list[Document]() + + dataset_from_documents = dataset_class.from_documents(_documents) + + assert isinstance(dataset_from_documents, dataset_class) + + assert all(isinstance(doc, TextBasedDocument) for doc in dataset_from_documents) + assert all( + doc1.asdict() == doc2.asdict() for doc1, doc2 in zip(documents, dataset_from_documents) + ) + assert hasattr(dataset_from_documents, "document_type") + + # Test dataset creation with document converter + dataset_from_documents_with_converter = dataset_class.from_documents( + _documents, document_converters={TestDocumentWithLabel: convert_to_document_with_label} + ) + + assert isinstance(dataset_from_documents_with_converter, dataset_class) + + assert len(dataset_from_documents_with_converter.document_converters) == 1 + assert TestDocumentWithLabel in dataset_from_documents_with_converter.document_converters + assert ( + dataset_from_documents_with_converter.document_converters[TestDocumentWithLabel] + == convert_to_document_with_label + ) + + # Test dataset creation with empty list / generator + with pytest.raises(ValueError) as excinfo: + dataset_class.from_documents(_empty_docs) + assert str(excinfo.value) == "No documents to create dataset from"