Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

generic dataset tests [WIP] #28

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions tests/dataset_builders/dataset_tester.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
import os
import tempfile
from typing import List, Optional
from unittest import TestCase

from datasets.builder import BuilderConfig
from datasets.download.download_manager import DownloadMode
from datasets.download.mock_download_manager import MockDownloadManager
from datasets.load import dataset_module_factory, import_main_class
from datasets.utils.file_utils import DownloadConfig, is_remote_url
from datasets.utils.logging import get_logger
from pytorch_ie.data.builder import ArrowBasedBuilder, GeneratorBasedBuilder

from tests import DATASET_BUILDERS_ROOT

logger = get_logger(__name__)


# Taken from https://github.com/huggingface/datasets/blob/207be676bffe9d164740a41a883af6125edef135/tests/test_dataset_common.py#L101
class DatasetTester:
def __init__(self, parent):
self.parent = parent if parent is not None else TestCase()

def load_builder_class(self, dataset_name, is_local=False):
# Download/copy dataset script
if is_local is True:
dataset_module = dataset_module_factory(
os.path.join(DATASET_BUILDERS_ROOT, dataset_name)
)
else:
dataset_module = dataset_module_factory(
dataset_name, download_config=DownloadConfig(force_download=True)
)
# Get dataset builder class
builder_cls = import_main_class(dataset_module.module_path)
return builder_cls

def load_all_configs(self, dataset_name, is_local=False) -> List[Optional[BuilderConfig]]:
# get builder class
builder_cls = self.load_builder_class(dataset_name, is_local=is_local)
builder = builder_cls

if len(builder.BUILDER_CONFIGS) == 0:
return [None]
return builder.BUILDER_CONFIGS

def check_load_dataset(
self, dataset_name, configs, is_local=False, use_local_dummy_data=False
):
for config in configs:
with tempfile.TemporaryDirectory() as processed_temp_dir, tempfile.TemporaryDirectory() as raw_temp_dir:
# create config and dataset
dataset_builder_cls = self.load_builder_class(dataset_name, is_local=is_local)
name = config.name if config is not None else None
dataset_builder = dataset_builder_cls(
config_name=name, cache_dir=processed_temp_dir
)

# TODO: skip Beam datasets and datasets that lack dummy data for now
if not isinstance(dataset_builder, (ArrowBasedBuilder, GeneratorBasedBuilder)):
logger.info("Skip tests for this dataset for now")
return

if config is not None:
version = config.version
else:
version = dataset_builder.VERSION

def check_if_url_is_valid(url):
if is_remote_url(url) and "\\" in url:
raise ValueError(f"Bad remote url '{url} since it contains a backslash")

# create mock data loader manager that has a special download_and_extract() method to download dummy data instead of real data
mock_dl_manager = MockDownloadManager(
dataset_name=dataset_name,
config=config,
version=version,
cache_dir=raw_temp_dir,
use_local_dummy_data=use_local_dummy_data,
download_callbacks=[check_if_url_is_valid],
)
mock_dl_manager.datasets_scripts_dir = str(DATASET_BUILDERS_ROOT)

# packaged datasets like csv, text, json or pandas require some data files
# builder_name = dataset_builder.__class__.__name__.lower()
# if builder_name in _PACKAGED_DATASETS_MODULES:
# mock_dl_manager.download_dummy_data()
# path_to_dummy_data = mock_dl_manager.dummy_file
# dataset_builder.config.data_files = get_packaged_dataset_dummy_data_files(
# builder_name, path_to_dummy_data
# )
# for config_attr, value in get_packaged_dataset_config_attributes(builder_name).items():
# setattr(dataset_builder.config, config_attr, value)

# mock size needed for dummy data instead of actual dataset
if dataset_builder.info is not None:
# approximate upper bound of order of magnitude of dummy data files
one_mega_byte = 2 << 19
dataset_builder.info.size_in_bytes = 2 * one_mega_byte
dataset_builder.info.download_size = one_mega_byte
dataset_builder.info.dataset_size = one_mega_byte

# generate examples from dummy data
dataset_builder.download_and_prepare(
dl_manager=mock_dl_manager,
download_mode=DownloadMode.FORCE_REDOWNLOAD,
verification_mode="no_checks",
try_from_hf_gcs=False,
)

# get dataset
dataset = dataset_builder.as_dataset(verification_mode="no_checks")

# check that dataset is not empty
self.parent.assertListEqual(
sorted(dataset_builder.info.splits.keys()), sorted(dataset)
)
for split in dataset_builder.info.splits.keys():
# check that loaded dataset is not empty
self.parent.assertTrue(len(dataset[split]) > 0)

# check that we can cast features for each task template
task_templates = dataset_builder.info.task_templates
if task_templates:
for task in task_templates:
task_features = {**task.input_schema, **task.label_schema}
for split in dataset:
casted_dataset = dataset[split].prepare_for_task(task)
self.parent.assertDictEqual(task_features, casted_dataset.features)
del casted_dataset
del dataset
121 changes: 121 additions & 0 deletions tests/dataset_builders/test_dataset_common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import os
import tempfile

