Skip to content

Commit

Permalink
add more test cases for datapath rewrite logic; fix rewrite to handle…
Browse files Browse the repository at this point in the history
… space in file name
  • Loading branch information
cyruszhang committed Dec 11, 2024
1 parent 940b44d commit b80f991
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 16 deletions.
18 changes: 14 additions & 4 deletions data_juicer/core/data/dataset_builder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import shlex
from typing import List, Tuple, Union

from data_juicer.core.data import NestedDataset
Expand Down Expand Up @@ -78,16 +79,25 @@ def parse_cli_datapath(dataset_path) -> Tuple[List[str], List[float]]:
them, e.g. `<w1> ds1.jsonl <w2> ds2_dir <w3> ds3_file.json`
:return: list of dataset path and list of weights
"""
data_prefix = dataset_path.split()
# Handle empty input
if not dataset_path or not dataset_path.strip():
return [], []

# Use shlex to properly handle quoted strings
try:
tokens = shlex.split(dataset_path)
except ValueError as e:
raise ValueError(f'Invalid dataset path format: {e}')

prefixes = []
weights = []

for i in range(len(data_prefix)):
for i in range(len(tokens)):
try:
value = max(float(data_prefix[i]), 0.0)
value = max(float(tokens[i]), 0.0)
weights.append(value)
except: # noqa: E722
value = data_prefix[i].strip()
value = tokens[i].strip()
# if not set weight, use 1.0 as default
if i == 0 or len(weights) == len(prefixes):
weights.append(1.0)
Expand Down
163 changes: 151 additions & 12 deletions tests/core/test_dataset_builder.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,32 @@
import os
import unittest
from unittest.mock import patch
from argparse import Namespace
from contextlib import redirect_stdout
from io import StringIO

from networkx.classes import is_empty

from data_juicer.config import init_configs
from data_juicer.core.data.dataset_builder import rewrite_cli_datapath
from data_juicer.core.data.dataset_builder import rewrite_cli_datapath, parse_cli_datapath
from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS

@SKIPPED_TESTS.register_module()
class DatasetBuilderTest(DataJuicerTestCaseBase):

def setUp(self):
# Get the directory where this test file is located
test_file_dir = os.path.dirname(os.path.abspath(__file__))
os.chdir(test_file_dir)

def test_rewrite_cli_datapath_local_single_file(self):
dataset_path = "./data/sample.txt"
ans = rewrite_cli_datapath(dataset_path)
self.assertEqual(
[{'path': ['./data/sample.txt'], 'type': 'ondisk', 'weight': 1.0}], ans)
[{'path': [dataset_path], 'type': 'ondisk', 'weight': 1.0}], ans)

def test_rewrite_cli_datapath_local_directory(self):
dataset_path = "./data"
ans = rewrite_cli_datapath(dataset_path)
self.assertEqual(
[{'path': ['./data'], 'type': 'ondisk', 'weight': 1.0}], ans)

def test_rewrite_cli_datapath_absolute_path(self):
dataset_path = os.curdir + "/data/sample.txt"
ans = rewrite_cli_datapath(dataset_path)
self.assertEqual(
[{'type': 'ondisk', 'path': [dataset_path], 'weight': 1.0}], ans)
[{'path': [dataset_path], 'type': 'ondisk', 'weight': 1.0}], ans)

def test_rewrite_cli_datapath_hf(self):
dataset_path = "hf-internal-testing/librispeech_asr_dummy"
Expand Down Expand Up @@ -75,6 +72,148 @@ def test_dataset_builder_ondisk_config_list(self):
{'path': ['sample.txt'], 'type': 'ondisk'}])
self.assertEqual(not cfg.dataset_path, True)

@patch('os.path.isdir')
@patch('os.path.isfile')
def test_rewrite_cli_datapath_local_files(self, mock_isfile, mock_isdir):
# Mock os.path.isdir and os.path.isfile to simulate local files
mock_isfile.side_effect = lambda x: x.endswith('.jsonl')
mock_isdir.side_effect = lambda x: x.endswith('_dir')

dataset_path = "1.0 ds1.jsonl 2.0 ds2_dir 3.0 ds3.jsonl"
expected = [
{'type': 'ondisk', 'path': ['ds1.jsonl'], 'weight': 1.0},
{'type': 'ondisk', 'path': ['ds2_dir'], 'weight': 2.0},
{'type': 'ondisk', 'path': ['ds3.jsonl'], 'weight': 3.0}
]
result = rewrite_cli_datapath(dataset_path)
self.assertEqual(result, expected)

def test_rewrite_cli_datapath_huggingface(self):
dataset_path = "1.0 huggingface/dataset"
expected = [
{'type': 'huggingface', 'path': 'huggingface/dataset', 'split': 'train'}
]
result = rewrite_cli_datapath(dataset_path)
self.assertEqual(result, expected)

def test_rewrite_cli_datapath_invalid(self):
dataset_path = "1.0 ./invalid_path"
with self.assertRaises(ValueError):
rewrite_cli_datapath(dataset_path)

def test_parse_cli_datapath(self):
dataset_path = "1.0 ds1.jsonl 2.0 ds2_dir 3.0 ds3.jsonl"
expected_paths = ['ds1.jsonl', 'ds2_dir', 'ds3.jsonl']
expected_weights = [1.0, 2.0, 3.0]
paths, weights = parse_cli_datapath(dataset_path)
self.assertEqual(paths, expected_paths)
self.assertEqual(weights, expected_weights)

def test_parse_cli_datapath_default_weight(self):
dataset_path = "ds1.jsonl ds2_dir 2.0 ds3.jsonl"
expected_paths = ['ds1.jsonl', 'ds2_dir', 'ds3.jsonl']
expected_weights = [1.0, 1.0, 2.0]
paths, weights = parse_cli_datapath(dataset_path)
self.assertEqual(paths, expected_paths)
self.assertEqual(weights, expected_weights)


def test_parse_cli_datapath_edge_cases(self):
# Test various edge cases
test_cases = [
# Empty string
("", [], []),
# Single path
("file.txt", ['file.txt'], [1.0]),
# Multiple spaces between items
("file1.txt file2.txt", ['file1.txt', 'file2.txt'], [1.0, 1.0]),
# Tab characters
("file1.txt\tfile2.txt", ['file1.txt', 'file2.txt'], [1.0, 1.0]),
# Paths with spaces in them (quoted)
('"my file.txt" 1.5 "other file.txt"',
['my file.txt', 'other file.txt'],
[1.0, 1.5]),
]

for input_path, expected_paths, expected_weights in test_cases:
paths, weights = parse_cli_datapath(input_path)
self.assertEqual(paths, expected_paths,
f"Failed paths for input: {input_path}")
self.assertEqual(weights, expected_weights,
f"Failed weights for input: {input_path}")

def test_parse_cli_datapath_valid_weights(self):
# Test various valid weight formats
test_cases = [
("1.0 file.txt", ['file.txt'], [1.0]),
("1.5 file1.txt 2.0 file2.txt",
['file1.txt', 'file2.txt'],
[1.5, 2.0]),
("0.5 file1.txt file2.txt 1.5 file3.txt",
['file1.txt', 'file2.txt', 'file3.txt'],
[0.5, 1.0, 1.5]),
# Test integer weights
("1 file.txt", ['file.txt'], [1.0]),
("2 file1.txt 3 file2.txt",
['file1.txt', 'file2.txt'],
[2.0, 3.0]),
]

for input_path, expected_paths, expected_weights in test_cases:
paths, weights = parse_cli_datapath(input_path)
self.assertEqual(paths, expected_paths,
f"Failed paths for input: {input_path}")
self.assertEqual(weights, expected_weights,
f"Failed weights for input: {input_path}")

def test_parse_cli_datapath_special_characters(self):
# Test paths with special characters
test_cases = [
# Paths with hyphens and underscores
("my-file_1.txt", ['my-file_1.txt'], [1.0]),
# Paths with dots
("path/to/file.with.dots.txt", ['path/to/file.with.dots.txt'], [1.0]),
# Paths with special characters
("file#1.txt", ['file#1.txt'], [1.0]),
# Mixed case with weight
("1.0 Path/To/File.TXT", ['Path/To/File.TXT'], [1.0]),
# Multiple paths with special characters
("2.0 file#1.txt 3.0 path/to/file-2.txt",
['file#1.txt', 'path/to/file-2.txt'],
[2.0, 3.0]),
]

for input_path, expected_paths, expected_weights in test_cases:
paths, weights = parse_cli_datapath(input_path)
self.assertEqual(paths, expected_paths,
f"Failed paths for input: {input_path}")
self.assertEqual(weights, expected_weights,
f"Failed weights for input: {input_path}")

def test_parse_cli_datapath_multiple_datasets(self):
# Test multiple datasets with various weight combinations
test_cases = [
# Multiple datasets with all weights specified
("0.5 data1.txt 1.5 data2.txt 2.0 data3.txt",
['data1.txt', 'data2.txt', 'data3.txt'],
[0.5, 1.5, 2.0]),
# Mix of weighted and unweighted datasets
("data1.txt 1.5 data2.txt data3.txt",
['data1.txt', 'data2.txt', 'data3.txt'],
[1.0, 1.5, 1.0]),
# Multiple datasets with same weight
("2.0 data1.txt 2.0 data2.txt 2.0 data3.txt",
['data1.txt', 'data2.txt', 'data3.txt'],
[2.0, 2.0, 2.0]),
]

for input_path, expected_paths, expected_weights in test_cases:
paths, weights = parse_cli_datapath(input_path)
self.assertEqual(paths, expected_paths,
f"Failed paths for input: {input_path}")
self.assertEqual(weights, expected_weights,
f"Failed weights for input: {input_path}")


if __name__ == '__main__':
unittest.main()

0 comments on commit b80f991

Please sign in to comment.