From ee863605ded251930bdcb62a4252192c8f23ea17 Mon Sep 17 00:00:00 2001 From: "gece.gc" Date: Tue, 12 Nov 2024 10:51:50 +0000 Subject: [PATCH] refine op --- configs/config_all.yaml | 9 +- data_juicer/ops/mapper/__init__.py | 10 +- .../ops/mapper/image_segment_mapper.py | 70 +++++++++ .../ops/mapper/image_tagging_mapper.py | 4 +- data_juicer/ops/mapper/segment_mapper.py | 87 ----------- .../video_tagging_from_frames_mapper.py | 4 +- data_juicer/utils/auto_install_mapping.py | 144 +++++++++--------- data_juicer/utils/model_utils.py | 34 ++--- docs/Operators.md | 4 +- docs/Operators_ZH.md | 4 +- environments/science_requires.txt | 1 + tests/ops/mapper/test_image_segment_mapper.py | 61 ++++++++ tests/ops/mapper/test_segment_mapper.py | 48 ------ 13 files changed, 237 insertions(+), 243 deletions(-) create mode 100644 data_juicer/ops/mapper/image_segment_mapper.py delete mode 100644 data_juicer/ops/mapper/segment_mapper.py create mode 100644 tests/ops/mapper/test_image_segment_mapper.py delete mode 100644 tests/ops/mapper/test_segment_mapper.py diff --git a/configs/config_all.yaml b/configs/config_all.yaml index 221de3629..a79291a43 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -132,6 +132,10 @@ process: cv_classifier: '' # OpenCV classifier path for face detection. By default, we will use 'haarcascade_frontalface_alt.xml'. blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] radius: 2 # radius of blur kernel + - image_segment_mapper: # perform segment-anything on images and return the bounding boxes. + imgsz: 1024 # image resolution after image resizing + conf: 0.05 # confidence score threshold + iou: 0.5 # IoU (Intersection over Union) score threshold - image_tagging_mapper: # Mapper to generate image tags. tag_field_name: '__dj__image_tags__' # the field name to store the tags. It's "__dj__image_tags__" in default. - nlpaug_en_mapper: # simply augment texts in English based on the nlpaug library @@ -195,11 +199,6 @@ process: lang: en # sample in which language tokenization: false # whether to use model to tokenize documents substrings: ['http', 'www', '.com', 'href', '//'] # incorrect substrings to remove - - segment_mapper: # perform segment-anything on images and return the bounding box values. - fastsam_path: './FastSAM-x.pt' # model name of the FastSAM model on ultralytics - imgsz: 1024 # image resolution after image resizing - conf: 0.05 # confidence score threshold - iou: 0.5 # IoU (Intersection over Union) score threshold - sentence_split_mapper: # split text to multiple sentences and join them with '\n' lang: 'en' # split text in what language - video_captioning_from_audio_mapper: # caption a video according to its audio streams based on Qwen-Audio model diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index dcfc9b7db..4c28017d8 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -17,6 +17,7 @@ from .image_captioning_mapper import ImageCaptioningMapper from .image_diffusion_mapper import ImageDiffusionMapper from .image_face_blur_mapper import ImageFaceBlurMapper +from .image_segment_mapper import ImageSegmentMapper from .image_tagging_mapper import ImageTaggingMapper from .nlpaug_en_mapper import NlpaugEnMapper from .nlpcda_zh_mapper import NlpcdaZhMapper @@ -36,7 +37,6 @@ from .remove_words_with_incorrect_substrings_mapper import \ RemoveWordsWithIncorrectSubstringsMapper from .replace_content_mapper import ReplaceContentMapper -from .segment_mapper import SegmentMapper from .sentence_split_mapper import SentenceSplitMapper from .video_captioning_from_audio_mapper import VideoCaptioningFromAudioMapper from .video_captioning_from_frames_mapper import \ @@ -63,15 +63,15 @@ 'ExpandMacroMapper', 'FixUnicodeMapper', 'GenerateQAFromExamplesMapper', 'GenerateQAFromTextMapper', 'ImageBlurMapper', 'ImageCaptioningFromGPT4VMapper', 'ImageCaptioningMapper', - 'ImageDiffusionMapper', 'ImageFaceBlurMapper', 'ImageTaggingMapper', - 'NlpaugEnMapper', 'NlpcdaZhMapper', 'OptimizeQAMapper', - 'OptimizeQueryMapper', 'OptimizeResponseMapper', + 'ImageDiffusionMapper', 'ImageFaceBlurMapper', 'ImageSegmentMapper', + 'ImageTaggingMapper', 'NlpaugEnMapper', 'NlpcdaZhMapper', + 'OptimizeQAMapper', 'OptimizeQueryMapper', 'OptimizeResponseMapper', 'PunctuationNormalizationMapper', 'RemoveBibliographyMapper', 'RemoveCommentsMapper', 'RemoveHeaderMapper', 'RemoveLongWordsMapper', 'RemoveNonChineseCharacterlMapper', 'RemoveRepeatSentencesMapper', 'RemoveSpecificCharsMapper', 'RemoveTableTextMapper', 'RemoveWordsWithIncorrectSubstringsMapper', 'ReplaceContentMapper', - 'SegmentMapper', 'SentenceSplitMapper', 'VideoCaptioningFromAudioMapper', + 'SentenceSplitMapper', 'VideoCaptioningFromAudioMapper', 'VideoCaptioningFromFramesMapper', 'VideoCaptioningFromSummarizerMapper', 'VideoCaptioningFromVideoMapper', 'VideoFFmpegWrappedMapper', 'VideoFaceBlurMapper', 'VideoRemoveWatermarkMapper', diff --git a/data_juicer/ops/mapper/image_segment_mapper.py b/data_juicer/ops/mapper/image_segment_mapper.py new file mode 100644 index 000000000..75303dfb9 --- /dev/null +++ b/data_juicer/ops/mapper/image_segment_mapper.py @@ -0,0 +1,70 @@ +import numpy as np + +from data_juicer.utils.constant import Fields +from data_juicer.utils.lazy_loader import LazyLoader +from data_juicer.utils.mm_utils import load_data_with_context, load_image +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, UNFORKABLE, Mapper +from ..op_fusion import LOADED_IMAGES + +OP_NAME = 'image_segment_mapper' + +torch = LazyLoader('torch', 'torch') +ultralytics = LazyLoader('ultralytics', 'ultralytics') + + +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +@LOADED_IMAGES.register_module(OP_NAME) +class ImageSegmentMapper(Mapper): + """Perform segment-anything on images and return the bounding boxes.""" + + _accelerator = 'cuda' + + def __init__(self, imgsz=1024, conf=0.05, iou=0.5, *args, **kwargs): + """ + Initialization method. + + :param imgsz: resolution for image resizing + :param conf: confidence score threshold + :param iou: IoU (Intersection over Union) score threshold + + """ + super().__init__(*args, **kwargs) + + self.model_key = prepare_model(model_type='fastsam', + model_path='FastSAM-x.pt') + + self.imgsz = imgsz + self.conf = conf + self.iou = iou + + def process_single(self, sample, rank=None, context=False): + # there is no image in this sample + if self.image_key not in sample or not sample[self.image_key]: + # N x M x 4 for N images, M boxes, 4 coords + sample[Fields.bbox_tag] = np.empty((0, 0, 4), dtype=np.float32) + return sample + + loaded_image_keys = sample[self.image_key] + sample, images = load_data_with_context(sample, context, + loaded_image_keys, load_image) + + model = get_model(self.model_key, rank=rank, use_cuda=self.use_cuda()) + sample[Fields.bbox_tag] = [] + + for image in images: + masks = model(image, + retina_masks=True, + imgsz=self.imgsz, + conf=self.conf, + iou=self.iou, + verbose=False)[0] + # breakpoint() + sample[Fields.bbox_tag].append(masks.boxes.xywh.cpu().numpy()) + + # match schema + if len(sample[Fields.bbox_tag]) == 0: + sample[Fields.bbox_tag] = np.empty((0, 0, 4), dtype=np.float32) + return sample diff --git a/data_juicer/ops/mapper/image_tagging_mapper.py b/data_juicer/ops/mapper/image_tagging_mapper.py index d47fbf0ef..55d7f4fe3 100644 --- a/data_juicer/ops/mapper/image_tagging_mapper.py +++ b/data_juicer/ops/mapper/image_tagging_mapper.py @@ -38,8 +38,8 @@ def __init__(self, """ super().__init__(*args, **kwargs) self.model_key = prepare_model( - model_type='recognizeAnything', - pretrained_model_name_or_path='ram_plus_swin_large_14m.pth', + model_type='ram', + model_path='ram_plus_swin_large_14m.pth', input_size=384) self.transform = ram.get_transform(image_size=384) self.tag_field_name = tag_field_name diff --git a/data_juicer/ops/mapper/segment_mapper.py b/data_juicer/ops/mapper/segment_mapper.py deleted file mode 100644 index 75bca1383..000000000 --- a/data_juicer/ops/mapper/segment_mapper.py +++ /dev/null @@ -1,87 +0,0 @@ -import copy - -from data_juicer.ops.base_op import OPERATORS, Mapper -from data_juicer.ops.op_fusion import LOADED_IMAGES -from data_juicer.utils.availability_utils import AvailabilityChecking -from data_juicer.utils.constant import Fields -from data_juicer.utils.mm_utils import load_image -from data_juicer.utils.model_utils import get_model, prepare_model - -OP_NAME = 'segment_mapper' - -with AvailabilityChecking(['torch', 'ultralytics'], OP_NAME): - import torch - import ultralytics # noqa: F401 - - # avoid hanging when calling model in multiprocessing - torch.set_num_threads(1) - - -@OPERATORS.register_module(OP_NAME) -@LOADED_IMAGES.register_module(OP_NAME) -class SegmentMapper(Mapper): - """Perform segment-anything on images and return the bounding boxes.""" - - _accelerator = 'cuda' - _batched_op = True - - def __init__(self, - fastsam_path='FastSAM-x.pt', - imgsz=1024, - conf=0.05, - iou=0.5, - *args, - **kwargs): - """ - Initialization method. - - :param fastsam_path: location of FastSAM - :param imgsz: image resolution after image resizing - :param conf: confidence score threshold - :param iou: IoU (Intersection over Union) score threshold - - """ - super().__init__(*args, **kwargs) - - self.model_key = prepare_model( - model_type='fastsam', pretrained_model_name_or_path=fastsam_path) - - self.imgsz = imgsz - self.conf = conf - self.iou = iou - - def process(self, ori_sample, rank=None): - - # there is no image in this sample - if self.image_key not in ori_sample or \ - not ori_sample[self.image_key]: - return [] - - generated_samples = copy.deepcopy(ori_sample) - - loaded_image_keys = ori_sample[self.image_key] - images = {} - for loaded_image_key in loaded_image_keys: - if loaded_image_key not in images: - # avoid loading the same images - image = load_image(loaded_image_key) - images[loaded_image_key] = image - - model = get_model(self.model_key, rank=rank, use_cuda=self.use_cuda()) - - generated_samples[Fields.bbox_tag] = [] - - for image in images: - masks = model([image], - retina_masks=True, - imgsz=self.imgsz, - conf=self.conf, - iou=self.iou, - verbose=False)[0] - - if len(masks.boxes.xyxy) == 0: - generated_samples[Fields.bbox_tag].append([]) - else: - generated_samples[Fields.bbox_tag].append(masks.boxes.xyxy) - - return generated_samples diff --git a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py index 26227738b..3eb0c3f0b 100644 --- a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py @@ -61,8 +61,8 @@ def __init__(self, f'Frame sampling method [{frame_sampling_method}] is not ' f'supported. Can only be one of ["all_keyframes", "uniform"].') self.model_key = prepare_model( - model_type='recognizeAnything', - pretrained_model_name_or_path='ram_plus_swin_large_14m.pth', + model_type='ram', + model_path='ram_plus_swin_large_14m.pth', input_size=384) self.frame_sampling_method = frame_sampling_method self.frame_num = frame_num diff --git a/data_juicer/utils/auto_install_mapping.py b/data_juicer/utils/auto_install_mapping.py index 96a54b437..3297bf8e7 100644 --- a/data_juicer/utils/auto_install_mapping.py +++ b/data_juicer/utils/auto_install_mapping.py @@ -1,108 +1,112 @@ # Map the imported module to the require package we need to install +# keep sorted for maintainability MODULE_TO_PKGS = { + 'PIL': ['Pillow'], 'aesthetics_predictor': ['simple-aesthetics-predictor'], 'cv2': ['opencv-python'], 'fasttext': ['fasttext-wheel'], 'ffmpeg': ['ffmpeg-python'], - 'PIL': ['Pillow'], 'ram': ['ram@git+https://github.com/xinyu1205/recognize-anything.git'], 'scenedetect': ['scenedetect[opencv]'], 'simhash': ['simhash-pybind'], + 'ultralytics': ['ultralytics'] } # Packages to corresponding ops that require them +# keep sorted for maintainability PKG_TO_OPS = { - 'torch': [ - 'image_aesthetics_filter', 'image_nsfw_filter', - 'image_text_matching_filter', 'image_text_similarity_filter', - 'image_watermark_filter', 'phrase_grounding_recall_filter', - 'video_aesthetics_filter', 'video_frames_text_similarity_filter', - 'video_nsfw_filter', 'video_tagging_from_frames_filter', - 'video_watermark_filter', 'generate_qa_from_text_mapper', - 'generate_qa_from_examples_mapper', 'image_captioning_mapper', - 'image_diffusion_mapper', 'image_tagging_mapper', - 'optimize_query_mapper', 'optimize_response_mapper', - 'optimize_qa_mapper', 'video_captioning_from_frames_mapper', - 'video_captioning_from_summarizer_mapper', - 'video_captioning_from_video_mapper', - 'video_tagging_from_audio_mapper', 'video_tagging_from_frames_mapper' - ], - 'torchaudio': [ - 'video_captioning_from_summarizer_mapper', - 'video_tagging_from_audio_mapper' + 'accelerate': [ + 'video_captioning_from_audio_mapper', + 'video_captioning_from_summarizer_mapper' ], + 'diffusers': ['image_diffusion_mapper'], 'easyocr': ['video_ocr_area_ratio_filter'], + 'einops': [ + 'video_captioning_from_audio_mapper', + 'video_captioning_from_summarizer_mapper' + ], 'fasttext-wheel': ['language_id_score_filter'], + 'ffmpeg-python': [ + 'audio_ffmpeg_wrapped_mapper', 'video_ffmpeg_wrapped_mapper', + 'video_resize_aspect_ratio_mapper', 'video_resize_resolution_mapper' + ], + 'ftfy': ['fix_unicode_mapper'], + 'imagededup': ['image_deduplicator', 'ray_image_deduplicator'], 'kenlm': ['perplexity_filter'], + 'nlpaug': ['nlpaug_en_mapper'], + 'nlpcda': ['nlpcda'], + 'nltk': ['phrase_grounding_recall_filter', 'sentence_split_mapper'], + 'opencc': ['chinese_convert_mapper'], + 'opencv-python': [ + 'image_face_blur_mapper', 'image_face_ratio_filter', + 'video_face_blur_mapper', 'video_motion_score_filter', + 'video_remove_watermark_mapper' + ], + 'ram': ['image_tagging_mapper', 'video_tagging_from_frames_mapper'], + 'rouge': ['generate_qa_from_examples_mapper'], + 'scenedetect[opencv]': ['video_split_by_scene_mapper'], + 'scipy': ['document_minhash_deduplicator'], + 'selectolax': ['clean_html_mapper'], 'sentencepiece': [ 'flagged_words_filter', 'perplexity_filter', 'stopwords_filter', 'word_repetition_filter', 'words_num_filter' ], - 'scipy': ['document_minhash_deduplicator'], - 'ftfy': ['fix_unicode_mapper'], 'simhash-pybind': [ 'document_simhash_deduplicator', 'image_captioning_mapper', 'image_diffusion_mapper', 'video_captioning_from_frames_mapper', 'video_captioning_from_summarizer_mapper', 'video_captioning_from_video_mapper' ], - 'selectolax': ['clean_html_mapper'], - 'nlpaug': ['nlpaug_en_mapper'], - 'nlpcda': ['nlpcda'], - 'nltk': ['phrase_grounding_recall_filter', 'sentence_split_mapper'], - 'transformers': [ - 'alphanumeric_filter', 'image_aesthetics_filter', 'image_nsfw_filter', - 'image_text_matching_filter', 'image_text_similarity_filter', - 'image_watermark_filter', 'phrase_grounding_recall_filter', - 'token_num_filter', 'video_aesthetics_filter', - 'video_frames_text_similarity_filter', 'video_nsfw_filter', - 'generate_qa_from_text_mapper', 'generate_qa_from_examples_mapper', - 'image_captioning_mapper', 'image_diffusion_mapper', - 'optimize_query_mapper', 'optimize_response_mapper', - 'optimize_qa_mapper', 'video_captioning_from_audio_mapper', - 'video_captioning_from_frames_mapper', - 'video_captioning_from_summarizer_mapper', - 'video_captioning_from_video_mapper', 'video_tagging_from_audio_mapper' - ], - 'transformers_stream_generator': [ + 'simple-aesthetics-predictor': + ['image_aesthetics_filter', 'video_aesthetics_filter'], + 'spacy-pkuseg': ['text_action_filter', 'text_entity_dependency_filter'], + 'tiktoken': [ 'video_captioning_from_audio_mapper', 'video_captioning_from_summarizer_mapper' ], - 'einops': [ - 'video_captioning_from_audio_mapper', - 'video_captioning_from_summarizer_mapper' + 'torch': [ + 'generate_qa_from_examples_mapper', 'generate_qa_from_text_mapper', + 'image_aesthetics_filter', 'image_captioning_mapper', + 'image_diffusion_mapper', 'image_nsfw_filter', 'image_segment_mapper', + 'image_tagging_mapper', 'image_text_matching_filter', + 'image_text_similarity_filter', 'image_watermark_filter', + 'optimize_qa_mapper', 'optimize_query_mapper', + 'optimize_response_mapper', 'phrase_grounding_recall_filter', + 'video_aesthetics_filter', 'video_captioning_from_frames_mapper', + 'video_captioning_from_summarizer_mapper', + 'video_captioning_from_video_mapper', + 'video_frames_text_similarity_filter', 'video_nsfw_filter', + 'video_tagging_from_audio_mapper', 'video_tagging_from_frames_filter', + 'video_tagging_from_frames_mapper', 'video_watermark_filter' ], - 'accelerate': [ + 'torchaudio': [ + 'video_captioning_from_summarizer_mapper', + 'video_tagging_from_audio_mapper' + ], + 'transformers': [ + 'alphanumeric_filter', 'generate_qa_from_examples_mapper', + 'generate_qa_from_text_mapper', 'image_aesthetics_filter', + 'image_captioning_mapper', 'image_diffusion_mapper', + 'image_nsfw_filter', 'image_text_matching_filter', + 'image_text_similarity_filter', 'image_watermark_filter', + 'optimize_qa_mapper', 'optimize_query_mapper', + 'optimize_response_mapper', 'phrase_grounding_recall_filter', + 'token_num_filter', 'video_aesthetics_filter', 'video_captioning_from_audio_mapper', - 'video_captioning_from_summarizer_mapper' + 'video_captioning_from_frames_mapper', + 'video_captioning_from_summarizer_mapper', + 'video_captioning_from_video_mapper', + 'video_frames_text_similarity_filter', 'video_nsfw_filter', + 'video_tagging_from_audio_mapper' ], - 'tiktoken': [ + 'transformers_stream_generator': [ 'video_captioning_from_audio_mapper', 'video_captioning_from_summarizer_mapper' ], - 'opencc': ['chinese_convert_mapper'], - 'imagededup': ['image_deduplicator', 'ray_image_deduplicator'], - 'spacy-pkuseg': ['text_action_filter', 'text_entity_dependency_filter'], - 'diffusers': ['image_diffusion_mapper'], - 'simple-aesthetics-predictor': - ['image_aesthetics_filter', 'video_aesthetics_filter'], - 'scenedetect[opencv]': ['video_split_by_scene_mapper'], - 'ffmpeg-python': [ - 'audio_ffmpeg_wrapped_mapper', 'video_ffmpeg_wrapped_mapper', - 'video_resize_aspect_ratio_mapper', 'video_resize_resolution_mapper' - ], - 'opencv-python': [ - 'image_face_ratio_filter', 'video_motion_score_filter', - 'image_face_blur_mapper', 'video_face_blur_mapper', - 'video_remove_watermark_mapper' - ], + 'ultralytics': ['image_segment_mapper'], 'vllm': [ - 'generate_qa_from_text_mapper', - 'generate_qa_from_examples_mapper', - 'optimize_query_mapper', - 'optimize_response_mapper', - 'optimize_qa_mapper', - ], - 'rouge': ['generate_qa_from_examples_mapper'], - 'ram': ['image_tagging_mapper', 'video_tagging_from_frames_mapper'] + 'generate_qa_from_examples_mapper', 'generate_qa_from_text_mapper', + 'optimize_qa_mapper', 'optimize_query_mapper', + 'optimize_response_mapper' + ] } diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 1828d96c7..c6a4fef0e 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -28,6 +28,7 @@ ram = LazyLoader('ram', 'ram.models') cv2 = LazyLoader('cv2', 'cv2') openai = LazyLoader('openai', 'openai') +ultralytics = LazyLoader('ultralytics', 'ultralytics') MODEL_ZOO = {} @@ -300,10 +301,8 @@ def prepare_diffusion_model(pretrained_model_name_or_path, diffusion_type, return model -def prepare_fastsam_model(pretrained_model_name_or_path): - from ultralytics import FastSAM - - return FastSAM(pretrained_model_name_or_path) +def prepare_fastsam_model(model_path, **model_params): + return ultralytics.FastSAM(model_path) def prepare_fasttext_model(model_name='lid.176.bin', **model_params): @@ -433,27 +432,24 @@ def prepare_opencv_classifier(model_path, **model_params): return model -def prepare_recognizeAnything_model( - pretrained_model_name_or_path='ram_plus_swin_large_14m.pth', - input_size=384, - **model_params): +def prepare_ram_model(model_path='ram_plus_swin_large_14m.pth', + input_size=384, + **model_params): """ - Prepare and load recognizeAnything model. + Prepare and load Recognize Anything Model (RAM). :param model_name: input model name. :param input_size: the input size of the model. """ - logger.info('Loading recognizeAnything model...') + logger.info('Loading Recognize Anything Model (RAM)...') try: - model = ram.ram_plus( - pretrained=check_model(pretrained_model_name_or_path), - image_size=input_size, - vit='swin_l') + model = ram.ram_plus(pretrained=check_model(model_path), + image_size=input_size, + vit='swin_l') except (RuntimeError, UnpicklingError) as e: # noqa: E722 logger.warning(e) - model = ram.ram_plus(pretrained=check_model( - pretrained_model_name_or_path, force=True), + model = ram.ram_plus(pretrained=check_model(model_path, force=True), image_size=input_size, vit='swin_l') device = model_params.pop('device', 'cpu') @@ -759,7 +755,7 @@ def prepare_vllm_model(pretrained_model_name_or_path, **model_params): 'kenlm': prepare_kenlm_model, 'nltk': prepare_nltk_model, 'opencv_classifier': prepare_opencv_classifier, - 'recognizeAnything': prepare_recognizeAnything_model, + 'ram': prepare_ram_model, 'sentencepiece': prepare_sentencepiece_for_lang, 'simple_aesthetics': prepare_simple_aesthetics_model, 'spacy': prepare_spacy_model, @@ -767,9 +763,7 @@ def prepare_vllm_model(pretrained_model_name_or_path, **model_params): 'vllm': prepare_vllm_model, } -_MODELS_WITHOUT_FILE_LOCK = { - 'kenlm', 'nltk', 'recognizeAnything', 'sentencepiece', 'spacy' -} +_MODELS_WITHOUT_FILE_LOCK = {'kenlm', 'nltk', 'ram', 'sentencepiece', 'spacy'} def prepare_model(model_type, **model_kwargs): diff --git a/docs/Operators.md b/docs/Operators.md index d1fe52698..0d964b090 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 9 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 52 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 53 | Edits and transforms samples | | [ Filter ]( #filter ) | 44 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 8 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -75,6 +75,7 @@ All the specific operators are listed below, each featured with several capabili | image_captioning_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | generate samples whose captions are generated based on another model (such as blip2) and the figure within the original sample | [code](../data_juicer/ops/mapper/image_captioning_mapper.py) | [tests](../tests/ops/mapper/test_image_captioning_mapper.py) | | image_diffusion_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Generate and augment images by stable diffusion model | [code](../data_juicer/ops/mapper/image_diffusion_mapper.py) | [tests](../tests/ops/mapper/test_image_diffusion_mapper.py) | | image_face_blur_mapper | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | Blur faces detected in images | [code](../data_juicer/ops/mapper/image_face_blur_mapper.py) | [tests](../tests/ops/mapper/test_image_face_blur_mapper.py) | +| image_segment_mapper | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Perform segment-anything on images and return the bounding box values | [code](../data_juicer/ops/mapper/image_segment_mapper.py) | [tests](../tests/ops/mapper/test_image_segment_mapper.py) | | image_tagging_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Mapper to generate image tags from the input images. | [code](../data_juicer/ops/mapper/image_tagging_mapper.py) | [tests](../tests/ops/mapper/test_image_tagging_mapper.py) | | nlpaug_en_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) | Simply augments texts in English based on the `nlpaug` library | [code](../data_juicer/ops/mapper/nlpaug_en_mapper.py) | [tests](../tests/ops/mapper/test_nlpaug_en_mapper.py) | | nlpcda_zh_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Simply augments texts in Chinese based on the `nlpcda` library | [code](../data_juicer/ops/mapper/nlpcda_zh_mapper.py) | [tests](../tests/ops/mapper/test_nlpcda_zh_mapper.py) | @@ -92,7 +93,6 @@ All the specific operators are listed below, each featured with several capabili | remove_table_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![Financial](https://img.shields.io/badge/Financial-A64C44?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) | Detects and removes possible table contents (:warning: relies on regular expression matching and thus fragile) | [code](../data_juicer/ops/mapper/remove_table_text_mapper.py) | [tests](../tests/ops/mapper/test_remove_table_text_mapper.py) | | remove_words_with_incorrect_ substrings_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Removes words containing specified substrings | [code](../data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py) | [tests](../tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py) | | replace_content_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | Replace all content in the text that matches a specific regular expression pattern with a designated replacement string | [code](../data_juicer/ops/mapper/replace_content_mapper.py) | [tests](../tests/ops/mapper/test_replace_content_mapper.py) | -| segment_mapper | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Perform segment-anything on images and return the bounding box values | [code](../data_juicer/ops/mapper/segment_mapper.py) | [tests](../tests/ops/mapper/test_segment_mapper.py) | | sentence_split_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) | Splits and reorganizes sentences according to semantics | [code](../data_juicer/ops/mapper/sentence_split_mapper.py) | [tests](../tests/ops/mapper/test_sentence_split_mapper.py) | | video_captioning_from_audio_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | Caption a video according to its audio streams based on Qwen-Audio model | [code](../data_juicer/ops/mapper/video_captioning_from_audio_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_audio_mapper.py) | | video_captioning_from_frames_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | generate samples whose captions are generated based on an image-to-text model and sampled video frames. Captions from different frames will be concatenated to a single string | [code](../data_juicer/ops/mapper/video_captioning_from_frames_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_frames_mapper.py) | diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 57e3b5add..b375c98a1 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 9 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 52 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 53 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 44 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 8 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -74,6 +74,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | image_captioning_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 生成样本,其标题是根据另一个辅助模型(例如 blip2)和原始样本中的图形生成的。 | [code](../data_juicer/ops/mapper/image_captioning_mapper.py) | [tests](../tests/ops/mapper/test_image_captioning_mapper.py) | | image_diffusion_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 用stable diffusion生成图像,对图像进行增强 | [code](../data_juicer/ops/mapper/image_diffusion_mapper.py) | [tests](../tests/ops/mapper/test_image_diffusion_mapper.py) | | image_face_blur_mapper | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) | 对图像中的人脸进行模糊处理 | [code](../data_juicer/ops/mapper/image_face_blur_mapper.py) | [tests](../tests/ops/mapper/test_image_face_blur_mapper.py) | +| image_segment_mapper | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 对图像实施“分割万物”(segment-anything)的语义分割,并返回 bounding boxes 坐标 | [code](../data_juicer/ops/mapper/image_segment_mapper.py) | [tests](../tests/ops/mapper/test_image_segment_mapper.py) | | image_tagging_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 从输入图片中生成图片标签 | [code](../data_juicer/ops/mapper/image_tagging_mapper.py) | [tests](../tests/ops/mapper/test_image_tagging_mapper.py) | | nlpaug_en_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) | 使用`nlpaug`库对英语文本进行简单增强 | [code](../data_juicer/ops/mapper/nlpaug_en_mapper.py) | [tests](../tests/ops/mapper/test_nlpaug_en_mapper.py) | | nlpcda_zh_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用`nlpcda`库对中文文本进行简单增强 | [code](../data_juicer/ops/mapper/nlpcda_zh_mapper.py) | [tests](../tests/ops/mapper/test_nlpcda_zh_mapper.py) | @@ -91,7 +92,6 @@ Data-Juicer 中的算子分为以下 5 种类型。 | remove_table_text_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![Financial](https://img.shields.io/badge/Financial-A64C44?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) | 检测并删除可能的表格内容(:warning: 依赖正则表达式匹配,因此很脆弱) | [code](../data_juicer/ops/mapper/remove_table_text_mapper.py) | [tests](../tests/ops/mapper/test_remove_table_text_mapper.py) | | remove_words_with_incorrect_ substrings_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 删除包含指定子字符串的单词 | [code](../data_juicer/ops/mapper/remove_words_with_incorrect_substrings_mapper.py) | [tests](../tests/ops/mapper/test_remove_words_with_incorrect_substrings_mapper.py) | | replace_content_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) ![zh](https://img.shields.io/badge/zh-F2D6A2?style=plastic) | 使用一个指定的替换字符串替换文本中满足特定正则表达式模版的所有内容 | [code](../data_juicer/ops/mapper/replace_content_mapper.py) | [tests](../tests/ops/mapper/test_replace_content_mapper.py) | -| segment_mapper | ![Image](https://img.shields.io/badge/Image-07B0F2?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 对图像实施“分割万物”(segment-anything)的语义分割,并返回 bounding boxes 坐标 | [code](../data_juicer/ops/mapper/segment_mapper.py) | [tests](../tests/ops/mapper/test_segment_mapper.py) | | sentence_split_mapper | ![General](https://img.shields.io/badge/General-5FBF50?style=plastic) ![Text](https://img.shields.io/badge/Text-010326?style=plastic) ![en](https://img.shields.io/badge/en-A60D1A?style=plastic) | 根据语义拆分和重组句子 | [code](../data_juicer/ops/mapper/sentence_split_mapper.py) | [tests](../tests/ops/mapper/test_sentence_split_mapper.py) | | video_captioning_from_audio_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 基于 Qwen-Audio 模型根据视频的音频流为视频生成新的标题描述 | [code](../data_juicer/ops/mapper/video_captioning_from_audio_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_audio_mapper.py) | | video_captioning_from_frames_mapper | ![Multimodal](https://img.shields.io/badge/Multimodal-F25922?style=plastic) ![GPU](https://img.shields.io/badge/GPU-F27649?style=plastic) | 生成样本,其标题是基于一个文字生成图片的模型和原始样本视频中指定帧的图像。不同帧产出的标题会拼接为一条单独的字符串。 | [code](../data_juicer/ops/mapper/video_captioning_from_frames_mapper.py) | [tests](../tests/ops/mapper/test_video_captioning_from_frames_mapper.py) | diff --git a/environments/science_requires.txt b/environments/science_requires.txt index f1e613126..846cb9633 100644 --- a/environments/science_requires.txt +++ b/environments/science_requires.txt @@ -26,3 +26,4 @@ ffmpeg-python opencv-python vllm>=0.1.3 rouge +ultralytics diff --git a/tests/ops/mapper/test_image_segment_mapper.py b/tests/ops/mapper/test_image_segment_mapper.py new file mode 100644 index 000000000..ca625a5b7 --- /dev/null +++ b/tests/ops/mapper/test_image_segment_mapper.py @@ -0,0 +1,61 @@ +import os +import unittest + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.image_segment_mapper import ImageSegmentMapper +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase + + +class ImageSegmentMapperTest(DataJuicerTestCaseBase): + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + img1_path = os.path.join(data_path, 'img1.png') + img2_path = os.path.join(data_path, 'img2.jpg') + img3_path = os.path.join(data_path, 'img3.jpg') + + def _run_op(self, op, source_list, num_proc=1): + dataset = Dataset.from_list(source_list) + dataset = dataset.map(op.process, num_proc=num_proc, with_rank=True) + res_list = dataset.to_list() + + bbox_nums = [[5], [14, 6]] + for sample, sample_bbn in zip(res_list, bbox_nums): + for bb, bbn in zip(sample['__dj__bbox__'], sample_bbn): + self.assertEqual(len(bb), bbn) + + def test_segment_mapper(self): + ds_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path, self.img3_path] + }] + # fix params for reproducibility + op = ImageSegmentMapper(imgsz=1024, conf=0.9, iou=0.5) + self._run_op(op, ds_list) + + def test_cpu(self): + ds_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path, self.img3_path] + }] + # fix params for reproducibility + op = ImageSegmentMapper(imgsz=1024, + conf=0.9, + iou=0.5, + accelerator='cpu') + self._run_op(op, ds_list) + + def test_multi_process(self): + ds_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path, self.img3_path] + }] + # fix params for reproducibility + op = ImageSegmentMapper(imgsz=1024, conf=0.9, iou=0.5) + self._run_op(op, ds_list, num_proc=2) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_segment_mapper.py b/tests/ops/mapper/test_segment_mapper.py deleted file mode 100644 index 8b4668858..000000000 --- a/tests/ops/mapper/test_segment_mapper.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -import unittest - -from data_juicer.core.data import NestedDataset as Dataset -from data_juicer.utils.mm_utils import SpecialTokens -from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, - DataJuicerTestCaseBase) - -from data_juicer.ops.mapper.segment_mapper import SegmentMapper -from data_juicer.utils.unittest_utils import (SKIPPED_TESTS, - DataJuicerTestCaseBase) - - - -class SDXLPrompt2PromptMapperTest(DataJuicerTestCaseBase): - - text_key = 'text' - - def _run_segment_mapper(self, enable_vllm=False): - op = SegmentMapper( - fastsam_path='FastSAM-x.pt', - ) - - data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', - 'data') - img2_path = os.path.join(data_path, 'img2.jpg') - img3_path = os.path.join(data_path, 'img3.jpg') - img5_path = os.path.join(data_path, 'img5.jpg') - - ds_list = [{ - 'images': [img2_path, img3_path] - }, { - 'images': [img5_path] - }] - - - for sample in ds_list: - result = op.process(sample) - print(f'Output results: {result}') - - - def test_segment_mapper(self): - self._run_segment_mapper() - - - -if __name__ == '__main__': - unittest.main() \ No newline at end of file