import pytest
from absl.testing import parameterized
from datasets.builder import BuilderConfig, DatasetBuilder
from datasets.download.download_manager import DownloadMode
from datasets.load import dataset_module_factory, import_main_class, load_dataset
from datasets.utils.file_utils import DownloadConfig

from tests import DATASET_BUILDERS_ROOT
from tests.data.dataset_tester import DatasetTester


def test_datasets_dir_and_script_names():
for dataset_dir in DATASET_BUILDERS_ROOT.iterdir():
name = dataset_dir.name
if (
not name.startswith("__") and len(os.listdir(dataset_dir)) > 0
): # ignore __pycache__ and empty dirs
# check that the script name is the same as the dir name
assert os.path.exists(
os.path.join(dataset_dir, name + ".py")
), f"Bad structure for dataset '{name}'. Please check that the directory name is a valid dataset and that the same the same as the dataset script name."

# if name in _PACKAGED_DATASETS_MODULES:
# continue
# else:
# # check that the script name is the same as the dir name
# assert os.path.exists(
# os.path.join(dataset_dir, name + ".py")
# ), f"Bad structure for dataset '{name}'. Please check that the directory name is a valid dataset and that the same the same as the dataset script name."


def get_local_dataset_names():
dataset_script_files = list(DATASET_BUILDERS_ROOT.absolute().glob("**/*.py"))
datasets = [
dataset_script_file.parent.name
for dataset_script_file in dataset_script_files
if dataset_script_file.name != "__init__.py"
]
return [{"testcase_name": x, "dataset_name": x} for x in datasets]


@parameterized.named_parameters(get_local_dataset_names())
# @for_all_test_methods(skip_if_dataset_requires_faiss, skip_if_not_compatible_with_windows)
class LocalDatasetTest(parameterized.TestCase):
dataset_name = None

def setUp(self):
self.dataset_tester = DatasetTester(self)

def test_load_dataset(self, dataset_name):
configs = self.dataset_tester.load_all_configs(dataset_name, is_local=True)[:1]
self.dataset_tester.check_load_dataset(
dataset_name, configs, is_local=True, use_local_dummy_data=True
)

def test_builder_class(self, dataset_name):
builder_cls = self.dataset_tester.load_builder_class(dataset_name, is_local=True)
name = builder_cls.BUILDER_CONFIGS[0].name if builder_cls.BUILDER_CONFIGS else None
with tempfile.TemporaryDirectory() as tmp_cache_dir:
builder = builder_cls(config_name=name, cache_dir=tmp_cache_dir)
self.assertIsInstance(builder, DatasetBuilder)

def test_builder_configs(self, dataset_name):
builder_configs = self.dataset_tester.load_all_configs(dataset_name, is_local=True)
self.assertTrue(len(builder_configs) > 0)

if builder_configs[0] is not None:
all(self.assertIsInstance(config, BuilderConfig) for config in builder_configs)

@pytest.mark.slow
def test_load_dataset_all_configs(self, dataset_name):
configs = self.dataset_tester.load_all_configs(dataset_name, is_local=True)
self.dataset_tester.check_load_dataset(
dataset_name, configs, is_local=True, use_local_dummy_data=True
)

@pytest.mark.slow
def test_load_real_dataset(self, dataset_name):
path = str(DATASET_BUILDERS_ROOT / dataset_name)
dataset_module = dataset_module_factory(
path, download_config=DownloadConfig(local_files_only=True)
)
builder_cls = import_main_class(dataset_module.module_path)
name = builder_cls.BUILDER_CONFIGS[0].name if builder_cls.BUILDER_CONFIGS else None
with tempfile.TemporaryDirectory() as temp_cache_dir:
dataset = load_dataset(
path,
name=name,
cache_dir=temp_cache_dir,
download_mode=DownloadMode.FORCE_REDOWNLOAD,
)
for split in dataset.keys():
self.assertTrue(len(dataset[split]) > 0)
del dataset

@pytest.mark.slow
def test_load_real_dataset_all_configs(self, dataset_name):
path = str(DATASET_BUILDERS_ROOT / dataset_name)
dataset_module = dataset_module_factory(
path, download_config=DownloadConfig(local_files_only=True)
)
builder_cls = import_main_class(dataset_module.module_path)
config_names = (
[config.name for config in builder_cls.BUILDER_CONFIGS]
if len(builder_cls.BUILDER_CONFIGS) > 0
else [None]
)
for name in config_names:
with tempfile.TemporaryDirectory() as temp_cache_dir:
dataset = load_dataset(
path,
name=name,
cache_dir=temp_cache_dir,
download_mode=DownloadMode.FORCE_REDOWNLOAD,
)
for split in dataset.keys():
self.assertTrue(len(dataset[split]) > 0)
del dataset
Loading