From b80f9913916dd83d9b42b0b5ded6f187d6b95c5f Mon Sep 17 00:00:00 2001 From: cyruszhang Date: Wed, 11 Dec 2024 13:02:18 -0800 Subject: [PATCH] add more test cases for datapath rewrite logic; fix rewrite to handle space in file name --- data_juicer/core/data/dataset_builder.py | 18 ++- tests/core/test_dataset_builder.py | 163 +++++++++++++++++++++-- 2 files changed, 165 insertions(+), 16 deletions(-) diff --git a/data_juicer/core/data/dataset_builder.py b/data_juicer/core/data/dataset_builder.py index ea7841d34..553761516 100644 --- a/data_juicer/core/data/dataset_builder.py +++ b/data_juicer/core/data/dataset_builder.py @@ -1,4 +1,5 @@ import os +import shlex from typing import List, Tuple, Union from data_juicer.core.data import NestedDataset @@ -78,16 +79,25 @@ def parse_cli_datapath(dataset_path) -> Tuple[List[str], List[float]]: them, e.g. ` ds1.jsonl ds2_dir ds3_file.json` :return: list of dataset path and list of weights """ - data_prefix = dataset_path.split() + # Handle empty input + if not dataset_path or not dataset_path.strip(): + return [], [] + + # Use shlex to properly handle quoted strings + try: + tokens = shlex.split(dataset_path) + except ValueError as e: + raise ValueError(f'Invalid dataset path format: {e}') + prefixes = [] weights = [] - for i in range(len(data_prefix)): + for i in range(len(tokens)): try: - value = max(float(data_prefix[i]), 0.0) + value = max(float(tokens[i]), 0.0) weights.append(value) except: # noqa: E722 - value = data_prefix[i].strip() + value = tokens[i].strip() # if not set weight, use 1.0 as default if i == 0 or len(weights) == len(prefixes): weights.append(1.0) diff --git a/tests/core/test_dataset_builder.py b/tests/core/test_dataset_builder.py index 32bd04e2f..86e0aab0e 100644 --- a/tests/core/test_dataset_builder.py +++ b/tests/core/test_dataset_builder.py @@ -1,35 +1,32 @@ import os import unittest +from unittest.mock import patch from argparse import Namespace from contextlib import redirect_stdout from io import StringIO - -from networkx.classes import is_empty - from data_juicer.config import init_configs -from data_juicer.core.data.dataset_builder import rewrite_cli_datapath +from data_juicer.core.data.dataset_builder import rewrite_cli_datapath, parse_cli_datapath from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS @SKIPPED_TESTS.register_module() class DatasetBuilderTest(DataJuicerTestCaseBase): + def setUp(self): + # Get the directory where this test file is located + test_file_dir = os.path.dirname(os.path.abspath(__file__)) + os.chdir(test_file_dir) + def test_rewrite_cli_datapath_local_single_file(self): dataset_path = "./data/sample.txt" ans = rewrite_cli_datapath(dataset_path) self.assertEqual( - [{'path': ['./data/sample.txt'], 'type': 'ondisk', 'weight': 1.0}], ans) + [{'path': [dataset_path], 'type': 'ondisk', 'weight': 1.0}], ans) def test_rewrite_cli_datapath_local_directory(self): dataset_path = "./data" ans = rewrite_cli_datapath(dataset_path) self.assertEqual( - [{'path': ['./data'], 'type': 'ondisk', 'weight': 1.0}], ans) - - def test_rewrite_cli_datapath_absolute_path(self): - dataset_path = os.curdir + "/data/sample.txt" - ans = rewrite_cli_datapath(dataset_path) - self.assertEqual( - [{'type': 'ondisk', 'path': [dataset_path], 'weight': 1.0}], ans) + [{'path': [dataset_path], 'type': 'ondisk', 'weight': 1.0}], ans) def test_rewrite_cli_datapath_hf(self): dataset_path = "hf-internal-testing/librispeech_asr_dummy" @@ -75,6 +72,148 @@ def test_dataset_builder_ondisk_config_list(self): {'path': ['sample.txt'], 'type': 'ondisk'}]) self.assertEqual(not cfg.dataset_path, True) + @patch('os.path.isdir') + @patch('os.path.isfile') + def test_rewrite_cli_datapath_local_files(self, mock_isfile, mock_isdir): + # Mock os.path.isdir and os.path.isfile to simulate local files + mock_isfile.side_effect = lambda x: x.endswith('.jsonl') + mock_isdir.side_effect = lambda x: x.endswith('_dir') + + dataset_path = "1.0 ds1.jsonl 2.0 ds2_dir 3.0 ds3.jsonl" + expected = [ + {'type': 'ondisk', 'path': ['ds1.jsonl'], 'weight': 1.0}, + {'type': 'ondisk', 'path': ['ds2_dir'], 'weight': 2.0}, + {'type': 'ondisk', 'path': ['ds3.jsonl'], 'weight': 3.0} + ] + result = rewrite_cli_datapath(dataset_path) + self.assertEqual(result, expected) + + def test_rewrite_cli_datapath_huggingface(self): + dataset_path = "1.0 huggingface/dataset" + expected = [ + {'type': 'huggingface', 'path': 'huggingface/dataset', 'split': 'train'} + ] + result = rewrite_cli_datapath(dataset_path) + self.assertEqual(result, expected) + + def test_rewrite_cli_datapath_invalid(self): + dataset_path = "1.0 ./invalid_path" + with self.assertRaises(ValueError): + rewrite_cli_datapath(dataset_path) + + def test_parse_cli_datapath(self): + dataset_path = "1.0 ds1.jsonl 2.0 ds2_dir 3.0 ds3.jsonl" + expected_paths = ['ds1.jsonl', 'ds2_dir', 'ds3.jsonl'] + expected_weights = [1.0, 2.0, 3.0] + paths, weights = parse_cli_datapath(dataset_path) + self.assertEqual(paths, expected_paths) + self.assertEqual(weights, expected_weights) + + def test_parse_cli_datapath_default_weight(self): + dataset_path = "ds1.jsonl ds2_dir 2.0 ds3.jsonl" + expected_paths = ['ds1.jsonl', 'ds2_dir', 'ds3.jsonl'] + expected_weights = [1.0, 1.0, 2.0] + paths, weights = parse_cli_datapath(dataset_path) + self.assertEqual(paths, expected_paths) + self.assertEqual(weights, expected_weights) + + + def test_parse_cli_datapath_edge_cases(self): + # Test various edge cases + test_cases = [ + # Empty string + ("", [], []), + # Single path + ("file.txt", ['file.txt'], [1.0]), + # Multiple spaces between items + ("file1.txt file2.txt", ['file1.txt', 'file2.txt'], [1.0, 1.0]), + # Tab characters + ("file1.txt\tfile2.txt", ['file1.txt', 'file2.txt'], [1.0, 1.0]), + # Paths with spaces in them (quoted) + ('"my file.txt" 1.5 "other file.txt"', + ['my file.txt', 'other file.txt'], + [1.0, 1.5]), + ] + + for input_path, expected_paths, expected_weights in test_cases: + paths, weights = parse_cli_datapath(input_path) + self.assertEqual(paths, expected_paths, + f"Failed paths for input: {input_path}") + self.assertEqual(weights, expected_weights, + f"Failed weights for input: {input_path}") + + def test_parse_cli_datapath_valid_weights(self): + # Test various valid weight formats + test_cases = [ + ("1.0 file.txt", ['file.txt'], [1.0]), + ("1.5 file1.txt 2.0 file2.txt", + ['file1.txt', 'file2.txt'], + [1.5, 2.0]), + ("0.5 file1.txt file2.txt 1.5 file3.txt", + ['file1.txt', 'file2.txt', 'file3.txt'], + [0.5, 1.0, 1.5]), + # Test integer weights + ("1 file.txt", ['file.txt'], [1.0]), + ("2 file1.txt 3 file2.txt", + ['file1.txt', 'file2.txt'], + [2.0, 3.0]), + ] + + for input_path, expected_paths, expected_weights in test_cases: + paths, weights = parse_cli_datapath(input_path) + self.assertEqual(paths, expected_paths, + f"Failed paths for input: {input_path}") + self.assertEqual(weights, expected_weights, + f"Failed weights for input: {input_path}") + + def test_parse_cli_datapath_special_characters(self): + # Test paths with special characters + test_cases = [ + # Paths with hyphens and underscores + ("my-file_1.txt", ['my-file_1.txt'], [1.0]), + # Paths with dots + ("path/to/file.with.dots.txt", ['path/to/file.with.dots.txt'], [1.0]), + # Paths with special characters + ("file#1.txt", ['file#1.txt'], [1.0]), + # Mixed case with weight + ("1.0 Path/To/File.TXT", ['Path/To/File.TXT'], [1.0]), + # Multiple paths with special characters + ("2.0 file#1.txt 3.0 path/to/file-2.txt", + ['file#1.txt', 'path/to/file-2.txt'], + [2.0, 3.0]), + ] + + for input_path, expected_paths, expected_weights in test_cases: + paths, weights = parse_cli_datapath(input_path) + self.assertEqual(paths, expected_paths, + f"Failed paths for input: {input_path}") + self.assertEqual(weights, expected_weights, + f"Failed weights for input: {input_path}") + + def test_parse_cli_datapath_multiple_datasets(self): + # Test multiple datasets with various weight combinations + test_cases = [ + # Multiple datasets with all weights specified + ("0.5 data1.txt 1.5 data2.txt 2.0 data3.txt", + ['data1.txt', 'data2.txt', 'data3.txt'], + [0.5, 1.5, 2.0]), + # Mix of weighted and unweighted datasets + ("data1.txt 1.5 data2.txt data3.txt", + ['data1.txt', 'data2.txt', 'data3.txt'], + [1.0, 1.5, 1.0]), + # Multiple datasets with same weight + ("2.0 data1.txt 2.0 data2.txt 2.0 data3.txt", + ['data1.txt', 'data2.txt', 'data3.txt'], + [2.0, 2.0, 2.0]), + ] + + for input_path, expected_paths, expected_weights in test_cases: + paths, weights = parse_cli_datapath(input_path) + self.assertEqual(paths, expected_paths, + f"Failed paths for input: {input_path}") + self.assertEqual(weights, expected_weights, + f"Failed weights for input: {input_path}") + if __name__ == '__main__': unittest.main()