Skip to content

Commit

Permalink
Merge pull request #47 from ArneBinder/load_dataset
Browse files Browse the repository at this point in the history
implement `load_dataset`
  • Loading branch information
ArneBinder authored Nov 13, 2023
2 parents 53cb53b + 643c52f commit 03e84f4
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/pie_datasets/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from .builder import ArrowBasedBuilder, GeneratorBasedBuilder
from .dataset import Dataset, IterableDataset
from .dataset_dict import DatasetDict
from .dataset_dict import DatasetDict, load_dataset

__all__ = [
"GeneratorBasedBuilder",
"ArrowBasedBuilder",
"Dataset",
"IterableDataset",
"DatasetDict",
"load_dataset",
]
19 changes: 19 additions & 0 deletions src/pie_datasets/core/dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,3 +670,22 @@ def cast_document_type(
}
)
return result


def load_dataset(*args, **kwargs) -> Union[DatasetDict, Dataset, IterableDataset]:
dataset_or_dataset_dict = datasets.load_dataset(*args, **kwargs)
if isinstance(dataset_or_dataset_dict, (Dataset, IterableDataset)):
return dataset_or_dataset_dict
elif isinstance(dataset_or_dataset_dict, (datasets.DatasetDict, datasets.IterableDatasetDict)):
for name, dataset in dataset_or_dataset_dict.items():
if not isinstance(dataset, (Dataset, IterableDataset)):
raise TypeError(
f'expected all splits to be {Dataset} or {IterableDataset}, but split "{name}" is of type '
f"{type(dataset)}"
)
return DatasetDict(dataset_or_dataset_dict)
else:
raise TypeError(
f"expected datasets.load_dataset to return {datasets.DatasetDict}, {datasets.IterableDatasetDict}, "
f"{Dataset}, or {IterableDataset}, but got {type(dataset_or_dataset_dict)}"
)
48 changes: 47 additions & 1 deletion tests/unit/core/test_dataset_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pytorch_ie.core import AnnotationList, Document, annotation_field
from pytorch_ie.documents import TextBasedDocument, TextDocument

from pie_datasets import Dataset, DatasetDict, IterableDataset
from pie_datasets import Dataset, DatasetDict, IterableDataset, load_dataset
from pie_datasets.core.dataset_dict import (
EnterDatasetDictMixin,
EnterDatasetMixin,
Expand Down Expand Up @@ -541,3 +541,49 @@ def test_to_document_type_noop(dataset_dict):
dataset_dict_converted = dataset_dict.to_document_type(DocumentWithEntitiesAndRelations)
assert dataset_dict_converted.document_type == DocumentWithEntitiesAndRelations
assert dataset_dict_converted == dataset_dict


def test_load_dataset_conll2003():
dataset_dict = load_dataset("pie/conll2003")
assert isinstance(dataset_dict, DatasetDict)
assert set(dataset_dict) == {"train", "test", "validation"}
split_sizes = {split: len(dataset_dict[split]) for split in dataset_dict}
assert split_sizes == {"train": 14041, "test": 3453, "validation": 3250}
doc = dataset_dict["train"][0]
assert isinstance(doc, TextBasedDocument)
assert doc.text == "EU rejects German call to boycott British lamb ."
resolved_entities = [(str(ent), ent.label) for ent in doc.entities]
assert resolved_entities == [("EU", "ORG"), ("German", "MISC"), ("British", "MISC")]


def test_load_dataset_conll2003_single_split():
dataset = load_dataset("pie/conll2003", split="train")
assert isinstance(dataset, Dataset)
assert len(dataset) == 14041
doc = dataset[0]
assert isinstance(doc, TextBasedDocument)
assert doc.text == "EU rejects German call to boycott British lamb ."
resolved_entities = [(str(ent), ent.label) for ent in doc.entities]
assert resolved_entities == [("EU", "ORG"), ("German", "MISC"), ("British", "MISC")]


def test_load_dataset_conll2003_wrong_type():
with pytest.raises(TypeError) as excinfo:
load_dataset("conll2003")
assert (
str(excinfo.value)
== "expected all splits to be <class 'pie_datasets.core.dataset.Dataset'> or "
"<class 'pie_datasets.core.dataset.IterableDataset'>, but split \"train\" is of type "
"<class 'datasets.arrow_dataset.Dataset'>"
)


def test_load_dataset_conll2003_wrong_type_single_split():
with pytest.raises(TypeError) as excinfo:
load_dataset("conll2003", split="train")
assert (
str(excinfo.value)
== "expected datasets.load_dataset to return <class 'datasets.dataset_dict.DatasetDict'>, "
"<class 'datasets.dataset_dict.IterableDatasetDict'>, <class 'pie_datasets.core.dataset.Dataset'>, "
"or <class 'pie_datasets.core.dataset.IterableDataset'>, but got <class 'datasets.arrow_dataset.Dataset'>"
)

0 comments on commit 03e84f4

Please sign in to comment.