Skip to content

Commit

Permalink
use pydantic types (#422)
Browse files Browse the repository at this point in the history
* use pydantic types

* change config unittest

* fix GenerateInstructionMapper

* update
  • Loading branch information
drcege authored Sep 12, 2024
1 parent 8ab74a3 commit b3fb942
Show file tree
Hide file tree
Showing 83 changed files with 301 additions and 281 deletions.
4 changes: 2 additions & 2 deletions data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tempfile
import time
from argparse import ArgumentError, Namespace
from typing import Dict, List, Tuple, Union
from typing import Dict, List, Union

import yaml
from jsonargparse import (ActionConfigFile, ArgumentParser, dict_to_namespace,
Expand Down Expand Up @@ -194,7 +194,7 @@ def init_configs(args=None):
'own special token according to your input dataset.')
parser.add_argument(
'--suffixes',
type=Union[str, List[str], Tuple[str]],
type=Union[str, List[str]],
default=[],
help='Suffixes of files that will be find and loaded. If not set, we '
'will find all suffix files, and select a suitable formatter '
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/format/formatter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import List, Tuple, Union
from typing import List, Union

from datasets import Dataset, DatasetDict, concatenate_datasets, load_dataset
from loguru import logger
Expand Down Expand Up @@ -27,7 +27,7 @@ def __init__(
self,
dataset_path: str,
type: str,
suffixes: Union[str, List[str], Tuple[str]] = None,
suffixes: Union[str, List[str], None] = None,
text_keys: List[str] = None,
add_suffix=False,
**kwargs,
Expand Down
4 changes: 2 additions & 2 deletions data_juicer/format/mixture_formatter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from itertools import chain, repeat
from typing import List, Tuple, Union
from typing import List, Union

import numpy as np
from datasets import Dataset, concatenate_datasets
Expand All @@ -15,7 +15,7 @@ class MixtureFormatter(BaseFormatter):

def __init__(self,
dataset_path: str,
suffixes: Union[str, List[str], Tuple[str]] = None,
suffixes: Union[str, List[str], None] = None,
text_keys=None,
add_suffix=False,
max_samples=None,
Expand Down
14 changes: 8 additions & 6 deletions data_juicer/ops/deduplicator/document_minhash_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
import hashlib
import struct
from collections import defaultdict
from typing import Optional

import numpy as np
import regex
from jsonargparse.typing import ClosedUnitInterval, PositiveInt
from loguru import logger
from pydantic import Field, PositiveInt
from tqdm import tqdm
from typing_extensions import Annotated

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import HashKeys
Expand Down Expand Up @@ -109,12 +111,12 @@ def __init__(
tokenization: str = 'space',
window_size: PositiveInt = 5,
lowercase: bool = True,
ignore_pattern: str = None,
ignore_pattern: Optional[str] = None,
num_permutations: PositiveInt = 256,
jaccard_threshold: ClosedUnitInterval = 0.7,
num_bands: PositiveInt = None,
num_rows_per_band: PositiveInt = None,
tokenizer_model: str = None,
jaccard_threshold: Annotated[float, Field(ge=0, le=1)] = 0.7,
num_bands: Optional[PositiveInt] = None,
num_rows_per_band: Optional[PositiveInt] = None,
tokenizer_model: Optional[str] = None,
*args,
**kwargs,
):
Expand Down
6 changes: 3 additions & 3 deletions data_juicer/ops/deduplicator/document_simhash_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
# --------------------------------------------------------

from collections import defaultdict, deque
from typing import Dict, Set
from typing import Dict, Optional, Set

import numpy as np
import regex
from jsonargparse.typing import PositiveInt
from loguru import logger
from pydantic import PositiveInt

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import HashKeys
Expand All @@ -30,7 +30,7 @@ def __init__(self,
tokenization: str = 'space',
window_size: PositiveInt = 6,
lowercase: bool = True,
ignore_pattern: str = None,
ignore_pattern: Optional[str] = None,
num_blocks: PositiveInt = 6,
hamming_distance: PositiveInt = 4,
*args,
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/deduplicator/image_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def process(self, dataset, show_num=0):
if show_num > 0:
# sample duplicate pairs
if self.consider_text:
hash2ids: Dict[Tuple[int], Set[int]] = defaultdict(set)
hash2ids: Dict[Tuple[int, int], Set[int]] = defaultdict(set)
hashes = zip(dataset[HashKeys.imagehash],
dataset[HashKeys.hash])
else:
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/deduplicator/ray_basic_deduplicator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any

from jsonargparse.typing import PositiveInt
from pydantic import PositiveInt

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import HashKeys
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/deduplicator/ray_document_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import string

import regex as re
from jsonargparse.typing import PositiveInt
from pydantic import PositiveInt

from ..base_op import OPERATORS
from .ray_basic_deduplicator import RayBasicDeduplicator
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/deduplicator/ray_image_deduplicator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from jsonargparse.typing import PositiveInt
from pydantic import PositiveInt

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.mm_utils import load_data_with_context, load_image
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/deduplicator/ray_video_deduplicator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import hashlib

from jsonargparse.typing import PositiveInt
from pydantic import PositiveInt

from data_juicer.utils.mm_utils import (close_video, load_data_with_context,
load_video)
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/deduplicator/video_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def process(self, dataset, show_num=0):
if show_num > 0:
# sample duplicate pairs
if self.consider_text:
hash2ids: Dict[Tuple[int], Set[int]] = defaultdict(set)
hash2ids: Dict[Tuple[int, int], Set[int]] = defaultdict(set)
hashes = zip(dataset[HashKeys.videohash],
dataset[HashKeys.hash])
else:
Expand Down
4 changes: 1 addition & 3 deletions data_juicer/ops/filter/alphanumeric_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import sys

from jsonargparse.typing import PositiveFloat

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.model_utils import get_model, prepare_model
Expand All @@ -23,7 +21,7 @@ class AlphanumericFilter(Filter):
def __init__(self,
tokenization: bool = False,
min_ratio: float = 0.25,
max_ratio: PositiveFloat = sys.maxsize,
max_ratio: float = sys.maxsize,
*args,
**kwargs):
"""
Expand Down
5 changes: 2 additions & 3 deletions data_juicer/ops/filter/audio_duration_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import librosa
import numpy as np
from jsonargparse.typing import NonNegativeInt

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import load_audio, load_data_with_context
Expand All @@ -20,8 +19,8 @@ class AudioDurationFilter(Filter):
"""

def __init__(self,
min_duration: NonNegativeInt = 0,
max_duration: NonNegativeInt = sys.maxsize,
min_duration: int = 0,
max_duration: int = sys.maxsize,
any_or_all: str = 'any',
*args,
**kwargs):
Expand Down
2 changes: 1 addition & 1 deletion data_juicer/ops/filter/audio_nmf_snr_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import librosa
import numpy as np
from jsonargparse.typing import PositiveInt
from librosa.decompose import decompose
from pydantic import PositiveInt

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import load_audio, load_data_with_context
Expand Down
6 changes: 2 additions & 4 deletions data_juicer/ops/filter/average_line_length_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import sys

from jsonargparse.typing import PositiveInt

from data_juicer.utils.constant import Fields, InterVars, StatsKeys

from ..base_op import OPERATORS, Filter
Expand All @@ -15,8 +13,8 @@ class AverageLineLengthFilter(Filter):
range."""

def __init__(self,
min_len: PositiveInt = 10,
max_len: PositiveInt = sys.maxsize,
min_len: int = 10,
max_len: int = sys.maxsize,
*args,
**kwargs):
"""
Expand Down
6 changes: 3 additions & 3 deletions data_juicer/ops/filter/character_repetition_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# --------------------------------------------------------

import numpy as np
from jsonargparse.typing import ClosedUnitInterval, PositiveInt
from pydantic import PositiveInt

from data_juicer.utils.constant import Fields, StatsKeys

Expand All @@ -17,8 +17,8 @@ class CharacterRepetitionFilter(Filter):

def __init__(self,
rep_len: PositiveInt = 10,
min_ratio: ClosedUnitInterval = 0.0,
max_ratio: ClosedUnitInterval = 0.5,
min_ratio: float = 0.0,
max_ratio: float = 0.5,
*args,
**kwargs):
"""
Expand Down
8 changes: 5 additions & 3 deletions data_juicer/ops/filter/flagged_words_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# https://huggingface.co/spaces/huggingface/text-data-filtering
# --------------------------------------------------------

from jsonargparse.typing import ClosedUnitInterval, List
from typing import List

from pydantic import PositiveInt

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields, InterVars, StatsKeys
Expand All @@ -29,10 +31,10 @@ class FlaggedWordFilter(Filter):
def __init__(self,
lang: str = 'en',
tokenization: bool = False,
max_ratio: ClosedUnitInterval = 0.045,
max_ratio: float = 0.045,
flagged_words_dir: str = ASSET_DIR,
use_words_aug: bool = False,
words_aug_group_sizes: List = [2],
words_aug_group_sizes: List[PositiveInt] = [2],
words_aug_join_char: str = '',
*args,
**kwargs):
Expand Down
9 changes: 4 additions & 5 deletions data_juicer/ops/filter/image_aesthetics_filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from jsonargparse.typing import ClosedUnitInterval
from loguru import logger

from data_juicer.utils.availability_utils import AvailabilityChecking
Expand Down Expand Up @@ -32,10 +31,10 @@ class ImageAestheticsFilter(Filter):
_accelerator = 'cuda'

def __init__(self,
hf_scorer_model='',
trust_remote_code=False,
min_score: ClosedUnitInterval = 0.5,
max_score: ClosedUnitInterval = 1.0,
hf_scorer_model: str = '',
trust_remote_code: bool = False,
min_score: float = 0.5,
max_score: float = 1.0,
any_or_all: str = 'any',
*args,
**kwargs):
Expand Down
5 changes: 2 additions & 3 deletions data_juicer/ops/filter/image_aspect_ratio_filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from jsonargparse.typing import PositiveFloat

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import load_data_with_context, load_image
Expand All @@ -16,8 +15,8 @@ class ImageAspectRatioFilter(Filter):
"""

def __init__(self,
min_ratio: PositiveFloat = 0.333,
max_ratio: PositiveFloat = 3.0,
min_ratio: float = 0.333,
max_ratio: float = 3.0,
any_or_all: str = 'any',
*args,
**kwargs):
Expand Down
7 changes: 3 additions & 4 deletions data_juicer/ops/filter/image_face_ratio_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import numpy as np
from jsonargparse.typing import ClosedUnitInterval
from loguru import logger

from data_juicer.utils.availability_utils import AvailabilityChecking
Expand Down Expand Up @@ -34,9 +33,9 @@ class ImageFaceRatioFilter(Filter):
}

def __init__(self,
cv_classifier='',
min_ratio: ClosedUnitInterval = 0.0,
max_ratio: ClosedUnitInterval = 0.4,
cv_classifier: str = '',
min_ratio: float = 0.0,
max_ratio: float = 0.4,
any_or_all: str = 'any',
*args,
**kwargs):
Expand Down
7 changes: 3 additions & 4 deletions data_juicer/ops/filter/image_nsfw_filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from jsonargparse.typing import ClosedUnitInterval

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields, StatsKeys
Expand Down Expand Up @@ -27,9 +26,9 @@ class ImageNSFWFilter(Filter):
_accelerator = 'cuda'

def __init__(self,
hf_nsfw_model='Falconsai/nsfw_image_detection',
trust_remote_code=False,
score_threshold: ClosedUnitInterval = 0.5,
hf_nsfw_model: str = 'Falconsai/nsfw_image_detection',
trust_remote_code: bool = False,
score_threshold: float = 0.5,
any_or_all: str = 'any',
*args,
**kwargs):
Expand Down
9 changes: 4 additions & 5 deletions data_juicer/ops/filter/image_shape_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import sys

import numpy as np
from jsonargparse.typing import PositiveInt

from data_juicer.utils.constant import Fields, StatsKeys
from data_juicer.utils.mm_utils import load_data_with_context, load_image
Expand All @@ -17,10 +16,10 @@ class ImageShapeFilter(Filter):
"""

def __init__(self,
min_width: PositiveInt = 1,
max_width: PositiveInt = sys.maxsize,
min_height: PositiveInt = 1,
max_height: PositiveInt = sys.maxsize,
min_width: int = 1,
max_width: int = sys.maxsize,
min_height: int = 1,
max_height: int = sys.maxsize,
any_or_all: str = 'any',
*args,
**kwargs):
Expand Down
9 changes: 4 additions & 5 deletions data_juicer/ops/filter/image_text_matching_filter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from jsonargparse.typing import ClosedUnitInterval
from PIL import ImageOps

from data_juicer.utils.availability_utils import AvailabilityChecking
Expand Down Expand Up @@ -30,10 +29,10 @@ class ImageTextMatchingFilter(Filter):
_accelerator = 'cuda'

def __init__(self,
hf_blip='Salesforce/blip-itm-base-coco',
trust_remote_code=False,
min_score: ClosedUnitInterval = 0.003,
max_score: ClosedUnitInterval = 1.0,
hf_blip: str = 'Salesforce/blip-itm-base-coco',
trust_remote_code: bool = False,
min_score: float = 0.003,
max_score: float = 1.0,
horizontal_flip: bool = False,
vertical_flip: bool = False,
any_or_all: str = 'any',
Expand Down
Loading

0 comments on commit b3fb942

Please sign in to comment.