Skip to content

Commit

Permalink
fix circular dependency; add dataset config test
Browse files Browse the repository at this point in the history
  • Loading branch information
cyruszhang committed Dec 10, 2024
1 parent 84803cd commit cb5b80a
Showing 8 changed files with 87 additions and 119 deletions.
3 changes: 2 additions & 1 deletion data_juicer/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .adapter import Adapter
from .analyzer import Analyzer
from .data import NestedDataset
from .executor import ExecutorFactory, LocalExecutor, RayExecutor
from .executor import ExecutorBase, ExecutorFactory, LocalExecutor, RayExecutor
from .exporter import Exporter
from .monitor import Monitor
from .tracer import Tracer
@@ -13,6 +13,7 @@
'ExecutorFactory',
'LocalExecutor',
'RayExecutor',
'ExecutorBase',
'Exporter',
'Monitor',
'Tracer',
6 changes: 4 additions & 2 deletions data_juicer/core/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from .dj_dataset import DJDataset, NestedDataset
from .dj_dataset import DJDataset, NestedDataset, wrap_func_with_nested_access
from .ray_dataset import RayDataset

__all__ = ['DJDataset', 'NestedDataset', 'RayDataset']
__all__ = [
'DJDataset', 'NestedDataset', 'RayDataset', 'wrap_func_with_nested_access'
]
4 changes: 2 additions & 2 deletions data_juicer/core/executor/factory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Union

from local_executor import LocalExecutor
from ray_executor import RayExecutor
from .local_executor import LocalExecutor
from .ray_executor import RayExecutor


class ExecutorFactory:
102 changes: 51 additions & 51 deletions data_juicer/core/executor/local_executor.py
Original file line number Diff line number Diff line change
@@ -13,15 +13,13 @@
from data_juicer.core.exporter import Exporter
from data_juicer.core.tracer import Tracer
from data_juicer.format.load import load_formatter
from data_juicer.format.mixture_formatter import MixtureFormatter
from data_juicer.ops import OPERATORS, load_ops
from data_juicer.ops.op_fusion import fuse_operators
from data_juicer.ops.selector.frequency_specified_field_selector import \
FrequencySpecifiedFieldSelector
from data_juicer.ops.selector.topk_specified_field_selector import \
TopkSpecifiedFieldSelector
from data_juicer.ops.selector import (FrequencySpecifiedFieldSelector,
TopkSpecifiedFieldSelector)
from data_juicer.utils import cache_utils
from data_juicer.utils.ckpt_utils import CheckpointManager
from data_juicer.utils.sample import random_sample


class LocalExecutor(ExecutorBase):
@@ -97,52 +95,6 @@ def __init__(self, cfg: Optional[Namespace] = None):
logger.info('Trace for all ops.')
self.op_list_to_trace = set(OPERATORS.modules.keys())

def sample_data(self,
dataset_to_sample: Dataset = None,
load_data_np=None,
sample_ratio: float = 1.0,
sample_algo: str = 'uniform',
**kwargs):
"""
Sample a subset from the given dataset.
:param dataset_to_sample: Dataset to sample from. If None, will use
the formatter linked by the executor. Default is None.
:param load_data_np: number of workers when loading the dataset.
:param sample_ratio: The ratio of the sample size to the original
dataset size. Default is 1.0 (no sampling).
:param sample_algo: Sampling algorithm to use. Options are "uniform",
"frequency_specified_field_selector", or
"topk_specified_field_selector".
Default is "uniform".
:return: A sampled Dataset.
"""
# Determine the dataset to sample from
if dataset_to_sample is not None:
dataset = dataset_to_sample
elif self.cfg.use_checkpoint and self.ckpt_manager.ckpt_available:
logger.info('Loading dataset from checkpoint...')
dataset = self.ckpt_manager.load_ckpt()
elif hasattr(self, 'formatter'):
logger.info('Loading dataset from data formatter...')
if load_data_np is None:
load_data_np = self.cfg.np
dataset = self.formatter.load_dataset(load_data_np, self.cfg)
else:
raise ValueError('No dataset available to sample from.')

# Perform sampling based on the specified algorithm
if sample_algo == 'uniform':
return MixtureFormatter.random_sample(dataset, sample_ratio)
elif sample_algo == 'frequency_specified_field_selector':
dj_op = FrequencySpecifiedFieldSelector(**kwargs)
return dj_op.process(dataset)
elif sample_algo == 'topk_specified_field_selector':
dj_op = TopkSpecifiedFieldSelector(**kwargs)
return dj_op.process(dataset)
else:
raise ValueError(f'Unsupported sample_algo: {sample_algo}')

