diff --git a/data_juicer/config/config.py b/data_juicer/config/config.py index 06b388de5..661647472 100644 --- a/data_juicer/config/config.py +++ b/data_juicer/config/config.py @@ -5,7 +5,7 @@ import tempfile import time from argparse import ArgumentError, Namespace -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Union import yaml from jsonargparse import (ActionConfigFile, ArgumentParser, dict_to_namespace, @@ -194,7 +194,7 @@ def init_configs(args=None): 'own special token according to your input dataset.') parser.add_argument( '--suffixes', - type=Union[str, List[str], Tuple[str]], + type=Union[str, List[str]], default=[], help='Suffixes of files that will be find and loaded. If not set, we ' 'will find all suffix files, and select a suitable formatter ' diff --git a/data_juicer/format/formatter.py b/data_juicer/format/formatter.py index 8b41c4f26..2a8cd99ed 100644 --- a/data_juicer/format/formatter.py +++ b/data_juicer/format/formatter.py @@ -1,5 +1,5 @@ import os -from typing import List, Tuple, Union +from typing import List, Union from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset from loguru import logger @@ -27,7 +27,7 @@ def __init__( self, dataset_path: str, type: str, - suffixes: Union[str, List[str], Tuple[str]] = None, + suffixes: Union[str, List[str], None] = None, text_keys: List[str] = None, add_suffix=False, **kwargs, diff --git a/data_juicer/format/mixture_formatter.py b/data_juicer/format/mixture_formatter.py index fffbea672..6c13bdd7c 100644 --- a/data_juicer/format/mixture_formatter.py +++ b/data_juicer/format/mixture_formatter.py @@ -1,5 +1,5 @@ from itertools import chain, repeat -from typing import List, Tuple, Union +from typing import List, Union import numpy as np from datasets import Dataset, concatenate_datasets @@ -15,7 +15,7 @@ class MixtureFormatter(BaseFormatter): def __init__(self, dataset_path: str, - suffixes: Union[str, List[str], Tuple[str]] = None, + suffixes: Union[str, List[str], None] = None, text_keys=None, add_suffix=False, max_samples=None, diff --git a/data_juicer/ops/deduplicator/document_minhash_deduplicator.py b/data_juicer/ops/deduplicator/document_minhash_deduplicator.py index b0c5e0a51..6fa47c869 100644 --- a/data_juicer/ops/deduplicator/document_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/document_minhash_deduplicator.py @@ -5,12 +5,14 @@ import hashlib import struct from collections import defaultdict +from typing import Optional import numpy as np import regex -from jsonargparse.typing import ClosedUnitInterval, PositiveInt from loguru import logger +from pydantic import Field, PositiveInt from tqdm import tqdm +from typing_extensions import Annotated from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import HashKeys @@ -109,12 +111,12 @@ def __init__( tokenization: str = 'space', window_size: PositiveInt = 5, lowercase: bool = True, - ignore_pattern: str = None, + ignore_pattern: Optional[str] = None, num_permutations: PositiveInt = 256, - jaccard_threshold: ClosedUnitInterval = 0.7, - num_bands: PositiveInt = None, - num_rows_per_band: PositiveInt = None, - tokenizer_model: str = None, + jaccard_threshold: Annotated[float, Field(ge=0, le=1)] = 0.7, + num_bands: Optional[PositiveInt] = None, + num_rows_per_band: Optional[PositiveInt] = None, + tokenizer_model: Optional[str] = None, *args, **kwargs, ): diff --git a/data_juicer/ops/deduplicator/document_simhash_deduplicator.py b/data_juicer/ops/deduplicator/document_simhash_deduplicator.py index 0eaad8edc..b536bca95 100644 --- a/data_juicer/ops/deduplicator/document_simhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/document_simhash_deduplicator.py @@ -3,12 +3,12 @@ # -------------------------------------------------------- from collections import defaultdict, deque -from typing import Dict, Set +from typing import Dict, Optional, Set import numpy as np import regex -from jsonargparse.typing import PositiveInt from loguru import logger +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import HashKeys @@ -30,7 +30,7 @@ def __init__(self, tokenization: str = 'space', window_size: PositiveInt = 6, lowercase: bool = True, - ignore_pattern: str = None, + ignore_pattern: Optional[str] = None, num_blocks: PositiveInt = 6, hamming_distance: PositiveInt = 4, *args, diff --git a/data_juicer/ops/deduplicator/image_deduplicator.py b/data_juicer/ops/deduplicator/image_deduplicator.py index ab3d7fbc9..828fab87f 100644 --- a/data_juicer/ops/deduplicator/image_deduplicator.py +++ b/data_juicer/ops/deduplicator/image_deduplicator.py @@ -104,7 +104,7 @@ def process(self, dataset, show_num=0): if show_num > 0: # sample duplicate pairs if self.consider_text: - hash2ids: Dict[Tuple[int], Set[int]] = defaultdict(set) + hash2ids: Dict[Tuple[int, int], Set[int]] = defaultdict(set) hashes = zip(dataset[HashKeys.imagehash], dataset[HashKeys.hash]) else: diff --git a/data_juicer/ops/deduplicator/ray_basic_deduplicator.py b/data_juicer/ops/deduplicator/ray_basic_deduplicator.py index 3eb453f9a..f8c40525e 100644 --- a/data_juicer/ops/deduplicator/ray_basic_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_basic_deduplicator.py @@ -1,6 +1,6 @@ from typing import Any -from jsonargparse.typing import PositiveInt +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import HashKeys diff --git a/data_juicer/ops/deduplicator/ray_document_deduplicator.py b/data_juicer/ops/deduplicator/ray_document_deduplicator.py index e12eb149f..ce5cced4e 100644 --- a/data_juicer/ops/deduplicator/ray_document_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_document_deduplicator.py @@ -2,7 +2,7 @@ import string import regex as re -from jsonargparse.typing import PositiveInt +from pydantic import PositiveInt from ..base_op import OPERATORS from .ray_basic_deduplicator import RayBasicDeduplicator diff --git a/data_juicer/ops/deduplicator/ray_image_deduplicator.py b/data_juicer/ops/deduplicator/ray_image_deduplicator.py index 10530c48b..038af481f 100644 --- a/data_juicer/ops/deduplicator/ray_image_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_image_deduplicator.py @@ -1,5 +1,5 @@ import numpy as np -from jsonargparse.typing import PositiveInt +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.mm_utils import load_data_with_context, load_image diff --git a/data_juicer/ops/deduplicator/ray_video_deduplicator.py b/data_juicer/ops/deduplicator/ray_video_deduplicator.py index 7193e9313..902ca1979 100644 --- a/data_juicer/ops/deduplicator/ray_video_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_video_deduplicator.py @@ -1,6 +1,6 @@ import hashlib -from jsonargparse.typing import PositiveInt +from pydantic import PositiveInt from data_juicer.utils.mm_utils import (close_video, load_data_with_context, load_video) diff --git a/data_juicer/ops/deduplicator/video_deduplicator.py b/data_juicer/ops/deduplicator/video_deduplicator.py index ed5a767a4..63b28310e 100644 --- a/data_juicer/ops/deduplicator/video_deduplicator.py +++ b/data_juicer/ops/deduplicator/video_deduplicator.py @@ -85,7 +85,7 @@ def process(self, dataset, show_num=0): if show_num > 0: # sample duplicate pairs if self.consider_text: - hash2ids: Dict[Tuple[int], Set[int]] = defaultdict(set) + hash2ids: Dict[Tuple[int, int], Set[int]] = defaultdict(set) hashes = zip(dataset[HashKeys.videohash], dataset[HashKeys.hash]) else: diff --git a/data_juicer/ops/filter/alphanumeric_filter.py b/data_juicer/ops/filter/alphanumeric_filter.py index 88e93c534..e1cf90927 100644 --- a/data_juicer/ops/filter/alphanumeric_filter.py +++ b/data_juicer/ops/filter/alphanumeric_filter.py @@ -1,7 +1,5 @@ import sys -from jsonargparse.typing import PositiveFloat - from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.model_utils import get_model, prepare_model @@ -23,7 +21,7 @@ class AlphanumericFilter(Filter): def __init__(self, tokenization: bool = False, min_ratio: float = 0.25, - max_ratio: PositiveFloat = sys.maxsize, + max_ratio: float = sys.maxsize, *args, **kwargs): """ diff --git a/data_juicer/ops/filter/audio_duration_filter.py b/data_juicer/ops/filter/audio_duration_filter.py index f9860855e..cf70206f6 100644 --- a/data_juicer/ops/filter/audio_duration_filter.py +++ b/data_juicer/ops/filter/audio_duration_filter.py @@ -2,7 +2,6 @@ import librosa import numpy as np -from jsonargparse.typing import NonNegativeInt from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.mm_utils import load_audio, load_data_with_context @@ -20,8 +19,8 @@ class AudioDurationFilter(Filter): """ def __init__(self, - min_duration: NonNegativeInt = 0, - max_duration: NonNegativeInt = sys.maxsize, + min_duration: int = 0, + max_duration: int = sys.maxsize, any_or_all: str = 'any', *args, **kwargs): diff --git a/data_juicer/ops/filter/audio_nmf_snr_filter.py b/data_juicer/ops/filter/audio_nmf_snr_filter.py index fc4952c46..1ae16c5f8 100644 --- a/data_juicer/ops/filter/audio_nmf_snr_filter.py +++ b/data_juicer/ops/filter/audio_nmf_snr_filter.py @@ -2,8 +2,8 @@ import librosa import numpy as np -from jsonargparse.typing import PositiveInt from librosa.decompose import decompose +from pydantic import PositiveInt from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.mm_utils import load_audio, load_data_with_context diff --git a/data_juicer/ops/filter/average_line_length_filter.py b/data_juicer/ops/filter/average_line_length_filter.py index 0cacd84f0..079a6d9f3 100644 --- a/data_juicer/ops/filter/average_line_length_filter.py +++ b/data_juicer/ops/filter/average_line_length_filter.py @@ -1,7 +1,5 @@ import sys -from jsonargparse.typing import PositiveInt - from data_juicer.utils.constant import Fields, InterVars, StatsKeys from ..base_op import OPERATORS, Filter @@ -15,8 +13,8 @@ class AverageLineLengthFilter(Filter): range.""" def __init__(self, - min_len: PositiveInt = 10, - max_len: PositiveInt = sys.maxsize, + min_len: int = 10, + max_len: int = sys.maxsize, *args, **kwargs): """ diff --git a/data_juicer/ops/filter/character_repetition_filter.py b/data_juicer/ops/filter/character_repetition_filter.py index e67423030..1fb6949ff 100644 --- a/data_juicer/ops/filter/character_repetition_filter.py +++ b/data_juicer/ops/filter/character_repetition_filter.py @@ -3,7 +3,7 @@ # -------------------------------------------------------- import numpy as np -from jsonargparse.typing import ClosedUnitInterval, PositiveInt +from pydantic import PositiveInt from data_juicer.utils.constant import Fields, StatsKeys @@ -17,8 +17,8 @@ class CharacterRepetitionFilter(Filter): def __init__(self, rep_len: PositiveInt = 10, - min_ratio: ClosedUnitInterval = 0.0, - max_ratio: ClosedUnitInterval = 0.5, + min_ratio: float = 0.0, + max_ratio: float = 0.5, *args, **kwargs): """ diff --git a/data_juicer/ops/filter/flagged_words_filter.py b/data_juicer/ops/filter/flagged_words_filter.py index c63036914..84aa96036 100644 --- a/data_juicer/ops/filter/flagged_words_filter.py +++ b/data_juicer/ops/filter/flagged_words_filter.py @@ -2,7 +2,9 @@ # https://huggingface.co/spaces/huggingface/text-data-filtering # -------------------------------------------------------- -from jsonargparse.typing import ClosedUnitInterval, List +from typing import List + +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, InterVars, StatsKeys @@ -29,10 +31,10 @@ class FlaggedWordFilter(Filter): def __init__(self, lang: str = 'en', tokenization: bool = False, - max_ratio: ClosedUnitInterval = 0.045, + max_ratio: float = 0.045, flagged_words_dir: str = ASSET_DIR, use_words_aug: bool = False, - words_aug_group_sizes: List = [2], + words_aug_group_sizes: List[PositiveInt] = [2], words_aug_join_char: str = '', *args, **kwargs): diff --git a/data_juicer/ops/filter/image_aesthetics_filter.py b/data_juicer/ops/filter/image_aesthetics_filter.py index 337fe0da6..bc6a2df19 100644 --- a/data_juicer/ops/filter/image_aesthetics_filter.py +++ b/data_juicer/ops/filter/image_aesthetics_filter.py @@ -1,5 +1,4 @@ import numpy as np -from jsonargparse.typing import ClosedUnitInterval from loguru import logger from data_juicer.utils.availability_utils import AvailabilityChecking @@ -32,10 +31,10 @@ class ImageAestheticsFilter(Filter): _accelerator = 'cuda' def __init__(self, - hf_scorer_model='', - trust_remote_code=False, - min_score: ClosedUnitInterval = 0.5, - max_score: ClosedUnitInterval = 1.0, + hf_scorer_model: str = '', + trust_remote_code: bool = False, + min_score: float = 0.5, + max_score: float = 1.0, any_or_all: str = 'any', *args, **kwargs): diff --git a/data_juicer/ops/filter/image_aspect_ratio_filter.py b/data_juicer/ops/filter/image_aspect_ratio_filter.py index 6fa7db2a8..211a40eee 100644 --- a/data_juicer/ops/filter/image_aspect_ratio_filter.py +++ b/data_juicer/ops/filter/image_aspect_ratio_filter.py @@ -1,5 +1,4 @@ import numpy as np -from jsonargparse.typing import PositiveFloat from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.mm_utils import load_data_with_context, load_image @@ -16,8 +15,8 @@ class ImageAspectRatioFilter(Filter): """ def __init__(self, - min_ratio: PositiveFloat = 0.333, - max_ratio: PositiveFloat = 3.0, + min_ratio: float = 0.333, + max_ratio: float = 3.0, any_or_all: str = 'any', *args, **kwargs): diff --git a/data_juicer/ops/filter/image_face_ratio_filter.py b/data_juicer/ops/filter/image_face_ratio_filter.py index 91ad63e6a..2b5d06677 100644 --- a/data_juicer/ops/filter/image_face_ratio_filter.py +++ b/data_juicer/ops/filter/image_face_ratio_filter.py @@ -1,7 +1,6 @@ import os import numpy as np -from jsonargparse.typing import ClosedUnitInterval from loguru import logger from data_juicer.utils.availability_utils import AvailabilityChecking @@ -34,9 +33,9 @@ class ImageFaceRatioFilter(Filter): } def __init__(self, - cv_classifier='', - min_ratio: ClosedUnitInterval = 0.0, - max_ratio: ClosedUnitInterval = 0.4, + cv_classifier: str = '', + min_ratio: float = 0.0, + max_ratio: float = 0.4, any_or_all: str = 'any', *args, **kwargs): diff --git a/data_juicer/ops/filter/image_nsfw_filter.py b/data_juicer/ops/filter/image_nsfw_filter.py index fd3a2d980..81f878b5f 100644 --- a/data_juicer/ops/filter/image_nsfw_filter.py +++ b/data_juicer/ops/filter/image_nsfw_filter.py @@ -1,5 +1,4 @@ import numpy as np -from jsonargparse.typing import ClosedUnitInterval from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys @@ -27,9 +26,9 @@ class ImageNSFWFilter(Filter): _accelerator = 'cuda' def __init__(self, - hf_nsfw_model='Falconsai/nsfw_image_detection', - trust_remote_code=False, - score_threshold: ClosedUnitInterval = 0.5, + hf_nsfw_model: str = 'Falconsai/nsfw_image_detection', + trust_remote_code: bool = False, + score_threshold: float = 0.5, any_or_all: str = 'any', *args, **kwargs): diff --git a/data_juicer/ops/filter/image_shape_filter.py b/data_juicer/ops/filter/image_shape_filter.py index 0a1d333a7..c6de1b4bc 100644 --- a/data_juicer/ops/filter/image_shape_filter.py +++ b/data_juicer/ops/filter/image_shape_filter.py @@ -1,7 +1,6 @@ import sys import numpy as np -from jsonargparse.typing import PositiveInt from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.mm_utils import load_data_with_context, load_image @@ -17,10 +16,10 @@ class ImageShapeFilter(Filter): """ def __init__(self, - min_width: PositiveInt = 1, - max_width: PositiveInt = sys.maxsize, - min_height: PositiveInt = 1, - max_height: PositiveInt = sys.maxsize, + min_width: int = 1, + max_width: int = sys.maxsize, + min_height: int = 1, + max_height: int = sys.maxsize, any_or_all: str = 'any', *args, **kwargs): diff --git a/data_juicer/ops/filter/image_text_matching_filter.py b/data_juicer/ops/filter/image_text_matching_filter.py index 572a9ffdb..d5c6ad87c 100644 --- a/data_juicer/ops/filter/image_text_matching_filter.py +++ b/data_juicer/ops/filter/image_text_matching_filter.py @@ -1,5 +1,4 @@ import numpy as np -from jsonargparse.typing import ClosedUnitInterval from PIL import ImageOps from data_juicer.utils.availability_utils import AvailabilityChecking @@ -30,10 +29,10 @@ class ImageTextMatchingFilter(Filter): _accelerator = 'cuda' def __init__(self, - hf_blip='Salesforce/blip-itm-base-coco', - trust_remote_code=False, - min_score: ClosedUnitInterval = 0.003, - max_score: ClosedUnitInterval = 1.0, + hf_blip: str = 'Salesforce/blip-itm-base-coco', + trust_remote_code: bool = False, + min_score: float = 0.003, + max_score: float = 1.0, horizontal_flip: bool = False, vertical_flip: bool = False, any_or_all: str = 'any', diff --git a/data_juicer/ops/filter/image_text_similarity_filter.py b/data_juicer/ops/filter/image_text_similarity_filter.py index 093777bbe..f6d2a0658 100644 --- a/data_juicer/ops/filter/image_text_similarity_filter.py +++ b/data_juicer/ops/filter/image_text_similarity_filter.py @@ -1,5 +1,4 @@ import numpy as np -from jsonargparse.typing import ClosedUnitInterval from PIL import ImageOps from data_juicer.utils.availability_utils import AvailabilityChecking @@ -31,10 +30,10 @@ class ImageTextSimilarityFilter(Filter): _accelerator = 'cuda' def __init__(self, - hf_clip='openai/clip-vit-base-patch32', - trust_remote_code=False, - min_score: ClosedUnitInterval = 0.1, - max_score: ClosedUnitInterval = 1.0, + hf_clip: str = 'openai/clip-vit-base-patch32', + trust_remote_code: bool = False, + min_score: float = 0.1, + max_score: float = 1.0, horizontal_flip: bool = False, vertical_flip: bool = False, any_or_all: str = 'any', diff --git a/data_juicer/ops/filter/image_watermark_filter.py b/data_juicer/ops/filter/image_watermark_filter.py index d88e4a9e4..620e80a09 100644 --- a/data_juicer/ops/filter/image_watermark_filter.py +++ b/data_juicer/ops/filter/image_watermark_filter.py @@ -1,5 +1,4 @@ import numpy as np -from jsonargparse.typing import ClosedUnitInterval from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys @@ -30,9 +29,9 @@ class ImageWatermarkFilter(Filter): _accelerator = 'cuda' def __init__(self, - hf_watermark_model='amrul-hzz/watermark_detector', - trust_remote_code=False, - prob_threshold: ClosedUnitInterval = 0.8, + hf_watermark_model: str = 'amrul-hzz/watermark_detector', + trust_remote_code: bool = False, + prob_threshold: float = 0.8, any_or_all: str = 'any', *args, **kwargs): diff --git a/data_juicer/ops/filter/language_id_score_filter.py b/data_juicer/ops/filter/language_id_score_filter.py index 6b71cf112..69283cf8a 100644 --- a/data_juicer/ops/filter/language_id_score_filter.py +++ b/data_juicer/ops/filter/language_id_score_filter.py @@ -1,6 +1,5 @@ -from typing import List, Tuple, Union +from typing import List, Union -from jsonargparse.typing import ClosedUnitInterval from loguru import logger from data_juicer.utils.availability_utils import AvailabilityChecking @@ -21,8 +20,8 @@ class LanguageIDScoreFilter(Filter): larger than a specific min value.""" def __init__(self, - lang: Union[str, List[str], Tuple[str]] = '', - min_score: ClosedUnitInterval = 0.8, + lang: Union[str, List[str]] = '', + min_score: float = 0.8, *args, **kwargs): """ diff --git a/data_juicer/ops/filter/maximum_line_length_filter.py b/data_juicer/ops/filter/maximum_line_length_filter.py index dab086a0a..2f2a4513e 100644 --- a/data_juicer/ops/filter/maximum_line_length_filter.py +++ b/data_juicer/ops/filter/maximum_line_length_filter.py @@ -1,7 +1,5 @@ import sys -from jsonargparse.typing import PositiveInt - from data_juicer.utils.constant import Fields, InterVars, StatsKeys from ..base_op import OPERATORS, Filter @@ -15,8 +13,8 @@ class MaximumLineLengthFilter(Filter): range.""" def __init__(self, - min_len: PositiveInt = 10, - max_len: PositiveInt = sys.maxsize, + min_len: int = 10, + max_len: int = sys.maxsize, *args, **kwargs): """ diff --git a/data_juicer/ops/filter/perplexity_filter.py b/data_juicer/ops/filter/perplexity_filter.py index 64408b872..5d0b396f9 100644 --- a/data_juicer/ops/filter/perplexity_filter.py +++ b/data_juicer/ops/filter/perplexity_filter.py @@ -2,8 +2,6 @@ # https://huggingface.co/spaces/huggingface/text-data-filtering # -------------------------------------------------------- -from jsonargparse.typing import PositiveFloat - from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, InterVars, StatsKeys from data_juicer.utils.model_utils import get_model, prepare_model @@ -27,7 +25,7 @@ class PerplexityFilter(Filter): def __init__(self, lang: str = 'en', - max_ppl: PositiveFloat = 1500, + max_ppl: float = 1500, *args, **kwargs): """ diff --git a/data_juicer/ops/filter/phrase_grounding_recall_filter.py b/data_juicer/ops/filter/phrase_grounding_recall_filter.py index 9fa0498fb..ad7afe902 100644 --- a/data_juicer/ops/filter/phrase_grounding_recall_filter.py +++ b/data_juicer/ops/filter/phrase_grounding_recall_filter.py @@ -1,7 +1,6 @@ from typing import List import numpy as np -from jsonargparse.typing import ClosedUnitInterval from loguru import logger from PIL import ImageOps @@ -77,17 +76,17 @@ class PhraseGroundingRecallFilter(Filter): _accelerator = 'cuda' def __init__(self, - hf_owlvit='google/owlvit-base-patch32', - trust_remote_code=False, - min_recall: ClosedUnitInterval = 0.1, - max_recall: ClosedUnitInterval = 1.0, + hf_owlvit: str = 'google/owlvit-base-patch32', + trust_remote_code: bool = False, + min_recall: float = 0.1, + max_recall: float = 1.0, horizontal_flip: bool = False, vertical_flip: bool = False, any_or_all: str = 'any', reduce_mode: str = 'avg', - iou_thr: ClosedUnitInterval = 0.5, - large_area_ratio_thr: ClosedUnitInterval = 0.95, - conf_thr: ClosedUnitInterval = 0.0, + iou_thr: float = 0.5, + large_area_ratio_thr: float = 0.95, + conf_thr: float = 0.0, *args, **kwargs): """ diff --git a/data_juicer/ops/filter/special_characters_filter.py b/data_juicer/ops/filter/special_characters_filter.py index 3b1b1a893..dc9ef1ed6 100644 --- a/data_juicer/ops/filter/special_characters_filter.py +++ b/data_juicer/ops/filter/special_characters_filter.py @@ -2,8 +2,6 @@ # https://huggingface.co/spaces/huggingface/text-data-filtering # -------------------------------------------------------- -from jsonargparse.typing import ClosedUnitInterval - from data_juicer.utils.constant import Fields, StatsKeys from ..base_op import OPERATORS, Filter @@ -16,8 +14,8 @@ class SpecialCharactersFilter(Filter): range.""" def __init__(self, - min_ratio: ClosedUnitInterval = 0.0, - max_ratio: ClosedUnitInterval = 0.25, + min_ratio: float = 0.0, + max_ratio: float = 0.25, *args, **kwargs): """ diff --git a/data_juicer/ops/filter/specified_field_filter.py b/data_juicer/ops/filter/specified_field_filter.py index 35ea6ac9f..7f79a98b8 100644 --- a/data_juicer/ops/filter/specified_field_filter.py +++ b/data_juicer/ops/filter/specified_field_filter.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Union +from typing import List from ..base_op import OPERATORS, Filter @@ -14,7 +14,7 @@ class SpecifiedFieldFilter(Filter): def __init__(self, field_key: str = '', - target_value: Union[List, Tuple] = [], + target_value: List = [], *args, **kwargs): """ diff --git a/data_juicer/ops/filter/stopwords_filter.py b/data_juicer/ops/filter/stopwords_filter.py index f61542e13..57dd138d1 100644 --- a/data_juicer/ops/filter/stopwords_filter.py +++ b/data_juicer/ops/filter/stopwords_filter.py @@ -2,7 +2,9 @@ # https://huggingface.co/spaces/huggingface/text-data-filtering # -------------------------------------------------------- -from jsonargparse.typing import ClosedUnitInterval, List +from typing import List + +from pydantic import PositiveInt from data_juicer.utils.asset_utils import ASSET_DIR, load_words_asset from data_juicer.utils.availability_utils import AvailabilityChecking @@ -29,10 +31,10 @@ class StopWordsFilter(Filter): def __init__(self, lang: str = 'en', tokenization: bool = False, - min_ratio: ClosedUnitInterval = 0.3, + min_ratio: float = 0.3, stopwords_dir: str = ASSET_DIR, use_words_aug: bool = False, - words_aug_group_sizes: List = [2], + words_aug_group_sizes: List[PositiveInt] = [2], words_aug_join_char: str = '', *args, **kwargs): diff --git a/data_juicer/ops/filter/suffix_filter.py b/data_juicer/ops/filter/suffix_filter.py index 0c95e9c4c..52a833691 100644 --- a/data_juicer/ops/filter/suffix_filter.py +++ b/data_juicer/ops/filter/suffix_filter.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Union +from typing import List, Union from data_juicer.utils.constant import Fields @@ -9,10 +9,7 @@ class SuffixFilter(Filter): """Filter to keep samples with specified suffix.""" - def __init__(self, - suffixes: Union[str, List[str], Tuple[str]] = [], - *args, - **kwargs): + def __init__(self, suffixes: Union[str, List[str]] = [], *args, **kwargs): """ Initialization method. diff --git a/data_juicer/ops/filter/text_length_filter.py b/data_juicer/ops/filter/text_length_filter.py index 9d3170bfc..6fa966889 100644 --- a/data_juicer/ops/filter/text_length_filter.py +++ b/data_juicer/ops/filter/text_length_filter.py @@ -1,7 +1,5 @@ import sys -from jsonargparse.typing import PositiveInt - from data_juicer.utils.constant import Fields, StatsKeys from ..base_op import OPERATORS, Filter @@ -13,8 +11,8 @@ class TextLengthFilter(Filter): range.""" def __init__(self, - min_len: PositiveInt = 10, - max_len: PositiveInt = sys.maxsize, + min_len: int = 10, + max_len: int = sys.maxsize, *args, **kwargs): """ diff --git a/data_juicer/ops/filter/token_num_filter.py b/data_juicer/ops/filter/token_num_filter.py index 793a103b2..d3a31c338 100644 --- a/data_juicer/ops/filter/token_num_filter.py +++ b/data_juicer/ops/filter/token_num_filter.py @@ -1,7 +1,5 @@ import sys -from jsonargparse.typing import PositiveInt - from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.model_utils import get_model, prepare_model @@ -22,8 +20,8 @@ class TokenNumFilter(Filter): def __init__(self, hf_tokenizer: str = 'EleutherAI/pythia-6.9b-deduped', - min_num: PositiveInt = 10, - max_num: PositiveInt = sys.maxsize, + min_num: int = 10, + max_num: int = sys.maxsize, *args, **kwargs): """ diff --git a/data_juicer/ops/filter/video_aesthetics_filter.py b/data_juicer/ops/filter/video_aesthetics_filter.py index ddb13aa4f..69129b60d 100644 --- a/data_juicer/ops/filter/video_aesthetics_filter.py +++ b/data_juicer/ops/filter/video_aesthetics_filter.py @@ -1,6 +1,6 @@ import numpy as np -from jsonargparse.typing import ClosedUnitInterval, PositiveInt from loguru import logger +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys @@ -36,10 +36,10 @@ class VideoAestheticsFilter(Filter): _accelerator = 'cuda' def __init__(self, - hf_scorer_model='', - trust_remote_code=False, - min_score: ClosedUnitInterval = 0.4, - max_score: ClosedUnitInterval = 1.0, + hf_scorer_model: str = '', + trust_remote_code: bool = False, + min_score: float = 0.4, + max_score: float = 1.0, frame_sampling_method: str = 'uniform', frame_num: PositiveInt = 3, any_or_all: str = 'any', diff --git a/data_juicer/ops/filter/video_duration_filter.py b/data_juicer/ops/filter/video_duration_filter.py index a224e0dd0..1cccf87c7 100644 --- a/data_juicer/ops/filter/video_duration_filter.py +++ b/data_juicer/ops/filter/video_duration_filter.py @@ -1,7 +1,6 @@ import sys import numpy as np -from jsonargparse.typing import NonNegativeFloat from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.mm_utils import (close_video, load_data_with_context, @@ -20,8 +19,8 @@ class VideoDurationFilter(Filter): """ def __init__(self, - min_duration: NonNegativeFloat = 0, - max_duration: NonNegativeFloat = sys.maxsize, + min_duration: float = 0, + max_duration: float = sys.maxsize, any_or_all: str = 'any', *args, **kwargs): diff --git a/data_juicer/ops/filter/video_frames_text_similarity_filter.py b/data_juicer/ops/filter/video_frames_text_similarity_filter.py index 0b080076f..eae51f66a 100644 --- a/data_juicer/ops/filter/video_frames_text_similarity_filter.py +++ b/data_juicer/ops/filter/video_frames_text_similarity_filter.py @@ -1,6 +1,6 @@ import numpy as np -from jsonargparse.typing import ClosedUnitInterval, PositiveInt from PIL import ImageOps +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys @@ -37,8 +37,8 @@ class VideoFramesTextSimilarityFilter(Filter): def __init__(self, hf_clip='openai/clip-vit-base-patch32', trust_remote_code=False, - min_score: ClosedUnitInterval = 0.1, - max_score: ClosedUnitInterval = 1.0, + min_score: float = 0.1, + max_score: float = 1.0, frame_sampling_method: str = 'all_keyframes', frame_num: PositiveInt = 3, horizontal_flip: bool = False, diff --git a/data_juicer/ops/filter/video_motion_score_filter.py b/data_juicer/ops/filter/video_motion_score_filter.py index 76ddaf4fe..daf94f273 100644 --- a/data_juicer/ops/filter/video_motion_score_filter.py +++ b/data_juicer/ops/filter/video_motion_score_filter.py @@ -1,9 +1,9 @@ import sys from contextlib import contextmanager -from typing import List, Optional, Sequence, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np -from jsonargparse.typing import PositiveFloat, PositiveInt +from pydantic import PositiveFloat, PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys @@ -46,8 +46,8 @@ def __init__(self, min_score: float = 0.25, max_score: float = sys.float_info.max, sampling_fps: PositiveFloat = 2, - size: Optional[Union[PositiveInt, - Sequence[PositiveInt]]] = None, + size: Union[PositiveInt, Tuple[PositiveInt], + Tuple[PositiveInt, PositiveInt], None] = None, max_size: Optional[PositiveInt] = None, relative: bool = False, any_or_all: str = 'any', @@ -90,7 +90,7 @@ def __init__(self, f'Size must be an int or a 1 or 2 element tuple/list,' f'not a {len(size)} element tuple/list.') if isinstance(size, int): - size = [size] + size = (size, ) self.size = size self.max_size = max_size self.relative = relative @@ -202,9 +202,9 @@ def process(self, sample): def _compute_resized_output_size( frame_size: Tuple[int, int], - size: Optional[List[int]], + size: Union[Tuple[PositiveInt], Tuple[PositiveInt, PositiveInt]], max_size: Optional[int] = None, -) -> List[int]: +) -> Tuple[int, int]: h, w = frame_size short, long = (w, h) if w <= h else (h, w) diff --git a/data_juicer/ops/filter/video_nsfw_filter.py b/data_juicer/ops/filter/video_nsfw_filter.py index c2c61084f..8ce40c045 100644 --- a/data_juicer/ops/filter/video_nsfw_filter.py +++ b/data_juicer/ops/filter/video_nsfw_filter.py @@ -1,5 +1,5 @@ import numpy as np -from jsonargparse.typing import ClosedUnitInterval, PositiveInt +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys @@ -31,9 +31,9 @@ class VideoNSFWFilter(Filter): _accelerator = 'cuda' def __init__(self, - hf_nsfw_model='Falconsai/nsfw_image_detection', - trust_remote_code=False, - score_threshold: ClosedUnitInterval = 0.5, + hf_nsfw_model: str = 'Falconsai/nsfw_image_detection', + trust_remote_code: bool = False, + score_threshold: float = 0.5, frame_sampling_method: str = 'all_keyframes', frame_num: PositiveInt = 3, reduce_mode: str = 'avg', diff --git a/data_juicer/ops/filter/video_ocr_area_ratio_filter.py b/data_juicer/ops/filter/video_ocr_area_ratio_filter.py index cbece9331..c0a3f1c65 100644 --- a/data_juicer/ops/filter/video_ocr_area_ratio_filter.py +++ b/data_juicer/ops/filter/video_ocr_area_ratio_filter.py @@ -1,7 +1,7 @@ from typing import List, Union import numpy as np -from jsonargparse.typing import ClosedUnitInterval, PositiveInt +from pydantic import PositiveInt from data_juicer import cuda_device_count from data_juicer.utils.availability_utils import AvailabilityChecking @@ -43,8 +43,8 @@ class VideoOcrAreaRatioFilter(Filter): _accelerator = 'cuda' def __init__(self, - min_area_ratio: ClosedUnitInterval = 0, - max_area_ratio: ClosedUnitInterval = 1.0, + min_area_ratio: float = 0, + max_area_ratio: float = 1.0, frame_sample_num: PositiveInt = 3, languages_to_detect: Union[str, List[str]] = ['ch_sim', 'en'], any_or_all: str = 'any', diff --git a/data_juicer/ops/filter/video_resolution_filter.py b/data_juicer/ops/filter/video_resolution_filter.py index f87aae4ca..61e5d13cd 100644 --- a/data_juicer/ops/filter/video_resolution_filter.py +++ b/data_juicer/ops/filter/video_resolution_filter.py @@ -1,7 +1,6 @@ import sys import numpy as np -from jsonargparse.typing import PositiveInt from data_juicer.utils.constant import Fields, StatsKeys from data_juicer.utils.mm_utils import (close_video, load_data_with_context, @@ -20,10 +19,10 @@ class VideoResolutionFilter(Filter): """ def __init__(self, - min_width: PositiveInt = 1, - max_width: PositiveInt = sys.maxsize, - min_height: PositiveInt = 1, - max_height: PositiveInt = sys.maxsize, + min_width: int = 1, + max_width: int = sys.maxsize, + min_height: int = 1, + max_height: int = sys.maxsize, any_or_all: str = 'any', *args, **kwargs): diff --git a/data_juicer/ops/filter/video_tagging_from_frames_filter.py b/data_juicer/ops/filter/video_tagging_from_frames_filter.py index fab8f4957..df90e6fd7 100644 --- a/data_juicer/ops/filter/video_tagging_from_frames_filter.py +++ b/data_juicer/ops/filter/video_tagging_from_frames_filter.py @@ -1,5 +1,7 @@ +from typing import List + import numpy as np -from jsonargparse.typing import List, PositiveInt +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields diff --git a/data_juicer/ops/filter/video_watermark_filter.py b/data_juicer/ops/filter/video_watermark_filter.py index 3443875a1..45f2d11d5 100644 --- a/data_juicer/ops/filter/video_watermark_filter.py +++ b/data_juicer/ops/filter/video_watermark_filter.py @@ -1,5 +1,5 @@ import numpy as np -from jsonargparse.typing import ClosedUnitInterval, PositiveInt +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, StatsKeys @@ -34,9 +34,9 @@ class VideoWatermarkFilter(Filter): _accelerator = 'cuda' def __init__(self, - hf_watermark_model='amrul-hzz/watermark_detector', - trust_remote_code=False, - prob_threshold: ClosedUnitInterval = 0.8, + hf_watermark_model: str = 'amrul-hzz/watermark_detector', + trust_remote_code: bool = False, + prob_threshold: float = 0.8, frame_sampling_method: str = 'all_keyframes', frame_num: PositiveInt = 3, reduce_mode: str = 'avg', diff --git a/data_juicer/ops/filter/word_repetition_filter.py b/data_juicer/ops/filter/word_repetition_filter.py index 187c23e06..3009c55f9 100644 --- a/data_juicer/ops/filter/word_repetition_filter.py +++ b/data_juicer/ops/filter/word_repetition_filter.py @@ -2,7 +2,7 @@ # https://huggingface.co/spaces/huggingface/text-data-filtering # -------------------------------------------------------- -from jsonargparse.typing import ClosedUnitInterval, PositiveInt +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, InterVars, StatsKeys @@ -29,8 +29,8 @@ def __init__(self, lang: str = 'en', tokenization: bool = False, rep_len: PositiveInt = 10, - min_ratio: ClosedUnitInterval = 0.0, - max_ratio: ClosedUnitInterval = 0.5, + min_ratio: float = 0.0, + max_ratio: float = 0.5, *args, **kwargs): """ diff --git a/data_juicer/ops/filter/words_num_filter.py b/data_juicer/ops/filter/words_num_filter.py index 3f6d02d76..7d354cb54 100644 --- a/data_juicer/ops/filter/words_num_filter.py +++ b/data_juicer/ops/filter/words_num_filter.py @@ -1,7 +1,5 @@ import sys -from jsonargparse.typing import PositiveInt - from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields, InterVars, StatsKeys from data_juicer.utils.model_utils import get_model, prepare_model @@ -26,8 +24,8 @@ class WordsNumFilter(Filter): def __init__(self, lang: str = 'en', tokenization: bool = False, - min_num: PositiveInt = 10, - max_num: PositiveInt = sys.maxsize, + min_num: int = 10, + max_num: int = sys.maxsize, *args, **kwargs): """ diff --git a/data_juicer/ops/mapper/clean_email_mapper.py b/data_juicer/ops/mapper/clean_email_mapper.py index 9708363e5..b8d2a1cbb 100644 --- a/data_juicer/ops/mapper/clean_email_mapper.py +++ b/data_juicer/ops/mapper/clean_email_mapper.py @@ -1,3 +1,5 @@ +from typing import Optional + import regex as re from ..base_op import OPERATORS, Mapper @@ -7,7 +9,11 @@ class CleanEmailMapper(Mapper): """Mapper to clean email in text samples.""" - def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs): + def __init__(self, + pattern: Optional[str] = None, + repl: str = '', + *args, + **kwargs): """ Initialization method. diff --git a/data_juicer/ops/mapper/clean_ip_mapper.py b/data_juicer/ops/mapper/clean_ip_mapper.py index 607aeb585..b36d13aae 100644 --- a/data_juicer/ops/mapper/clean_ip_mapper.py +++ b/data_juicer/ops/mapper/clean_ip_mapper.py @@ -1,3 +1,5 @@ +from typing import Optional + import regex as re from ..base_op import OPERATORS, Mapper @@ -7,7 +9,11 @@ class CleanIpMapper(Mapper): """Mapper to clean ipv4 and ipv6 address in text samples.""" - def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs): + def __init__(self, + pattern: Optional[str] = None, + repl: str = '', + *args, + **kwargs): """ Initialization method. diff --git a/data_juicer/ops/mapper/clean_links_mapper.py b/data_juicer/ops/mapper/clean_links_mapper.py index bcd90d524..ebeac8668 100644 --- a/data_juicer/ops/mapper/clean_links_mapper.py +++ b/data_juicer/ops/mapper/clean_links_mapper.py @@ -1,6 +1,8 @@ # Some code here has been modified from: # https://github.com/kallewesterling/CleanText/ # -------------------------------------------------------- +from typing import Optional + import regex as re from ..base_op import OPERATORS, Mapper @@ -10,7 +12,11 @@ class CleanLinksMapper(Mapper): """Mapper to clean links like http/https/ftp in text samples.""" - def __init__(self, pattern: str = None, repl: str = '', *args, **kwargs): + def __init__(self, + pattern: Optional[str] = None, + repl: str = '', + *args, + **kwargs): """ Initialization method. diff --git a/data_juicer/ops/mapper/extract_qa_mapper.py b/data_juicer/ops/mapper/extract_qa_mapper.py index db8a397f2..8a41efeb4 100644 --- a/data_juicer/ops/mapper/extract_qa_mapper.py +++ b/data_juicer/ops/mapper/extract_qa_mapper.py @@ -1,6 +1,6 @@ import json import re -from typing import Dict +from typing import Dict, Optional from loguru import logger @@ -42,11 +42,11 @@ class ExtractQAMapper(Mapper): def __init__(self, hf_model: str = 'alibaba-pai/pai-qwen1_5-7b-doc2qa', trust_remote_code: bool = False, - pattern: str = None, + pattern: Optional[str] = None, qa_format: str = 'chatml', enable_vllm: bool = True, - tensor_parallel_size: int = None, - max_model_len: int = None, + tensor_parallel_size: Optional[int] = None, + max_model_len: Optional[int] = None, max_num_seqs: int = 256, sampling_params: Dict = {}, *args, diff --git a/data_juicer/ops/mapper/generate_instruction_mapper.py b/data_juicer/ops/mapper/generate_instruction_mapper.py index f75c54153..9fafa94e3 100644 --- a/data_juicer/ops/mapper/generate_instruction_mapper.py +++ b/data_juicer/ops/mapper/generate_instruction_mapper.py @@ -1,9 +1,10 @@ import json import random import re -from typing import Dict +from typing import Dict, Optional from loguru import logger +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.model_utils import get_model, prepare_model @@ -52,17 +53,17 @@ class GenerateInstructionMapper(Mapper): def __init__(self, hf_model: str = 'Qwen/Qwen-7B-Chat', - seed_file: str = None, - instruct_num: int = 3, + seed_file: str = '', + instruct_num: PositiveInt = 3, trust_remote_code: bool = False, similarity_threshold: float = 0.7, - prompt_template: str = None, - qa_pair_template: str = None, - example_template: str = None, - qa_extraction_pattern: str = None, + prompt_template: Optional[str] = None, + qa_pair_template: Optional[str] = None, + example_template: Optional[str] = None, + qa_extraction_pattern: Optional[str] = None, enable_vllm: bool = True, - tensor_parallel_size: int = None, - max_model_len: int = None, + tensor_parallel_size: Optional[int] = None, + max_model_len: Optional[int] = None, max_num_seqs: int = 256, sampling_params: Dict = {}, *args, @@ -112,8 +113,9 @@ def __init__(self, self.num_proc = 1 if not seed_file: - raise ValueError('Please provide `seed_file` parameter, a file in chatml format. '\ - 'Reference data: data-juicer/demos/data/demo-dataset-chatml.jsonl ') + raise ValueError( + 'Please provide `seed_file` in chatml format.' + 'Example: data-juicer/demos/data/demo-dataset-chatml.jsonl') self.instruct_num = instruct_num self.similarity_threshold = similarity_threshold diff --git a/data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py b/data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py index 76cfbfae0..1500a074a 100644 --- a/data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py +++ b/data_juicer/ops/mapper/image_captioning_from_gpt4v_mapper.py @@ -1,8 +1,10 @@ import copy +from typing import Optional import requests -from jsonargparse.typing import ClosedUnitInterval from loguru import logger +from pydantic import Field +from typing_extensions import Annotated from data_juicer.utils.mm_utils import (SpecialTokens, image_byte_to_base64, insert_texts_after_placeholders, @@ -104,10 +106,10 @@ def __init__(self, mode: str = 'description', api_key: str = '', max_token: int = 500, - temperature: ClosedUnitInterval = 1.0, + temperature: Annotated[float, Field(ge=0, le=1)] = 1.0, system_prompt: str = '', user_prompt: str = '', - user_prompt_key: str = None, + user_prompt_key: Optional[str] = None, keep_original_sample: bool = True, any_or_all: str = 'any', *args, diff --git a/data_juicer/ops/mapper/image_captioning_mapper.py b/data_juicer/ops/mapper/image_captioning_mapper.py index d28c3d8ef..0e3b3a39c 100644 --- a/data_juicer/ops/mapper/image_captioning_mapper.py +++ b/data_juicer/ops/mapper/image_captioning_mapper.py @@ -1,9 +1,10 @@ import copy import random +from typing import Optional import numpy as np -from jsonargparse.typing import PositiveInt from loguru import logger +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import HashKeys @@ -38,13 +39,13 @@ class ImageCaptioningMapper(Mapper): _batched_op = True def __init__(self, - hf_img2seq='Salesforce/blip2-opt-2.7b', - trust_remote_code=False, + hf_img2seq: str = 'Salesforce/blip2-opt-2.7b', + trust_remote_code: bool = False, caption_num: PositiveInt = 1, keep_candidate_mode: str = 'random_any', keep_original_sample: bool = True, - prompt: str = None, - prompt_key: str = None, + prompt: Optional[str] = None, + prompt_key: Optional[str] = None, *args, **kwargs): """ diff --git a/data_juicer/ops/mapper/image_diffusion_mapper.py b/data_juicer/ops/mapper/image_diffusion_mapper.py index 26f0f8403..bd8702b5c 100644 --- a/data_juicer/ops/mapper/image_diffusion_mapper.py +++ b/data_juicer/ops/mapper/image_diffusion_mapper.py @@ -1,7 +1,10 @@ import copy import os +from typing import Optional from PIL import Image +from pydantic import Field, PositiveInt +from typing_extensions import Annotated from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields @@ -38,15 +41,15 @@ class ImageDiffusionMapper(Mapper): def __init__(self, hf_diffusion: str = 'CompVis/stable-diffusion-v1-4', - trust_remote_code=False, + trust_remote_code: bool = False, torch_dtype: str = 'fp32', revision: str = 'main', - strength: float = 0.8, + strength: Annotated[float, Field(ge=0, le=1)] = 0.8, guidance_scale: float = 7.5, - aug_num: int = 1, + aug_num: PositiveInt = 1, keep_original_sample: bool = True, - caption_key: str = None, - hf_img2seq='Salesforce/blip2-opt-2.7b', + caption_key: Optional[str] = None, + hf_img2seq: str = 'Salesforce/blip2-opt-2.7b', *args, **kwargs): """ diff --git a/data_juicer/ops/mapper/image_face_blur_mapper.py b/data_juicer/ops/mapper/image_face_blur_mapper.py index 1ccd4e8c4..e47da2e4e 100644 --- a/data_juicer/ops/mapper/image_face_blur_mapper.py +++ b/data_juicer/ops/mapper/image_face_blur_mapper.py @@ -1,6 +1,7 @@ import os from loguru import logger +from pydantic import NonNegativeFloat from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields @@ -34,9 +35,9 @@ class ImageFaceBlurMapper(Mapper): } def __init__(self, - cv_classifier='', + cv_classifier: str = '', blur_type: str = 'gaussian', - radius: float = 2, + radius: NonNegativeFloat = 2, *args, **kwargs): """ diff --git a/data_juicer/ops/mapper/nlpaug_en_mapper.py b/data_juicer/ops/mapper/nlpaug_en_mapper.py index 581296b6a..3ec5864c7 100644 --- a/data_juicer/ops/mapper/nlpaug_en_mapper.py +++ b/data_juicer/ops/mapper/nlpaug_en_mapper.py @@ -1,6 +1,7 @@ from copy import deepcopy from loguru import logger +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking @@ -23,7 +24,7 @@ class NlpaugEnMapper(Mapper): def __init__(self, sequential: bool = False, - aug_num: int = 1, + aug_num: PositiveInt = 1, keep_original_sample: bool = True, delete_random_word: bool = False, swap_random_word: bool = False, diff --git a/data_juicer/ops/mapper/nlpcda_zh_mapper.py b/data_juicer/ops/mapper/nlpcda_zh_mapper.py index 4c7bdefe3..640ea7391 100644 --- a/data_juicer/ops/mapper/nlpcda_zh_mapper.py +++ b/data_juicer/ops/mapper/nlpcda_zh_mapper.py @@ -1,6 +1,7 @@ from copy import deepcopy from loguru import logger +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.logger_utils import HiddenPrints @@ -21,7 +22,7 @@ class NlpcdaZhMapper(Mapper): def __init__(self, sequential: bool = False, - aug_num: int = 1, + aug_num: PositiveInt = 1, keep_original_sample: bool = True, replace_similar_word: bool = False, replace_homophone_char: bool = False, diff --git a/data_juicer/ops/mapper/optimize_instruction_mapper.py b/data_juicer/ops/mapper/optimize_instruction_mapper.py index 32785dc27..a9ec0564c 100644 --- a/data_juicer/ops/mapper/optimize_instruction_mapper.py +++ b/data_juicer/ops/mapper/optimize_instruction_mapper.py @@ -1,4 +1,4 @@ -from typing import Dict +from typing import Dict, Optional from loguru import logger @@ -34,10 +34,10 @@ class OptimizeInstructionMapper(Mapper): def __init__(self, hf_model: str = 'alibaba-pai/Qwen2-7B-Instruct-Refine', trust_remote_code: bool = False, - system_prompt: str = None, + system_prompt: Optional[str] = None, enable_vllm: bool = True, - tensor_parallel_size: int = None, - max_model_len: int = None, + tensor_parallel_size: Optional[int] = None, + max_model_len: Optional[int] = None, max_num_seqs: int = 256, sampling_params: Dict = {}, *args, diff --git a/data_juicer/ops/mapper/remove_long_words_mapper.py b/data_juicer/ops/mapper/remove_long_words_mapper.py index 92ac8fe2d..ff8fa2d29 100644 --- a/data_juicer/ops/mapper/remove_long_words_mapper.py +++ b/data_juicer/ops/mapper/remove_long_words_mapper.py @@ -4,8 +4,6 @@ import sys -from jsonargparse.typing import PositiveInt - from ..base_op import OPERATORS, Mapper from ..common import (SPECIAL_CHARACTERS, merge_on_whitespace_tab_newline, split_on_newline_tab_whitespace, strip) @@ -16,8 +14,8 @@ class RemoveLongWordsMapper(Mapper): """Mapper to remove long words within a specific range.""" def __init__(self, - min_len: PositiveInt = 1, - max_len: PositiveInt = sys.maxsize, + min_len: int = 1, + max_len: int = sys.maxsize, *args, **kwargs): """ diff --git a/data_juicer/ops/mapper/remove_table_text_mapper.py b/data_juicer/ops/mapper/remove_table_text_mapper.py index 4f6dfb233..ca12104c0 100644 --- a/data_juicer/ops/mapper/remove_table_text_mapper.py +++ b/data_juicer/ops/mapper/remove_table_text_mapper.py @@ -1,11 +1,9 @@ import regex as re -from jsonargparse.typing import restricted_number_type +from pydantic import Field +from typing_extensions import Annotated from ..base_op import OPERATORS, Mapper -from_2_to_20 = restricted_number_type('from_2_to_20', int, [('>=', 2), - ('<=', 20)]) - @OPERATORS.register_module('remove_table_text_mapper') class RemoveTableTextMapper(Mapper): @@ -17,8 +15,8 @@ class RemoveTableTextMapper(Mapper): """ def __init__(self, - min_col: from_2_to_20 = 2, - max_col: from_2_to_20 = 20, + min_col: Annotated[int, Field(ge=2, le=20)] = 2, + max_col: Annotated[int, Field(ge=2, le=20)] = 20, *args, **kwargs): """ diff --git a/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py index 605a75e3b..d262c1d17 100644 --- a/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py +++ b/data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py @@ -1,4 +1,4 @@ -from jsonargparse.typing import List +from typing import List, Optional from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.model_utils import get_model, prepare_model @@ -21,7 +21,7 @@ class RemoveWordsWithIncorrectSubstringsMapper(Mapper): def __init__(self, lang: str = 'en', tokenization: bool = False, - substrings: List = None, + substrings: Optional[List[str]] = None, *args, **kwargs): """ diff --git a/data_juicer/ops/mapper/replace_content_mapper.py b/data_juicer/ops/mapper/replace_content_mapper.py index d73669c3e..d16e4ec7c 100644 --- a/data_juicer/ops/mapper/replace_content_mapper.py +++ b/data_juicer/ops/mapper/replace_content_mapper.py @@ -12,7 +12,7 @@ class ReplaceContentMapper(Mapper): replacement string.""" def __init__(self, - pattern: Union[str, List[str]] = None, + pattern: Union[str, List[str], None] = None, repl: Union[str, List[str]] = '', *args, **kwargs): diff --git a/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py b/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py index f4b13bd8a..baa3a4b5f 100644 --- a/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_frames_mapper.py @@ -1,11 +1,12 @@ # yapf: disable import copy import random +from typing import Optional import numpy as np -from jsonargparse.typing import PositiveInt from loguru import logger from PIL import ImageOps +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import HashKeys @@ -46,13 +47,13 @@ class VideoCaptioningFromFramesMapper(Mapper): def __init__( self, - hf_img2seq='Salesforce/blip2-opt-2.7b', - trust_remote_code=False, + hf_img2seq: str = 'Salesforce/blip2-opt-2.7b', + trust_remote_code: bool = False, caption_num: PositiveInt = 1, keep_candidate_mode: str = 'random_any', keep_original_sample: bool = True, - prompt: str = None, - prompt_key: str = None, + prompt: Optional[str] = None, + prompt_key: Optional[str] = None, frame_sampling_method: str = 'all_keyframes', frame_num: PositiveInt = 3, horizontal_flip: bool = False, diff --git a/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py b/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py index a6cec83c0..3cf0ef618 100644 --- a/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_summarizer_mapper.py @@ -1,7 +1,7 @@ import copy -from typing import Dict +from typing import Dict, Optional -from jsonargparse.typing import PositiveInt +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields @@ -54,16 +54,16 @@ class VideoCaptioningFromSummarizerMapper(Mapper): def __init__(self, hf_summarizer: str = None, - trust_remote_code=False, + trust_remote_code: bool = False, consider_video_caption_from_video: bool = True, consider_video_caption_from_audio: bool = True, consider_video_caption_from_frames: bool = True, consider_video_tags_from_audio: bool = True, consider_video_tags_from_frames: bool = True, - vid_cap_from_vid_args: Dict = None, - vid_cap_from_frm_args: Dict = None, - vid_tag_from_aud_args: Dict = None, - vid_tag_from_frm_args: Dict = None, + vid_cap_from_vid_args: Optional[Dict] = None, + vid_cap_from_frm_args: Optional[Dict] = None, + vid_tag_from_aud_args: Optional[Dict] = None, + vid_tag_from_frm_args: Optional[Dict] = None, keep_tag_num: PositiveInt = 5, keep_original_sample: bool = True, *args, diff --git a/data_juicer/ops/mapper/video_captioning_from_video_mapper.py b/data_juicer/ops/mapper/video_captioning_from_video_mapper.py index ec21b17b8..e697bf0cc 100644 --- a/data_juicer/ops/mapper/video_captioning_from_video_mapper.py +++ b/data_juicer/ops/mapper/video_captioning_from_video_mapper.py @@ -1,11 +1,12 @@ # yapf: disable import copy import random +from typing import Optional import numpy as np -from jsonargparse.typing import PositiveInt from loguru import logger from PIL import ImageOps +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import HashKeys @@ -45,13 +46,13 @@ class VideoCaptioningFromVideoMapper(Mapper): def __init__( self, - hf_video_blip='kpyu/video-blip-opt-2.7b-ego4d', - trust_remote_code=False, + hf_video_blip: str = 'kpyu/video-blip-opt-2.7b-ego4d', + trust_remote_code: bool = False, caption_num: PositiveInt = 1, keep_candidate_mode: str = 'random_any', keep_original_sample: bool = True, - prompt: str = None, - prompt_key: str = None, + prompt: Optional[str] = None, + prompt_key: Optional[str] = None, frame_sampling_method: str = 'all_keyframes', frame_num: PositiveInt = 3, horizontal_flip: bool = False, diff --git a/data_juicer/ops/mapper/video_face_blur_mapper.py b/data_juicer/ops/mapper/video_face_blur_mapper.py index 5ef6db010..a862cd8e0 100644 --- a/data_juicer/ops/mapper/video_face_blur_mapper.py +++ b/data_juicer/ops/mapper/video_face_blur_mapper.py @@ -35,7 +35,7 @@ class VideoFaceBlurMapper(Mapper): } def __init__(self, - cv_classifier='', + cv_classifier: str = '', blur_type: str = 'gaussian', radius: float = 2, *args, diff --git a/data_juicer/ops/mapper/video_remove_watermark_mapper.py b/data_juicer/ops/mapper/video_remove_watermark_mapper.py index f99929439..24924a7cd 100644 --- a/data_juicer/ops/mapper/video_remove_watermark_mapper.py +++ b/data_juicer/ops/mapper/video_remove_watermark_mapper.py @@ -1,8 +1,9 @@ import os +from typing import List, Optional import av import numpy as np -from jsonargparse.typing import List, PositiveInt +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields @@ -33,11 +34,10 @@ class VideoRemoveWatermarkMapper(Mapper): def __init__(self, roi_strings: List[str] = ['0,0,0.1,0.1'], roi_type: str = 'ratio', - roi_key: str = None, + roi_key: Optional[str] = None, frame_num: PositiveInt = 10, min_frame_threshold: PositiveInt = 7, detection_method: str = 'pixel_value', - threshold: int = None, *args, **kwargs): """ diff --git a/data_juicer/ops/mapper/video_resize_resolution_mapper.py b/data_juicer/ops/mapper/video_resize_resolution_mapper.py index a88d3758d..5f60d14f3 100644 --- a/data_juicer/ops/mapper/video_resize_resolution_mapper.py +++ b/data_juicer/ops/mapper/video_resize_resolution_mapper.py @@ -2,7 +2,7 @@ import os import sys -from jsonargparse.typing import PositiveInt +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields @@ -28,10 +28,10 @@ class VideoResizeResolutionMapper(Mapper): """ def __init__(self, - min_width: PositiveInt = 1, - max_width: PositiveInt = sys.maxsize, - min_height: PositiveInt = 1, - max_height: PositiveInt = sys.maxsize, + min_width: int = 1, + max_width: int = sys.maxsize, + min_height: int = 1, + max_height: int = sys.maxsize, force_original_aspect_ratio: str = 'disable', force_divisible_by: PositiveInt = 2, *args, diff --git a/data_juicer/ops/mapper/video_split_by_scene_mapper.py b/data_juicer/ops/mapper/video_split_by_scene_mapper.py index 18a642c12..4b2e39165 100644 --- a/data_juicer/ops/mapper/video_split_by_scene_mapper.py +++ b/data_juicer/ops/mapper/video_split_by_scene_mapper.py @@ -2,7 +2,7 @@ import re from itertools import chain -from jsonargparse.typing import NonNegativeFloat, NonNegativeInt +from pydantic import NonNegativeFloat, NonNegativeInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields diff --git a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py index 2b693cb04..07d1638e7 100644 --- a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py @@ -27,8 +27,8 @@ class VideoTaggingFromAudioMapper(Mapper): _accelerator = 'cuda' def __init__(self, - hf_ast='MIT/ast-finetuned-audioset-10-10-0.4593', - trust_remote_code=False, + hf_ast: str = 'MIT/ast-finetuned-audioset-10-10-0.4593', + trust_remote_code: bool = False, *args, **kwargs): """ diff --git a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py index 4ac0944c8..014ec2268 100644 --- a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py @@ -1,6 +1,6 @@ from collections import Counter -from jsonargparse.typing import PositiveInt +from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields diff --git a/data_juicer/ops/selector/frequency_specified_field_selector.py b/data_juicer/ops/selector/frequency_specified_field_selector.py index 63cc6d80d..5657e2509 100644 --- a/data_juicer/ops/selector/frequency_specified_field_selector.py +++ b/data_juicer/ops/selector/frequency_specified_field_selector.py @@ -1,6 +1,8 @@ import numbers +from typing import Optional -from jsonargparse.typing import ClosedUnitInterval, PositiveInt +from pydantic import Field, PositiveInt +from typing_extensions import Annotated from ..base_op import OPERATORS, Selector @@ -12,8 +14,9 @@ class FrequencySpecifiedFieldSelector(Selector): def __init__(self, field_key: str = '', - top_ratio: ClosedUnitInterval = None, - topk: PositiveInt = None, + top_ratio: Optional[Annotated[float, + Field(ge=0, le=1)]] = None, + topk: Optional[PositiveInt] = None, reverse: bool = True, *args, **kwargs): diff --git a/data_juicer/ops/selector/random_selector.py b/data_juicer/ops/selector/random_selector.py index 19724d29d..c3990ab19 100644 --- a/data_juicer/ops/selector/random_selector.py +++ b/data_juicer/ops/selector/random_selector.py @@ -1,4 +1,7 @@ -from jsonargparse.typing import ClosedUnitInterval, PositiveInt +from typing import Optional + +from pydantic import Field, PositiveInt +from typing_extensions import Annotated from data_juicer.format.mixture_formatter import MixtureFormatter @@ -10,7 +13,8 @@ class RandomSelector(Selector): """Selector to random select samples. """ def __init__(self, - select_ratio: ClosedUnitInterval = None, + select_ratio: Optional[Annotated[float, + Field(ge=0, le=1)]] = None, select_num: PositiveInt = None, *args, **kwargs): diff --git a/data_juicer/ops/selector/range_specified_field_selector.py b/data_juicer/ops/selector/range_specified_field_selector.py index f2e9f12c6..55243b50f 100644 --- a/data_juicer/ops/selector/range_specified_field_selector.py +++ b/data_juicer/ops/selector/range_specified_field_selector.py @@ -1,6 +1,8 @@ import heapq +from typing import Optional -from jsonargparse.typing import ClosedUnitInterval, PositiveInt +from pydantic import Field, PositiveInt +from typing_extensions import Annotated from data_juicer.utils.common_utils import stats_to_number @@ -12,14 +14,17 @@ class RangeSpecifiedFieldSelector(Selector): """Selector to select a range of samples based on the sorted specified field value from smallest to largest. """ - def __init__(self, - field_key: str = '', - lower_percentile: ClosedUnitInterval = None, - upper_percentile: ClosedUnitInterval = None, - lower_rank: PositiveInt = None, - upper_rank: PositiveInt = None, - *args, - **kwargs): + def __init__( + self, + field_key: str = '', + lower_percentile: Optional[Annotated[float, + Field(ge=0, le=1)]] = None, + upper_percentile: Optional[Annotated[float, + Field(ge=0, le=1)]] = None, + lower_rank: Optional[PositiveInt] = None, + upper_rank: Optional[PositiveInt] = None, + *args, + **kwargs): """ Initialization method. diff --git a/data_juicer/ops/selector/topk_specified_field_selector.py b/data_juicer/ops/selector/topk_specified_field_selector.py index 573b2e09f..1852f7222 100644 --- a/data_juicer/ops/selector/topk_specified_field_selector.py +++ b/data_juicer/ops/selector/topk_specified_field_selector.py @@ -1,6 +1,8 @@ import heapq +from typing import Optional -from jsonargparse.typing import ClosedUnitInterval, PositiveInt +from pydantic import Field, PositiveInt +from typing_extensions import Annotated from data_juicer.utils.common_utils import stats_to_number @@ -14,8 +16,9 @@ class TopkSpecifiedFieldSelector(Selector): def __init__(self, field_key: str = '', - top_ratio: ClosedUnitInterval = None, - topk: PositiveInt = None, + top_ratio: Optional[Annotated[float, + Field(ge=0, le=1)]] = None, + topk: Optional[PositiveInt] = None, reverse: bool = True, *args, **kwargs): diff --git a/data_juicer/utils/file_utils.py b/data_juicer/utils/file_utils.py index 28a70a7ed..9af6fa993 100644 --- a/data_juicer/utils/file_utils.py +++ b/data_juicer/utils/file_utils.py @@ -6,7 +6,7 @@ import shutil from datetime import datetime, timezone from pathlib import Path -from typing import AsyncGenerator, List, Tuple, Union +from typing import AsyncGenerator, List, Union from datasets.utils.extract import ZstdExtractor as Extractor @@ -46,7 +46,7 @@ async def follow_read( def find_files_with_suffix( path: Union[str, Path], - suffixes: Union[str, List[str], Tuple[str]] = None) -> List[str]: + suffixes: Union[str, List[str], None] = None) -> List[str]: """ Traverse a path to find all files with the specified suffixes. diff --git a/data_juicer/utils/mm_utils.py b/data_juicer/utils/mm_utils.py index 37c7faf55..5b3ec0430 100644 --- a/data_juicer/utils/mm_utils.py +++ b/data_juicer/utils/mm_utils.py @@ -3,12 +3,13 @@ import os import re import shutil -from typing import List, Union +from typing import List, Optional, Union import av import numpy as np from datasets import Audio, Image from loguru import logger +from pydantic import PositiveInt from data_juicer.utils.constant import DEFAULT_PREFIX, Fields from data_juicer.utils.file_utils import add_suffix_to_filename @@ -195,7 +196,7 @@ def load_video(path, mode='r'): def get_video_duration(input_video: Union[str, av.container.InputContainer], - video_stream_index=0): + video_stream_index: int = 0): """ Get the video's duration from the container @@ -222,7 +223,7 @@ def get_video_duration(input_video: Union[str, av.container.InputContainer], def get_decoded_frames_from_video( input_video: Union[str, av.container.InputContainer], - video_stream_index=0): + video_stream_index: int = 0): """ Get the video's frames from the container @@ -247,7 +248,7 @@ def cut_video_by_seconds( input_video: Union[str, av.container.InputContainer], output_video: str, start_seconds: float, - end_seconds: float = None, + end_seconds: Optional[float] = None, ): """ Cut a video into several segments by times in second. @@ -466,7 +467,7 @@ def get_key_frame_seconds(input_video: Union[str, def extract_video_frames_uniformly( input_video: Union[str, av.container.InputContainer], - frame_num: int, + frame_num: PositiveInt, ): """ Extract a number of video frames uniformly within the video duration. @@ -581,10 +582,10 @@ def extract_video_frames_uniformly( def extract_audio_from_video( input_video: Union[str, av.container.InputContainer], - output_audio: str = None, + output_audio: Optional[str] = None, start_seconds: int = 0, - end_seconds: int = None, - stream_indexes: Union[int, List[int]] = None, + end_seconds: Optional[int] = None, + stream_indexes: Union[int, List[int], None] = None, ): """ Extract audio data for the given video. @@ -804,7 +805,7 @@ def parse_string_to_roi(roi_string, roi_type='pixel'): return None -def close_video(container): +def close_video(container: av.container.InputContainer): """ Close the video stream and container to avoid memory leak. diff --git a/environments/minimal_requires.txt b/environments/minimal_requires.txt index bd55d2008..db90b521d 100644 --- a/environments/minimal_requires.txt +++ b/environments/minimal_requires.txt @@ -25,3 +25,4 @@ spacy==3.7.0 multiprocess==0.70.12 dill==0.3.4 psutil +pydantic>=2.0 diff --git a/tests/config/demo_4_test.yaml b/tests/config/demo_4_test.yaml index 0fe834613..5040049bf 100644 --- a/tests/config/demo_4_test.yaml +++ b/tests/config/demo_4_test.yaml @@ -16,3 +16,4 @@ process: - document_deduplicator: # deduplicate text samples using md5 hashing exact matching method lowercase: false # whether to convert text to lower case ignore_non_character: false + - remove_table_text_mapper: diff --git a/tests/config/demo_4_test_bad_val.yaml b/tests/config/demo_4_test_bad_val.yaml index 3f1b4dbd2..3fca62e32 100644 --- a/tests/config/demo_4_test_bad_val.yaml +++ b/tests/config/demo_4_test_bad_val.yaml @@ -13,7 +13,8 @@ process: - whitespace_normalization_mapper: - language_id_score_filter: lang: 'zh' - min_score: 1.1 # !! a bad value !! - document_deduplicator: # deduplicate text samples using md5 hashing exact matching method lowercase: false # whether to convert text to lower case ignore_non_character: false + - remove_table_text_mapper: + max_col: 30 # !! a bad value !! \ No newline at end of file diff --git a/tests/config/test_config_funcs.py b/tests/config/test_config_funcs.py index a0022711b..54d9a44dc 100644 --- a/tests/config/test_config_funcs.py +++ b/tests/config/test_config_funcs.py @@ -71,28 +71,26 @@ def test_yaml_cfg_file(self): }, 'nested dict load fail, un-expected internal value') ops_from_cfg = load_ops(cfg.process) - self.assertTrue(len(ops_from_cfg) == 3) + self.assertTrue(len(ops_from_cfg) == 4) def test_val_range_check_cmd(self): out = StringIO() - err_msg_head = ("language_id_score_filter.min_score") - err_msg = ("Not of type ClosedUnitInterval: 1.1 does not conform to " - "restriction v>=0 and v<=1") + err_msg_head = ("remove_table_text_mapper.min_col") + err_msg = ("Input should be greater than or equal to 2") with redirect_stdout(out), redirect_stderr(out): with self.assertRaises(SystemExit) as cm: init_configs( args=f'--config {test_yaml_path} ' - '--language_id_score_filter.min_score 1.1'.split()) + '--remove_table_text_mapper.min_col 1'.split()) self.assertEqual(cm.exception.code, 2) out_str = out.getvalue() self.assertIn(err_msg_head, out_str) self.assertIn(err_msg, out_str) - def test_val_range_check_yaml(self): + def _test_val_range_check_yaml(self): out = StringIO() - err_msg_head = ("language_id_score_filter.min_score") - err_msg = ("Not of type ClosedUnitInterval: 1.1 does not conform to " - "restriction v>=0 and v<=1") + err_msg_head = ("remove_table_text_mapper.max_col") + err_msg = ("Input should be less than or equal to 20") with redirect_stdout(out), redirect_stderr(out): with self.assertRaises(SystemExit) as cm: init_configs(args=f'--config {test_bad_yaml_path}'.split()) diff --git a/tools/distributed_deduplication/dedup_utils.py b/tools/distributed_deduplication/dedup_utils.py index 25c46d4ed..4a4bf9a23 100644 --- a/tools/distributed_deduplication/dedup_utils.py +++ b/tools/distributed_deduplication/dedup_utils.py @@ -2,14 +2,14 @@ # https://github.com/bigcode-project/bigcode-dataset/blob/main/near_deduplication/minhash_deduplication_spark.py # -------------------------------------------------------- -from typing import List, Tuple, Union +from typing import List, Optional, Tuple from loguru import logger from pyspark import SparkConf from pyspark.sql import SparkSession -def init_spark(master_url: Union[str, None] = None, +def init_spark(master_url: Optional[str] = None, spark_executor_memory=None, spark_driver_memory=None, spark_executor_memoryOverhead=None): diff --git a/tools/distributed_deduplication/spark_dedup.py b/tools/distributed_deduplication/spark_dedup.py index 871f1811b..8f0f05d4f 100644 --- a/tools/distributed_deduplication/spark_dedup.py +++ b/tools/distributed_deduplication/spark_dedup.py @@ -1,6 +1,6 @@ import sys import time -from typing import Union +from typing import Optional import fire from loguru import logger @@ -18,11 +18,11 @@ @logger.catch def dedup_dataset(dataset_path: str, result_path: str, - tokenizer: Union[str, None] = None, + tokenizer: Optional[str] = None, num_features: int = 1047576, num_hashtables: int = 10, text_key: str = 'text', - master_url: Union[str, None] = None): + master_url: Optional[str] = None): """ Perform fuzzy text deduplication on the given dataset. :param dataset_path: the path to the dataset to perform deduplication,