diff --git a/data_juicer/core/__init__.py b/data_juicer/core/__init__.py index edb24ad7b..f5450fabc 100644 --- a/data_juicer/core/__init__.py +++ b/data_juicer/core/__init__.py @@ -1,7 +1,7 @@ from .adapter import Adapter from .analyzer import Analyzer from .data import NestedDataset -from .executor import ExecutorFactory, LocalExecutor, RayExecutor +from .executor import ExecutorBase, ExecutorFactory, LocalExecutor, RayExecutor from .exporter import Exporter from .monitor import Monitor from .tracer import Tracer @@ -13,6 +13,7 @@ 'ExecutorFactory', 'LocalExecutor', 'RayExecutor', + 'ExecutorBase', 'Exporter', 'Monitor', 'Tracer', diff --git a/data_juicer/core/data/__init__.py b/data_juicer/core/data/__init__.py index 0c8ec69ca..d93899665 100644 --- a/data_juicer/core/data/__init__.py +++ b/data_juicer/core/data/__init__.py @@ -1,4 +1,6 @@ -from .dj_dataset import DJDataset, NestedDataset +from .dj_dataset import DJDataset, NestedDataset, wrap_func_with_nested_access from .ray_dataset import RayDataset -__all__ = ['DJDataset', 'NestedDataset', 'RayDataset'] +__all__ = [ + 'DJDataset', 'NestedDataset', 'RayDataset', 'wrap_func_with_nested_access' +] diff --git a/data_juicer/core/executor/factory.py b/data_juicer/core/executor/factory.py index 22944068e..3133e3b1a 100644 --- a/data_juicer/core/executor/factory.py +++ b/data_juicer/core/executor/factory.py @@ -1,7 +1,7 @@ from typing import Union -from local_executor import LocalExecutor -from ray_executor import RayExecutor +from .local_executor import LocalExecutor +from .ray_executor import RayExecutor class ExecutorFactory: diff --git a/data_juicer/core/executor/local_executor.py b/data_juicer/core/executor/local_executor.py index 5cb5f27d1..ac8aca353 100644 --- a/data_juicer/core/executor/local_executor.py +++ b/data_juicer/core/executor/local_executor.py @@ -13,15 +13,13 @@ from data_juicer.core.exporter import Exporter from data_juicer.core.tracer import Tracer from data_juicer.format.load import load_formatter -from data_juicer.format.mixture_formatter import MixtureFormatter from data_juicer.ops import OPERATORS, load_ops from data_juicer.ops.op_fusion import fuse_operators -from data_juicer.ops.selector.frequency_specified_field_selector import \ - FrequencySpecifiedFieldSelector -from data_juicer.ops.selector.topk_specified_field_selector import \ - TopkSpecifiedFieldSelector +from data_juicer.ops.selector import (FrequencySpecifiedFieldSelector, + TopkSpecifiedFieldSelector) from data_juicer.utils import cache_utils from data_juicer.utils.ckpt_utils import CheckpointManager +from data_juicer.utils.sample import random_sample class LocalExecutor(ExecutorBase): @@ -97,52 +95,6 @@ def __init__(self, cfg: Optional[Namespace] = None): logger.info('Trace for all ops.') self.op_list_to_trace = set(OPERATORS.modules.keys()) - def sample_data(self, - dataset_to_sample: Dataset = None, - load_data_np=None, - sample_ratio: float = 1.0, - sample_algo: str = 'uniform', - **kwargs): - """ - Sample a subset from the given dataset. - - :param dataset_to_sample: Dataset to sample from. If None, will use - the formatter linked by the executor. Default is None. - :param load_data_np: number of workers when loading the dataset. - :param sample_ratio: The ratio of the sample size to the original - dataset size. Default is 1.0 (no sampling). - :param sample_algo: Sampling algorithm to use. Options are "uniform", - "frequency_specified_field_selector", or - "topk_specified_field_selector". - Default is "uniform". - :return: A sampled Dataset. - """ - # Determine the dataset to sample from - if dataset_to_sample is not None: - dataset = dataset_to_sample - elif self.cfg.use_checkpoint and self.ckpt_manager.ckpt_available: - logger.info('Loading dataset from checkpoint...') - dataset = self.ckpt_manager.load_ckpt() - elif hasattr(self, 'formatter'): - logger.info('Loading dataset from data formatter...') - if load_data_np is None: - load_data_np = self.cfg.np - dataset = self.formatter.load_dataset(load_data_np, self.cfg) - else: - raise ValueError('No dataset available to sample from.') - - # Perform sampling based on the specified algorithm - if sample_algo == 'uniform': - return MixtureFormatter.random_sample(dataset, sample_ratio) - elif sample_algo == 'frequency_specified_field_selector': - dj_op = FrequencySpecifiedFieldSelector(**kwargs) - return dj_op.process(dataset) - elif sample_algo == 'topk_specified_field_selector': - dj_op = TopkSpecifiedFieldSelector(**kwargs) - return dj_op.process(dataset) - else: - raise ValueError(f'Unsupported sample_algo: {sample_algo}') - def run(self, load_data_np: Optional[PositiveInt] = None, skip_return=False): @@ -215,3 +167,51 @@ def run(self, if not skip_return: return dataset + + def sample_data(self, + dataset_to_sample: Dataset = None, + load_data_np=None, + sample_ratio: float = 1.0, + sample_algo: str = 'uniform', + **kwargs): + """ + Sample a subset from the given dataset. + TODO add support other than LocalExecutor + + :param executor: executor + :param dataset_to_sample: Dataset to sample from. If None, will use + the formatter linked by the executor. Default is None. + :param load_data_np: number of workers when loading the dataset. + :param sample_ratio: The ratio of the sample size to the original + dataset size. Default is 1.0 (no sampling). + :param sample_algo: Sampling algorithm to use. Options are "uniform", + "frequency_specified_field_selector", or + "topk_specified_field_selector". + Default is "uniform". + :return: A sampled Dataset. + """ + # Determine the dataset to sample from + if dataset_to_sample is not None: + dataset = dataset_to_sample + elif self.cfg.use_checkpoint and self.ckpt_manager.ckpt_available: + logger.info('Loading dataset from checkpoint...') + dataset = self.ckpt_manager.load_ckpt() + elif hasattr(self, 'formatter'): + logger.info('Loading dataset from data formatter...') + if load_data_np is None: + load_data_np = self.cfg.np + dataset = self.formatter.load_dataset(load_data_np, self.cfg) + else: + raise ValueError('No dataset available to sample from.') + + # Perform sampling based on the specified algorithm + if sample_algo == 'uniform': + return random_sample(dataset, sample_ratio) + elif sample_algo == 'frequency_specified_field_selector': + dj_op = FrequencySpecifiedFieldSelector(**kwargs) + return dj_op.process(dataset) + elif sample_algo == 'topk_specified_field_selector': + dj_op = TopkSpecifiedFieldSelector(**kwargs) + return dj_op.process(dataset) + else: + raise ValueError(f'Unsupported sample_algo: {sample_algo}') diff --git a/data_juicer/ops/selector/random_selector.py b/data_juicer/ops/selector/random_selector.py index c3990ab19..f92d82b68 100644 --- a/data_juicer/ops/selector/random_selector.py +++ b/data_juicer/ops/selector/random_selector.py @@ -3,9 +3,8 @@ from pydantic import Field, PositiveInt from typing_extensions import Annotated -from data_juicer.format.mixture_formatter import MixtureFormatter - -from ..base_op import OPERATORS, Selector +from data_juicer.ops.base_op import OPERATORS, Selector +from data_juicer.utils.sample import random_sample @OPERATORS.register_module('random_selector') @@ -41,7 +40,6 @@ def process(self, dataset): if self.select_ratio is None and self.select_num is None: return dataset - select_num = 0 if not self.select_ratio: select_num = self.select_num else: @@ -49,5 +47,4 @@ def process(self, dataset): if self.select_num and self.select_num < select_num: select_num = self.select_num - return MixtureFormatter.random_sample(dataset, - sample_number=select_num) + return random_sample(dataset, sample_number=select_num) diff --git a/data_juicer/utils/sample.py b/data_juicer/utils/sample.py index 0164dbec0..17275c588 100644 --- a/data_juicer/utils/sample.py +++ b/data_juicer/utils/sample.py @@ -1,60 +1,6 @@ from itertools import chain, repeat import numpy as np -from datasets import Dataset -from loguru import logger - -from data_juicer.ops.selector import (FrequencySpecifiedFieldSelector, - TopkSpecifiedFieldSelector) - - -class SamplingMixin: - - def sample_data(self, - dataset_to_sample: Dataset = None, - load_data_np=None, - sample_ratio: float = 1.0, - sample_algo: str = 'uniform', - **kwargs): - """ - Sample a subset from the given dataset. - - :param dataset_to_sample: Dataset to sample from. If None, will use - the formatter linked by the executor. Default is None. - :param load_data_np: number of workers when loading the dataset. - :param sample_ratio: The ratio of the sample size to the original - dataset size. Default is 1.0 (no sampling). - :param sample_algo: Sampling algorithm to use. Options are "uniform", - "frequency_specified_field_selector", or - "topk_specified_field_selector". - Default is "uniform". - :return: A sampled Dataset. - """ - # Determine the dataset to sample from - if dataset_to_sample is not None: - dataset = dataset_to_sample - elif self.cfg.use_checkpoint and self.ckpt_manager.ckpt_available: - logger.info('Loading dataset from checkpoint...') - dataset = self.ckpt_manager.load_ckpt() - elif hasattr(self, 'formatter'): - logger.info('Loading dataset from data formatter...') - if load_data_np is None: - load_data_np = self.cfg.np - dataset = self.formatter.load_dataset(load_data_np, self.cfg) - else: - raise ValueError('No dataset available to sample from.') - - # Perform sampling based on the specified algorithm - if sample_algo == 'uniform': - return random_sample(dataset, sample_ratio) - elif sample_algo == 'frequency_specified_field_selector': - dj_op = FrequencySpecifiedFieldSelector(**kwargs) - return dj_op.process(dataset) - elif sample_algo == 'topk_specified_field_selector': - dj_op = TopkSpecifiedFieldSelector(**kwargs) - return dj_op.process(dataset) - else: - raise ValueError(f'Unsupported sample_algo: {sample_algo}') def random_sample(dataset, weight=1.0, sample_number=0, seed=None): diff --git a/tests/core/data/test_config.yaml b/tests/core/data/test_config.yaml index 642ecd958..9620bed65 100644 --- a/tests/core/data/test_config.yaml +++ b/tests/core/data/test_config.yaml @@ -1,5 +1,5 @@ -project_name: 'dataset-local-json' +project_name: 'dataset-ondisk-json' dataset: - type: 'local' + type: 'ondisk' path: - 'sample.json' \ No newline at end of file diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index ad55ec867..63b0b8343 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -1,4 +1,12 @@ import os +import unittest +from argparse import Namespace +from contextlib import redirect_stdout +from io import StringIO + +from networkx.classes import is_empty + +from data_juicer.config import init_configs from data_juicer.core.data.dataset_builder import rewrite_cli_datapath from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS @@ -42,4 +50,18 @@ def test_rewrite_cli_datapath_with_weights(self): self.assertEqual( [{'path': ['./data/sample.json'], 'type': 'ondisk', 'weight': 0.5}, {'path': ['./data/sample.txt'], 'type': 'ondisk', 'weight': 1.0}], - ans) \ No newline at end of file + ans) + + def test_dataset_builder_ondisk_config(self): + test_config_file = './data/test_config.yaml' + out = StringIO() + with redirect_stdout(out): + cfg = init_configs(args=f'--config {test_config_file}'.split()) + self.assertIsInstance(cfg, Namespace) + self.assertEqual(cfg.project_name, 'dataset-ondisk-json') + self.assertEqual(cfg.dataset, {'path': ['sample.json'], 'type': 'ondisk'}) + self.assertEqual(not cfg.dataset_path, True) + + +if __name__ == '__main__': + unittest.main()