def run(self,
load_data_np: Optional[PositiveInt] = None,
skip_return=False):
@@ -215,3 +167,51 @@ def run(self,

if not skip_return:
return dataset

def sample_data(self,
dataset_to_sample: Dataset = None,
load_data_np=None,
sample_ratio: float = 1.0,
sample_algo: str = 'uniform',
**kwargs):
"""
Sample a subset from the given dataset.
TODO add support other than LocalExecutor
:param executor: executor
:param dataset_to_sample: Dataset to sample from. If None, will use
the formatter linked by the executor. Default is None.
:param load_data_np: number of workers when loading the dataset.
:param sample_ratio: The ratio of the sample size to the original
dataset size. Default is 1.0 (no sampling).
:param sample_algo: Sampling algorithm to use. Options are "uniform",
"frequency_specified_field_selector", or
"topk_specified_field_selector".
Default is "uniform".
:return: A sampled Dataset.
"""
# Determine the dataset to sample from
if dataset_to_sample is not None:
dataset = dataset_to_sample
elif self.cfg.use_checkpoint and self.ckpt_manager.ckpt_available:
logger.info('Loading dataset from checkpoint...')
dataset = self.ckpt_manager.load_ckpt()
elif hasattr(self, 'formatter'):
logger.info('Loading dataset from data formatter...')
if load_data_np is None:
load_data_np = self.cfg.np
dataset = self.formatter.load_dataset(load_data_np, self.cfg)
else:
raise ValueError('No dataset available to sample from.')

# Perform sampling based on the specified algorithm
if sample_algo == 'uniform':
return random_sample(dataset, sample_ratio)
elif sample_algo == 'frequency_specified_field_selector':
dj_op = FrequencySpecifiedFieldSelector(**kwargs)
return dj_op.process(dataset)
elif sample_algo == 'topk_specified_field_selector':
dj_op = TopkSpecifiedFieldSelector(**kwargs)
return dj_op.process(dataset)
else:
raise ValueError(f'Unsupported sample_algo: {sample_algo}')
9 changes: 3 additions & 6 deletions data_juicer/ops/selector/random_selector.py
Original file line number Diff line number Diff line change
@@ -3,9 +3,8 @@
from pydantic import Field, PositiveInt
from typing_extensions import Annotated

from data_juicer.format.mixture_formatter import MixtureFormatter

from ..base_op import OPERATORS, Selector
from data_juicer.ops.base_op import OPERATORS, Selector
from data_juicer.utils.sample import random_sample


@OPERATORS.register_module('random_selector')
@@ -41,13 +40,11 @@ def process(self, dataset):
if self.select_ratio is None and self.select_num is None:
return dataset

select_num = 0
if not self.select_ratio:
select_num = self.select_num
else:
select_num = int(self.select_ratio * len(dataset))
if self.select_num and self.select_num < select_num:
select_num = self.select_num

return MixtureFormatter.random_sample(dataset,
sample_number=select_num)
return random_sample(dataset, sample_number=select_num)
54 changes: 0 additions & 54 deletions data_juicer/utils/sample.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,6 @@
from itertools import chain, repeat

import numpy as np
from datasets import Dataset
from loguru import logger

from data_juicer.ops.selector import (FrequencySpecifiedFieldSelector,
TopkSpecifiedFieldSelector)


class SamplingMixin:

def sample_data(self,
dataset_to_sample: Dataset = None,
load_data_np=None,
sample_ratio: float = 1.0,
sample_algo: str = 'uniform',
**kwargs):
"""
Sample a subset from the given dataset.
:param dataset_to_sample: Dataset to sample from. If None, will use
the formatter linked by the executor. Default is None.
:param load_data_np: number of workers when loading the dataset.
:param sample_ratio: The ratio of the sample size to the original
dataset size. Default is 1.0 (no sampling).
:param sample_algo: Sampling algorithm to use. Options are "uniform",
"frequency_specified_field_selector", or
"topk_specified_field_selector".
Default is "uniform".
:return: A sampled Dataset.
"""
# Determine the dataset to sample from
if dataset_to_sample is not None:
dataset = dataset_to_sample
elif self.cfg.use_checkpoint and self.ckpt_manager.ckpt_available:
logger.info('Loading dataset from checkpoint...')
dataset = self.ckpt_manager.load_ckpt()
elif hasattr(self, 'formatter'):
logger.info('Loading dataset from data formatter...')
if load_data_np is None:
load_data_np = self.cfg.np
dataset = self.formatter.load_dataset(load_data_np, self.cfg)
else:
raise ValueError('No dataset available to sample from.')

# Perform sampling based on the specified algorithm
if sample_algo == 'uniform':
return random_sample(dataset, sample_ratio)
elif sample_algo == 'frequency_specified_field_selector':
dj_op = FrequencySpecifiedFieldSelector(**kwargs)
return dj_op.process(dataset)
elif sample_algo == 'topk_specified_field_selector':
dj_op = TopkSpecifiedFieldSelector(**kwargs)
return dj_op.process(dataset)
else:
raise ValueError(f'Unsupported sample_algo: {sample_algo}')


def random_sample(dataset, weight=1.0, sample_number=0, seed=None):
4 changes: 2 additions & 2 deletions tests/core/data/test_config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
project_name: 'dataset-local-json'
project_name: 'dataset-ondisk-json'
dataset:
type: 'local'
type: 'ondisk'
path:
- 'sample.json'
24 changes: 23 additions & 1 deletion tests/core/test_dataset_builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
import os
import unittest
from argparse import Namespace
from contextlib import redirect_stdout
from io import StringIO

from networkx.classes import is_empty

from data_juicer.config import init_configs
from data_juicer.core.data.dataset_builder import rewrite_cli_datapath
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS

@@ -42,4 +50,18 @@ def test_rewrite_cli_datapath_with_weights(self):
self.assertEqual(
[{'path': ['./data/sample.json'], 'type': 'ondisk', 'weight': 0.5},
{'path': ['./data/sample.txt'], 'type': 'ondisk', 'weight': 1.0}],
ans)
ans)

def test_dataset_builder_ondisk_config(self):
test_config_file = './data/test_config.yaml'
out = StringIO()
with redirect_stdout(out):
cfg = init_configs(args=f'--config {test_config_file}'.split())
self.assertIsInstance(cfg, Namespace)
self.assertEqual(cfg.project_name, 'dataset-ondisk-json')
self.assertEqual(cfg.dataset, {'path': ['sample.json'], 'type': 'ondisk'})
self.assertEqual(not cfg.dataset_path, True)


if __name__ == '__main__':
unittest.main()

0 comments on commit cb5b80a

Please sign in to comment.