From c40a3086f7d0a3c26092bd60e5a4aca6a011bf2c Mon Sep 17 00:00:00 2001 From: Yilun Huang Date: Thu, 12 Sep 2024 15:28:26 +0800 Subject: [PATCH] Add new OP: image_tagging_mapper (#423) * * init image tagging mapper * + Add unittest for image_tagging_mapper * support specified tag field names for all tagging OPs * * fix problems of unittest * + add docs * * update docs * * skip two unittests which require ram * * minor fix for gece's comments * * merge main into this branch * + add type hint --- configs/config_all.yaml | 5 + .../video_tagging_from_frames_filter.py | 7 +- data_juicer/ops/mapper/__init__.py | 11 +- .../ops/mapper/image_tagging_mapper.py | 85 +++++++++ .../mapper/video_tagging_from_audio_mapper.py | 16 +- .../video_tagging_from_frames_mapper.py | 16 +- data_juicer/utils/constant.py | 2 + docs/Operators.md | 97 +++++------ docs/Operators_ZH.md | 97 +++++------ tests/ops/mapper/test_image_tagging_mapper.py | 156 +++++++++++++++++ .../test_video_tagging_from_audio_mapper.py | 28 ++- .../test_video_tagging_from_frames_mapper.py | 163 +++++++++++++----- 12 files changed, 522 insertions(+), 161 deletions(-) create mode 100644 data_juicer/ops/mapper/image_tagging_mapper.py create mode 100644 tests/ops/mapper/test_image_tagging_mapper.py diff --git a/configs/config_all.yaml b/configs/config_all.yaml index de53e0420..a28861c77 100644 --- a/configs/config_all.yaml +++ b/configs/config_all.yaml @@ -122,6 +122,8 @@ process: cv_classifier: '' # OpenCV classifier path for face detection. By default, we will use 'haarcascade_frontalface_alt.xml'. blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian'] radius: 2 # radius of blur kernel + - image_tagging_mapper: # Mapper to generate image tags. + tag_field_name: '__dj__image_tags__' # the field name to store the tags. It's "__dj__image_tags__" in default. - nlpaug_en_mapper: # simply augment texts in English based on the nlpaug library sequential: false # whether combine all augmentation methods to a sequence. If it's True, a sample will be augmented by all opened augmentation methods sequentially. If it's False, each opened augmentation method would generate its augmented samples independently. aug_num: 1 # number of augmented samples to be generated. If `sequential` is True, there will be total aug_num augmented samples generated. If it's False, there will be (aug_num * #opened_aug_method) augmented samples generated. @@ -258,10 +260,12 @@ process: show_progress: false # whether to show progress from scenedetect - video_tagging_from_audio_mapper: # Mapper to generate video tags from audio streams extracted from the video. hf_ast: 'MIT/ast-finetuned-audioset-10-10-0.4593' # Huggingface model name for the audio classification model. + tag_field_name: '__dj__video_audio_tags__' # the field name to store the tags. It's "__dj__video_audio_tags__" in default. mem_required: '500MB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched - video_tagging_from_frames_mapper: # Mapper to generate video tags from frames extracted from the video. frame_sampling_method: 'all_keyframes' # sampling method of extracting frame images from the videos. Should be one of ["all_keyframes", "uniform"]. The former one extracts all key frames and the latter one extract specified number of frames uniformly from the video. Default: "all_keyframes". frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. + tag_field_name: '__dj__video_frame_tags__' # the field name to store the tags. It's "__dj__video_frame_tags__" in default. - whitespace_normalization_mapper: # normalize different kinds of whitespaces to English whitespace. # Filter ops @@ -478,6 +482,7 @@ process: contain: any # require the videos containing 'any' or 'all' given tags. When tags equal to [], 'all' keeps all samples, 'any' keeps no sample. frame_sampling_method: all_keyframes # sampling method of extracting frame images from the videos. Should be one of ["all_keyframes", "uniform"]. The former one extracts all key frames and the latter one extract specified number of frames uniformly from the video. Default: "all_keyframes". frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. + tag_field_name: '__dj__video_frame_tags__' # the field name to store the tags. It's "__dj__video_frame_tags__" in default. any_or_all: any # keep this sample when any/all videos meet the filter condition - words_num_filter: # filter text with number of words out of specific range lang: en # sample in which language diff --git a/data_juicer/ops/filter/video_tagging_from_frames_filter.py b/data_juicer/ops/filter/video_tagging_from_frames_filter.py index df90e6fd7..056233a9c 100644 --- a/data_juicer/ops/filter/video_tagging_from_frames_filter.py +++ b/data_juicer/ops/filter/video_tagging_from_frames_filter.py @@ -37,6 +37,7 @@ def __init__(self, contain: str = 'any', frame_sampling_method: str = 'all_keyframes', frame_num: PositiveInt = 3, + tag_field_name: str = Fields.video_frame_tags, any_or_all: str = 'any', *args, **kwargs): @@ -61,6 +62,8 @@ def __init__(self, the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. + :param tag_field_name: the field name to store the tags. It's + "__dj__video_frame_tags__" in default. :param any_or_all: keep this sample with 'any' or 'all' strategy of all videos. 'any': keep this sample if any videos meet the condition. 'all': keep this sample only if all videos meet the @@ -82,10 +85,12 @@ def __init__(self, self.tags = set([tag.lower() for tag in tags]) self.contain_any = (contain == 'any') self.any = (any_or_all == 'any') + self.tag_field_name = tag_field_name self.tagging_producer = VideoTaggingFromFramesMapper( frame_sampling_method=frame_sampling_method, frame_num=frame_num, accelerator=self.accelerator, + tag_field_name=self.tag_field_name, ) def compute_stats(self, sample, rank=None, context=False): @@ -95,7 +100,7 @@ def compute_stats(self, sample, rank=None, context=False): return sample def process(self, sample, rank=None): - video_tags = sample[Fields.video_frame_tags] + video_tags = sample[self.tag_field_name] if len(video_tags) <= 0: return True diff --git a/data_juicer/ops/mapper/__init__.py b/data_juicer/ops/mapper/__init__.py index d0e32825c..eb814b374 100644 --- a/data_juicer/ops/mapper/__init__.py +++ b/data_juicer/ops/mapper/__init__.py @@ -6,10 +6,11 @@ generate_instruction_mapper, image_blur_mapper, image_captioning_from_gpt4v_mapper, image_captioning_mapper, image_diffusion_mapper, image_face_blur_mapper, - nlpaug_en_mapper, nlpcda_zh_mapper, optimize_instruction_mapper, - punctuation_normalization_mapper, remove_bibliography_mapper, - remove_comments_mapper, remove_header_mapper, - remove_long_words_mapper, remove_non_chinese_character_mapper, + image_tagging_mapper, nlpaug_en_mapper, nlpcda_zh_mapper, + optimize_instruction_mapper, punctuation_normalization_mapper, + remove_bibliography_mapper, remove_comments_mapper, + remove_header_mapper, remove_long_words_mapper, + remove_non_chinese_character_mapper, remove_repeat_sentences_mapper, remove_specific_chars_mapper, remove_table_text_mapper, remove_words_with_incorrect_substrings_mapper, @@ -41,6 +42,7 @@ from .image_captioning_mapper import ImageCaptioningMapper from .image_diffusion_mapper import ImageDiffusionMapper from .image_face_blur_mapper import ImageFaceBlurMapper +from .image_tagging_mapper import ImageTaggingMapper from .nlpaug_en_mapper import NlpaugEnMapper from .nlpcda_zh_mapper import NlpcdaZhMapper from .optimize_instruction_mapper import OptimizeInstructionMapper @@ -123,6 +125,7 @@ 'AudioFFmpegWrappedMapper', 'VideoSplitByDurationMapper', 'VideoFaceBlurMapper', + 'ImageTaggingMapper', ] # yapf: enable diff --git a/data_juicer/ops/mapper/image_tagging_mapper.py b/data_juicer/ops/mapper/image_tagging_mapper.py new file mode 100644 index 000000000..0bd2b89e2 --- /dev/null +++ b/data_juicer/ops/mapper/image_tagging_mapper.py @@ -0,0 +1,85 @@ +from collections import Counter + +import numpy as np + +from data_juicer.utils.availability_utils import AvailabilityChecking +from data_juicer.utils.constant import Fields +from data_juicer.utils.mm_utils import load_data_with_context, load_image +from data_juicer.utils.model_utils import get_model, prepare_model + +from ..base_op import OPERATORS, UNFORKABLE, Mapper +from ..op_fusion import LOADED_IMAGES + +OP_NAME = 'image_tagging_mapper' + +with AvailabilityChecking( + ['torch', 'git+https://github.com/xinyu1205/recognize-anything.git'], + OP_NAME): + import ram # noqa: F401 + import torch + + # avoid hanging when calling recognizeAnything in multiprocessing + torch.set_num_threads(1) + + +@UNFORKABLE.register_module(OP_NAME) +@OPERATORS.register_module(OP_NAME) +@LOADED_IMAGES.register_module(OP_NAME) +class ImageTaggingMapper(Mapper): + """Mapper to generate image tags. + """ + + _accelerator = 'cuda' + + def __init__(self, + tag_field_name: str = Fields.image_tags, + *args, + **kwargs): + """ + Initialization method. + :param tag_field_name: the field name to store the tags. It's + "__dj__image_tags__" in default. + :param args: extra args + :param kwargs: extra args + """ + super().__init__(*args, **kwargs) + self.model_key = prepare_model( + model_type='recognizeAnything', + pretrained_model_name_or_path='ram_plus_swin_large_14m.pth', + input_size=384) + from ram import get_transform + self.transform = get_transform(image_size=384) + self.tag_field_name = tag_field_name + + def process(self, sample, rank=None, context=False): + # check if it's generated already + if self.tag_field_name in sample: + return sample + + # there is no image in this sample + if self.image_key not in sample or not sample[self.image_key]: + sample[self.tag_field_name] = np.array([[]], dtype=np.str_) + return sample + + # load images + loaded_image_keys = sample[self.image_key] + sample, images = load_data_with_context(sample, context, + loaded_image_keys, load_image) + + model = get_model(self.model_key, rank, self.use_cuda()) + image_tags = [] + for _, value in enumerate(loaded_image_keys): + image = images[value] + + image_tensor = torch.unsqueeze(self.transform(image), dim=0).to( + next(model.parameters()).device) + with torch.no_grad(): + tags, _ = model.generate_tag(image_tensor) + + words = [word.strip() for word in tags[0].split('|')] + word_count = Counter(words) + sorted_word_list = [item for item, _ in word_count.most_common()] + image_tags.append(np.array(sorted_word_list, dtype=np.str_)) + + sample[self.tag_field_name] = image_tags + return sample diff --git a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py index 07d1638e7..c9f0536e2 100644 --- a/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_audio_mapper.py @@ -1,4 +1,5 @@ import librosa +import numpy as np from data_juicer.utils.availability_utils import AvailabilityChecking from data_juicer.utils.constant import Fields @@ -29,11 +30,16 @@ class VideoTaggingFromAudioMapper(Mapper): def __init__(self, hf_ast: str = 'MIT/ast-finetuned-audioset-10-10-0.4593', trust_remote_code: bool = False, + tag_field_name: str = Fields.video_audio_tags, *args, **kwargs): """ Initialization method. + :param hf_ast: path to the HF model to tag from audios. + :param trust_remote_code: whether to trust the remote code of HF models + :param tag_field_name: the field name to store the tags. It's + "__dj__video_audio_tags__" in default. :param args: extra args :param kwargs: extra args """ @@ -44,14 +50,16 @@ def __init__(self, self._model_sampling_rate = 16000 self._no_audio_label = 'EMPTY' + self.tag_field_name = tag_field_name + def process(self, sample, rank=None): # check if it's generated already - if Fields.video_audio_tags in sample: + if self.tag_field_name in sample: return sample # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: - sample[Fields.video_audio_tags] = [] + sample[self.tag_field_name] = np.array([], dtype=np.str_) return sample # load video paths @@ -80,11 +88,11 @@ def process(self, sample, rank=None): sr = self._model_sampling_rate inputs = feature_extractor(y, sampling_rate=sr, - return_tensors='pt') + return_tensors='pt').to(model.device) with torch.no_grad(): logits = model(**inputs).logits predicted_tag_id = torch.argmax(logits, dim=-1).item() predicted_tag = model.config.id2label[predicted_tag_id] video_audio_tags.append(predicted_tag) - sample[Fields.video_audio_tags] = video_audio_tags + sample[self.tag_field_name] = np.array(video_audio_tags, dtype=np.str_) return sample diff --git a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py index 014ec2268..8aafdb615 100644 --- a/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py +++ b/data_juicer/ops/mapper/video_tagging_from_frames_mapper.py @@ -1,5 +1,6 @@ from collections import Counter +import numpy as np from pydantic import PositiveInt from data_juicer.utils.availability_utils import AvailabilityChecking @@ -36,6 +37,7 @@ class VideoTaggingFromFramesMapper(Mapper): def __init__(self, frame_sampling_method: str = 'all_keyframes', frame_num: PositiveInt = 3, + tag_field_name: str = Fields.video_frame_tags, *args, **kwargs): """ @@ -54,6 +56,8 @@ def __init__(self, the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration. + :param tag_field_name: the field name to store the tags. It's + "__dj__video_frame_tags__" in default. :param args: extra args :param kwargs: extra args """ @@ -71,14 +75,16 @@ def __init__(self, from ram import get_transform self.transform = get_transform(image_size=384) + self.tag_field_name = tag_field_name + def process(self, sample, rank=None, context=False): # check if it's generated already - if Fields.video_frame_tags in sample: + if self.tag_field_name in sample: return sample # there is no video in this sample if self.video_key not in sample or not sample[self.video_key]: - sample[Fields.video_frame_tags] = [] + sample[self.tag_field_name] = np.array([[]], dtype=np.str_) return sample # load videos @@ -98,7 +104,7 @@ def process(self, sample, rank=None, context=False): frames = extract_video_frames_uniformly(video, self.frame_num) else: video_tags.append([]) - frames = [] + continue frame_tensor = torch.stack([ self.transform(frame.to_image()) for frame in frames @@ -109,11 +115,11 @@ def process(self, sample, rank=None, context=False): words = [word.strip() for tag in tags for word in tag.split('|')] word_count = Counter(words) sorted_word_list = [item for item, _ in word_count.most_common()] - video_tags.append(sorted_word_list) + video_tags.append(np.array(sorted_word_list, dtype=np.str_)) if not context: for vid_key in videos: close_video(videos[vid_key]) - sample[Fields.video_frame_tags] = video_tags + sample[self.tag_field_name] = video_tags return sample diff --git a/data_juicer/utils/constant.py b/data_juicer/utils/constant.py index 8329424bd..99e8724c0 100644 --- a/data_juicer/utils/constant.py +++ b/data_juicer/utils/constant.py @@ -19,6 +19,8 @@ class Fields(object): # video_frame_tags video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__' video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__' + # image_tags + image_tags = DEFAULT_PREFIX + 'image_tags__' # the name of the original file from which this sample was derived. source_file = DEFAULT_PREFIX + 'source_file__' diff --git a/docs/Operators.md b/docs/Operators.md index 7b1f8f0f3..dd56871c4 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -11,7 +11,7 @@ The operators in Data-Juicer are categorized into 5 types. | Type | Number | Description | |-----------------------------------|:------:|-------------------------------------------------| | [ Formatter ]( #formatter ) | 7 | Discovers, loads, and canonicalizes source data | -| [ Mapper ]( #mapper ) | 46 | Edits and transforms samples | +| [ Mapper ]( #mapper ) | 47 | Edits and transforms samples | | [ Filter ]( #filter ) | 42 | Filters out low-quality samples | | [ Deduplicator ]( #deduplicator ) | 5 | Detects and removes duplicate samples | | [ Selector ]( #selector ) | 4 | Selects top samples based on ranking | @@ -48,54 +48,55 @@ All the specific operators are listed below, each featured with several capabili ## Mapper -| Operator | Domain | Lang | Description | -|-----------------------------------------------------|--------------------|--------|---------------------------------------------------------------------------------------------------------------| -| audio_ffmpeg_wrapped_mapper | Audio | - | Simple wrapper to run a FFmpeg audio filter | -| chinese_convert_mapper | General | zh | Converts Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji (by [opencc](https://github.com/BYVoid/OpenCC)) | -| clean_copyright_mapper | Code | en, zh | Removes copyright notice at the beginning of code files (must contain the word *copyright*) | -| clean_email_mapper | General | en, zh | Removes email information | -| clean_html_mapper | General | en, zh | Removes HTML tags and returns plain text of all the nodes | -| clean_ip_mapper | General | en, zh | Removes IP addresses | -| clean_links_mapper | General, Code | en, zh | Removes links, such as those starting with http or ftp | -| expand_macro_mapper | LaTeX | en, zh | Expands macros usually defined at the top of TeX documents | -| extract_qa_mapper | General | en, zh | Extract question and answer pair from text samples. | -| fix_unicode_mapper | General | en, zh | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) | -| generate_instruction_mapper | General | en, zh | Generate instruction text samples.| -| image_blur_mapper | Image | - | Blur images | -| image_captioning_from_gpt4v_mapper | Multimodal | - | generate samples whose texts are generated based on gpt-4-visison and the image | -| image_captioning_mapper | Multimodal | - | generate samples whose captions are generated based on another model (such as blip2) and the figure within the original sample | -| image_diffusion_mapper | Multimodal | - | Generate and augment images by stable diffusion model | -| image_face_blur_mapper | Image | - | Blur faces detected in images | -| nlpaug_en_mapper | General | en | Simply augments texts in English based on the `nlpaug` library | -| nlpcda_zh_mapper | General | zh | Simply augments texts in Chinese based on the `nlpcda` library | -| optimize_instruction_mapper | General | en, zh | Optimize instruction text samples.| -| punctuation_normalization_mapper | General | en, zh | Normalizes various Unicode punctuations to their ASCII equivalents | -| remove_bibliography_mapper | LaTeX | en, zh | Removes the bibliography of TeX documents | -| remove_comments_mapper | LaTeX | en, zh | Removes the comments of TeX documents | -| remove_header_mapper | LaTeX | en, zh | Removes the running headers of TeX documents, e.g., titles, chapter or section numbers/names | -| remove_long_words_mapper | General | en, zh | Removes words with length outside the specified range | -| remove_non_chinese_character_mapper | General | en, zh | Remove non Chinese character in text samples. | -| remove_repeat_sentences_mapper | General | en, zh | Remove repeat sentences in text samples. | -| remove_specific_chars_mapper | General | en, zh | Removes any user-specified characters or substrings | -| remove_table_text_mapper | General, Financial | en | Detects and removes possible table contents (:warning: relies on regular expression matching and thus fragile)| -| remove_words_with_incorrect_
substrings_mapper | General | en, zh | Removes words containing specified substrings | -| replace_content_mapper | General | en, zh | Replace all content in the text that matches a specific regular expression pattern with a designated replacement string | -| sentence_split_mapper | General | en | Splits and reorganizes sentences according to semantics | -| video_captioning_from_audio_mapper | Multimodal | - | Caption a video according to its audio streams based on Qwen-Audio model | +| Operator | Domain | Lang | Description | +|-----------------------------------------------------|--------------------|--------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| audio_ffmpeg_wrapped_mapper | Audio | - | Simple wrapper to run a FFmpeg audio filter | +| chinese_convert_mapper | General | zh | Converts Chinese between Traditional Chinese, Simplified Chinese and Japanese Kanji (by [opencc](https://github.com/BYVoid/OpenCC)) | +| clean_copyright_mapper | Code | en, zh | Removes copyright notice at the beginning of code files (must contain the word *copyright*) | +| clean_email_mapper | General | en, zh | Removes email information | +| clean_html_mapper | General | en, zh | Removes HTML tags and returns plain text of all the nodes | +| clean_ip_mapper | General | en, zh | Removes IP addresses | +| clean_links_mapper | General, Code | en, zh | Removes links, such as those starting with http or ftp | +| expand_macro_mapper | LaTeX | en, zh | Expands macros usually defined at the top of TeX documents | +| extract_qa_mapper | General | en, zh | Extract question and answer pair from text samples. | +| fix_unicode_mapper | General | en, zh | Fixes broken Unicodes (by [ftfy](https://ftfy.readthedocs.io/)) | +| generate_instruction_mapper | General | en, zh | Generate instruction text samples. | +| image_blur_mapper | Image | - | Blur images | +| image_captioning_from_gpt4v_mapper | Multimodal | - | generate samples whose texts are generated based on gpt-4-visison and the image | +| image_captioning_mapper | Multimodal | - | generate samples whose captions are generated based on another model (such as blip2) and the figure within the original sample | +| image_diffusion_mapper | Multimodal | - | Generate and augment images by stable diffusion model | +| image_face_blur_mapper | Image | - | Blur faces detected in images | +| image_tagging_mapper | Multimodal | - | Mapper to generate image tags from the input images. | +| nlpaug_en_mapper | General | en | Simply augments texts in English based on the `nlpaug` library | +| nlpcda_zh_mapper | General | zh | Simply augments texts in Chinese based on the `nlpcda` library | +| optimize_instruction_mapper | General | en, zh | Optimize instruction text samples. | +| punctuation_normalization_mapper | General | en, zh | Normalizes various Unicode punctuations to their ASCII equivalents | +| remove_bibliography_mapper | LaTeX | en, zh | Removes the bibliography of TeX documents | +| remove_comments_mapper | LaTeX | en, zh | Removes the comments of TeX documents | +| remove_header_mapper | LaTeX | en, zh | Removes the running headers of TeX documents, e.g., titles, chapter or section numbers/names | +| remove_long_words_mapper | General | en, zh | Removes words with length outside the specified range | +| remove_non_chinese_character_mapper | General | en, zh | Remove non Chinese character in text samples. | +| remove_repeat_sentences_mapper | General | en, zh | Remove repeat sentences in text samples. | +| remove_specific_chars_mapper | General | en, zh | Removes any user-specified characters or substrings | +| remove_table_text_mapper | General, Financial | en | Detects and removes possible table contents (:warning: relies on regular expression matching and thus fragile) | +| remove_words_with_incorrect_
substrings_mapper | General | en, zh | Removes words containing specified substrings | +| replace_content_mapper | General | en, zh | Replace all content in the text that matches a specific regular expression pattern with a designated replacement string | +| sentence_split_mapper | General | en | Splits and reorganizes sentences according to semantics | +| video_captioning_from_audio_mapper | Multimodal | - | Caption a video according to its audio streams based on Qwen-Audio model | | video_captioning_from_frames_mapper | Multimodal | - | generate samples whose captions are generated based on an image-to-text model and sampled video frames. Captions from different frames will be concatenated to a single string | -| video_captioning_from_summarizer_mapper | Multimodal | - | Generate video captions by summarizing several kinds of generated texts (captions from video/audio/frames, tags from audio/frames, ...) | -| video_captioning_from_video_mapper | Multimodal | - | generate samples whose captions are generated based on another model (video-blip) and sampled video frame within the original sample | -| video_face_blur_mapper | Video | - | Blur faces detected in videos | -| video_ffmpeg_wrapped_mapper | Video | - | Simple wrapper to run a FFmpeg video filter | -| video_remove_watermark_mapper | Video | - | Remove the watermarks in videos given regions | -| video_resize_aspect_ratio_mapper | Video | - | Resize video aspect ratio to a specified range | -| video_resize_resolution_mapper | Video | - | Map videos to ones with given resolution range | -| video_split_by_duration_mapper | Multimodal | - | Mapper to split video by duration | -| video_spit_by_key_frame_mapper | Multimodal | - | Mapper to split video by key frame | -| video_split_by_scene_mapper | Multimodal | - | Split videos into scene clips | -| video_tagging_from_audio_mapper | Multimodal | - | Mapper to generate video tags from audio streams extracted from the video. | -| video_tagging_from_frames_mapper | Multimodal | - | Mapper to generate video tags from frames extracted from the video. | -| whitespace_normalization_mapper | General | en, zh | Normalizes various Unicode whitespaces to the normal ASCII space (U+0020) | +| video_captioning_from_summarizer_mapper | Multimodal | - | Generate video captions by summarizing several kinds of generated texts (captions from video/audio/frames, tags from audio/frames, ...) | +| video_captioning_from_video_mapper | Multimodal | - | generate samples whose captions are generated based on another model (video-blip) and sampled video frame within the original sample | +| video_face_blur_mapper | Video | - | Blur faces detected in videos | +| video_ffmpeg_wrapped_mapper | Video | - | Simple wrapper to run a FFmpeg video filter | +| video_remove_watermark_mapper | Video | - | Remove the watermarks in videos given regions | +| video_resize_aspect_ratio_mapper | Video | - | Resize video aspect ratio to a specified range | +| video_resize_resolution_mapper | Video | - | Map videos to ones with given resolution range | +| video_split_by_duration_mapper | Multimodal | - | Mapper to split video by duration | +| video_spit_by_key_frame_mapper | Multimodal | - | Mapper to split video by key frame | +| video_split_by_scene_mapper | Multimodal | - | Split videos into scene clips | +| video_tagging_from_audio_mapper | Multimodal | - | Mapper to generate video tags from audio streams extracted from the video. | +| video_tagging_from_frames_mapper | Multimodal | - | Mapper to generate video tags from frames extracted from the video. | +| whitespace_normalization_mapper | General | en, zh | Normalizes various Unicode whitespaces to the normal ASCII space (U+0020) | ## Filter
diff --git a/docs/Operators_ZH.md b/docs/Operators_ZH.md index 7ee0bda66..f5d598f54 100644 --- a/docs/Operators_ZH.md +++ b/docs/Operators_ZH.md @@ -11,7 +11,7 @@ Data-Juicer 中的算子分为以下 5 种类型。 | 类型 | 数量 | 描述 | |------------------------------------|:--:|---------------| | [ Formatter ]( #formatter ) | 7 | 发现、加载、规范化原始数据 | -| [ Mapper ]( #mapper ) | 46 | 对数据样本进行编辑和转换 | +| [ Mapper ]( #mapper ) | 47 | 对数据样本进行编辑和转换 | | [ Filter ]( #filter ) | 42 | 过滤低质量样本 | | [ Deduplicator ]( #deduplicator ) | 5 | 识别、删除重复样本 | | [ Selector ]( #selector ) | 4 | 基于排序选取高质量样本 | @@ -47,54 +47,55 @@ Data-Juicer 中的算子分为以下 5 种类型。 ## Mapper -| 算子 | 场景 | 语言 | 描述 | -|-----------------------------------------------------|-----------------------|-----------|--------------------------------------------------------| -| audio_ffmpeg_wrapped_mapper | Audio | - | 运行 FFmpeg 语音过滤器的简单封装 | -| chinese_convert_mapper | General | zh | 用于在繁体中文、简体中文和日文汉字之间进行转换(借助 [opencc](https://github.com/BYVoid/OpenCC)) | -| clean_copyright_mapper | Code | en, zh | 删除代码文件开头的版权声明 (必须包含单词 *copyright*) | -| clean_email_mapper | General | en, zh | 删除邮箱信息 | -| clean_html_mapper | General | en, zh | 删除 HTML 标签并返回所有节点的纯文本 | -| clean_ip_mapper | General | en, zh | 删除 IP 地址 | -| clean_links_mapper | General, Code | en, zh | 删除链接,例如以 http 或 ftp 开头的 | -| expand_macro_mapper | LaTeX | en, zh | 扩展通常在 TeX 文档顶部定义的宏 | -| extract_qa_mapper | General | en, zh | 从文本中抽取问答对 | -| fix_unicode_mapper | General | en, zh | 修复损坏的 Unicode(借助 [ftfy](https://ftfy.readthedocs.io/)) | -| generate_instruction_mapper | General | en, zh | 指令扩充,根据种子数据,生成新的样本。 | -| image_blur_mapper | Image | - | 对图像进行模糊处理 | -| image_captioning_from_gpt4v_mapper | Multimodal | - | 基于gpt-4-vision和图像生成文本 | -| image_captioning_mapper | Multimodal | - | 生成样本,其标题是根据另一个辅助模型(例如 blip2)和原始样本中的图形生成的。 | -| image_diffusion_mapper | Multimodal | - | 用stable diffusion生成图像,对图像进行增强 | -| image_face_blur_mapper | Image | - | 对图像中的人脸进行模糊处理 | -| nlpaug_en_mapper | General | en | 使用`nlpaug`库对英语文本进行简单增强 | -| nlpcda_zh_mapper | General | zh | 使用`nlpcda`库对中文文本进行简单增强 | -| optimize_instruction_mapper | General | en, zh | 指令优化,优化prompt。| -| punctuation_normalization_mapper | General | en, zh | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | -| remove_bibliography_mapper | LaTeX | en, zh | 删除 TeX 文档的参考文献 | -| remove_comments_mapper | LaTeX | en, zh | 删除 TeX 文档中的注释 | -| remove_header_mapper | LaTeX | en, zh | 删除 TeX 文档头,例如标题、章节数字/名称等 | -| remove_long_words_mapper | General | en, zh | 删除长度超出指定范围的单词 | -| remove_non_chinese_character_mapper | General | en, zh | 删除样本中的非中文字符 | -| remove_repeat_sentences_mapper | General | en, zh | 删除样本中的重复句子 | -| remove_specific_chars_mapper | General | en, zh | 删除任何用户指定的字符或子字符串 | -| remove_table_text_mapper | General, Financial | en | 检测并删除可能的表格内容(:warning: 依赖正则表达式匹配,因此很脆弱) | -| remove_words_with_incorrect_
substrings_mapper | General | en, zh | 删除包含指定子字符串的单词 | -| replace_content_mapper | General | en, zh | 使用一个指定的替换字符串替换文本中满足特定正则表达式模版的所有内容 | -| sentence_split_mapper | General | en | 根据语义拆分和重组句子 | -| video_captioning_from_audio_mapper | Multimodal | - | 基于 Qwen-Audio 模型根据视频的音频流为视频生成新的标题描述 | -| video_captioning_from_frames_mapper | Multimodal | - | 生成样本,其标题是基于一个文字生成图片的模型和原始样本视频中指定帧的图像。不同帧产出的标题会拼接为一条单独的字符串。 | -| video_captioning_from_summarizer_mapper | Multimodal | - | 通过对多种不同方式生成的文本进行摘要以生成样本的标题(从视频/音频/帧生成标题,从音频/帧生成标签,...) | -| video_captioning_from_video_mapper | Multimodal | - | 生成样本,其标题是根据另一个辅助模型(video-blip)和原始样本中的视频中指定帧的图像。 | -| video_face_blur_mapper | Video | - | 对视频中的人脸进行模糊处理 | -| video_ffmpeg_wrapped_mapper | Video | - | 运行 FFmpeg 视频过滤器的简单封装 | -| video_remove_watermark_mapper | Video | - | 去除视频中给定区域的水印 | -| video_resize_aspect_ratio_mapper | Video | - | 将视频的宽高比调整到指定范围内 | -| video_resize_resolution_mapper | Video | - | 将视频映射到给定的分辨率区间 | -| video_split_by_duration_mapper | Multimodal | - | 根据时长将视频切分为多个片段 | -| video_split_by_key_frame_mapper | Multimodal | - | 根据关键帧切分视频 | -| video_split_by_scene_mapper | Multimodal | - | 将视频切分为场景片段 | +| 算子 | 场景 | 语言 | 描述 | +|----------------------------------------------------|-----------------------|-----------|------------------------------------------------------------------------| +| audio_ffmpeg_wrapped_mapper | Audio | - | 运行 FFmpeg 语音过滤器的简单封装 | +| chinese_convert_mapper | General | zh | 用于在繁体中文、简体中文和日文汉字之间进行转换(借助 [opencc](https://github.com/BYVoid/OpenCC)) | +| clean_copyright_mapper | Code | en, zh | 删除代码文件开头的版权声明 (必须包含单词 *copyright*) | +| clean_email_mapper | General | en, zh | 删除邮箱信息 | +| clean_html_mapper | General | en, zh | 删除 HTML 标签并返回所有节点的纯文本 | +| clean_ip_mapper | General | en, zh | 删除 IP 地址 | +| clean_links_mapper | General, Code | en, zh | 删除链接,例如以 http 或 ftp 开头的 | +| expand_macro_mapper | LaTeX | en, zh | 扩展通常在 TeX 文档顶部定义的宏 | +| extract_qa_mapper | General | en, zh | 从文本中抽取问答对 | +| fix_unicode_mapper | General | en, zh | 修复损坏的 Unicode(借助 [ftfy](https://ftfy.readthedocs.io/)) | +| generate_instruction_mapper | General | en, zh | 指令扩充,根据种子数据,生成新的样本。 | +| image_blur_mapper | Image | - | 对图像进行模糊处理 | +| image_captioning_from_gpt4v_mapper | Multimodal | - | 基于gpt-4-vision和图像生成文本 | +| image_captioning_mapper | Multimodal | - | 生成样本,其标题是根据另一个辅助模型(例如 blip2)和原始样本中的图形生成的。 | +| image_diffusion_mapper | Multimodal | - | 用stable diffusion生成图像,对图像进行增强 | +| image_face_blur_mapper | Image | - | 对图像中的人脸进行模糊处理 | +| image_tagging_mapper | Multimodal | - | 从输入图片中生成图片标签 | +| nlpaug_en_mapper | General | en | 使用`nlpaug`库对英语文本进行简单增强 | +| nlpcda_zh_mapper | General | zh | 使用`nlpcda`库对中文文本进行简单增强 | +| optimize_instruction_mapper | General | en, zh | 指令优化,优化prompt。 | +| punctuation_normalization_mapper | General | en, zh | 将各种 Unicode 标点符号标准化为其 ASCII 等效项 | +| remove_bibliography_mapper | LaTeX | en, zh | 删除 TeX 文档的参考文献 | +| remove_comments_mapper | LaTeX | en, zh | 删除 TeX 文档中的注释 | +| remove_header_mapper | LaTeX | en, zh | 删除 TeX 文档头,例如标题、章节数字/名称等 | +| remove_long_words_mapper | General | en, zh | 删除长度超出指定范围的单词 | +| remove_non_chinese_character_mapper | General | en, zh | 删除样本中的非中文字符 | +| remove_repeat_sentences_mapper | General | en, zh | 删除样本中的重复句子 | +| remove_specific_chars_mapper | General | en, zh | 删除任何用户指定的字符或子字符串 | +| remove_table_text_mapper | General, Financial | en | 检测并删除可能的表格内容(:warning: 依赖正则表达式匹配,因此很脆弱) | +| remove_words_with_incorrect_
substrings_mapper | General | en, zh | 删除包含指定子字符串的单词 | +| replace_content_mapper | General | en, zh | 使用一个指定的替换字符串替换文本中满足特定正则表达式模版的所有内容 | +| sentence_split_mapper | General | en | 根据语义拆分和重组句子 | +| video_captioning_from_audio_mapper | Multimodal | - | 基于 Qwen-Audio 模型根据视频的音频流为视频生成新的标题描述 | +| video_captioning_from_frames_mapper | Multimodal | - | 生成样本,其标题是基于一个文字生成图片的模型和原始样本视频中指定帧的图像。不同帧产出的标题会拼接为一条单独的字符串。 | +| video_captioning_from_summarizer_mapper | Multimodal | - | 通过对多种不同方式生成的文本进行摘要以生成样本的标题(从视频/音频/帧生成标题,从音频/帧生成标签,...) | +| video_captioning_from_video_mapper | Multimodal | - | 生成样本,其标题是根据另一个辅助模型(video-blip)和原始样本中的视频中指定帧的图像。 | +| video_face_blur_mapper | Video | - | 对视频中的人脸进行模糊处理 | +| video_ffmpeg_wrapped_mapper | Video | - | 运行 FFmpeg 视频过滤器的简单封装 | +| video_remove_watermark_mapper | Video | - | 去除视频中给定区域的水印 | +| video_resize_aspect_ratio_mapper | Video | - | 将视频的宽高比调整到指定范围内 | +| video_resize_resolution_mapper | Video | - | 将视频映射到给定的分辨率区间 | +| video_split_by_duration_mapper | Multimodal | - | 根据时长将视频切分为多个片段 | +| video_split_by_key_frame_mapper | Multimodal | - | 根据关键帧切分视频 | +| video_split_by_scene_mapper | Multimodal | - | 将视频切分为场景片段 | | video_tagging_from_audio_mapper | Multimodal | - | 从视频提取的音频中生成视频标签 | -| video_tagging_from_frames_mapper | Multimodal | - | 从视频提取的帧中生成视频标签 | -| whitespace_normalization_mapper | General | en, zh | 将各种 Unicode 空白标准化为常规 ASCII 空格 (U+0020) | +| video_tagging_from_frames_mapper | Multimodal | - | 从视频提取的帧中生成视频标签 | +| whitespace_normalization_mapper | General | en, zh | 将各种 Unicode 空白标准化为常规 ASCII 空格 (U+0020) | ## Filter
diff --git a/tests/ops/mapper/test_image_tagging_mapper.py b/tests/ops/mapper/test_image_tagging_mapper.py new file mode 100644 index 000000000..e9609b12f --- /dev/null +++ b/tests/ops/mapper/test_image_tagging_mapper.py @@ -0,0 +1,156 @@ +# flake8: noqa: E501 +import os +import unittest + +from data_juicer.core.data import NestedDataset as Dataset +from data_juicer.ops.mapper.image_tagging_mapper import \ + ImageTaggingMapper +from data_juicer.utils.constant import Fields +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS + +@SKIPPED_TESTS.register_module() +class ImageTaggingMapperTest(DataJuicerTestCaseBase): + data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', + 'data') + img1_path = os.path.join(data_path, 'img1.png') + img2_path = os.path.join(data_path, 'img2.jpg') + img3_path = os.path.join(data_path, 'img3.jpg') + + def _run_image_tagging_mapper(self, + op, + source_list, + target_list, + num_proc=1): + dataset = Dataset.from_list(source_list) + dataset = dataset.map(op.process, num_proc=num_proc) + res_list = dataset.to_list() + self.assertEqual(res_list, target_list) + + def test(self): + ds_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }, { + 'images': [self.img3_path] + }] + tgt_list = [{ + 'images': [self.img1_path], + Fields.image_tags: [[ + 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', + 'chair', 'pillar', 'comfort', 'side table', 'floor', + 'hardwood floor', 'headboard', 'linen', 'mattress', + 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp', + 'stool', 'white', 'window', 'wood floor']], + }, { + 'images': [self.img2_path], + Fields.image_tags: [[ + 'advertisement', 'back', 'bus', 'car', 'city bus', + 'city street', 'curb', 'decker bus', 'drive', 'license plate', + 'road', 'street scene', 'tour bus', 'travel', 'white']], + }, { + 'images': [self.img3_path], + Fields.image_tags: [[ + 'alley', 'black', 'building', 'catch', 'person', 'pavement', + 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']], + }] + op = ImageTaggingMapper() + self._run_image_tagging_mapper(op, ds_list, tgt_list) + + def test_no_images(self): + ds_list = [{ + 'images': [] + }, { + 'images': [self.img2_path] + }] + tgt_list = [{ + 'images': [], + Fields.image_tags: [[]], + }, { + 'images': [self.img2_path], + Fields.image_tags: [[ + 'advertisement', 'back', 'bus', 'car', 'city bus', + 'city street', 'curb', 'decker bus', 'drive', 'license plate', + 'road', 'street scene', 'tour bus', 'travel', 'white']], + }] + op = ImageTaggingMapper() + self._run_image_tagging_mapper(op, ds_list, tgt_list) + + def test_specified_tag_field_name(self): + tag_field_name = 'my_tags' + + ds_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }, { + 'images': [self.img3_path] + }] + tgt_list = [{ + 'images': [self.img1_path], + tag_field_name: [[ + 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', + 'chair', 'pillar', 'comfort', 'side table', 'floor', + 'hardwood floor', 'headboard', 'linen', 'mattress', + 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp', + 'stool', 'white', 'window', 'wood floor']], + }, { + 'images': [self.img2_path], + tag_field_name: [[ + 'advertisement', 'back', 'bus', 'car', 'city bus', + 'city street', 'curb', 'decker bus', 'drive', 'license plate', + 'road', 'street scene', 'tour bus', 'travel', 'white']], + }, { + 'images': [self.img3_path], + tag_field_name: [[ + 'alley', 'black', 'building', 'catch', 'person', 'pavement', + 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']], + }] + op = ImageTaggingMapper(tag_field_name=tag_field_name) + self._run_image_tagging_mapper(op, ds_list, tgt_list) + + def test_multi_process(self): + # WARNING: current parallel tests only work in spawn method + import multiprocess + original_method = multiprocess.get_start_method() + multiprocess.set_start_method('spawn', force=True) + # WARNING: current parallel tests only work in spawn method + ds_list = [{ + 'images': [self.img1_path] + }, { + 'images': [self.img2_path] + }, { + 'images': [self.img3_path] + }] + tgt_list = [{ + 'images': [self.img1_path], + Fields.image_tags: [[ + 'bed', 'bedcover', 'bedroom', 'bedding', 'lamp', 'ceiling', + 'chair', 'pillar', 'comfort', 'side table', 'floor', + 'hardwood floor', 'headboard', 'linen', 'mattress', + 'nightstand', 'picture frame', 'pillow', 'room', 'wall lamp', + 'stool', 'white', 'window', 'wood floor']], + }, { + 'images': [self.img2_path], + Fields.image_tags: [[ + 'advertisement', 'back', 'bus', 'car', 'city bus', + 'city street', 'curb', 'decker bus', 'drive', 'license plate', + 'road', 'street scene', 'tour bus', 'travel', 'white']], + }, { + 'images': [self.img3_path], + Fields.image_tags: [[ + 'alley', 'black', 'building', 'catch', 'person', 'pavement', + 'photo', 'rain', 'road', 'umbrella', 'walk', 'woman']], + }] + op = ImageTaggingMapper() + self._run_image_tagging_mapper(op, + ds_list, + tgt_list, + num_proc=2) + # WARNING: current parallel tests only work in spawn method + multiprocess.set_start_method(original_method, force=True) + # WARNING: current parallel tests only work in spawn method + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/ops/mapper/test_video_tagging_from_audio_mapper.py b/tests/ops/mapper/test_video_tagging_from_audio_mapper.py index a81fb51c7..8bbf05933 100644 --- a/tests/ops/mapper/test_video_tagging_from_audio_mapper.py +++ b/tests/ops/mapper/test_video_tagging_from_audio_mapper.py @@ -6,9 +6,8 @@ VideoTaggingFromAudioMapper from data_juicer.utils.constant import Fields from data_juicer.utils.mm_utils import SpecialTokens -from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase, SKIPPED_TESTS +from data_juicer.utils.unittest_utils import DataJuicerTestCaseBase -@SKIPPED_TESTS.register_module() class VideoTaggingFromAudioMapperTest(DataJuicerTestCaseBase): data_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data') @@ -29,11 +28,11 @@ def _run_video_tagging_from_audio_mapper(self, op, source_list, target_list, + tag_field_name=Fields.video_audio_tags, num_proc=1): dataset = Dataset.from_list(source_list) dataset = dataset.map(op.process, num_proc=num_proc) - res_list = dataset.select_columns([Fields.video_audio_tags - ])[Fields.video_audio_tags] + res_list = dataset.select_columns([tag_field_name])[tag_field_name] self.assertEqual(res_list, target_list) def test(self): @@ -56,6 +55,27 @@ def test(self): op = VideoTaggingFromAudioMapper(self.hf_ast) self._run_video_tagging_from_audio_mapper(op, ds_list, tgt_list) + def test_specified_tag_field_name(self): + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。' + f'{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': f'{SpecialTokens.video} 一个人在帮另一个人梳头发。 {SpecialTokens.eoc}', + 'videos': [self.vid4_path] + }, { + 'text': + f'{SpecialTokens.video} 一个穿着红色连衣裙的女人在试衣服。 {SpecialTokens.eoc}', + 'videos': [self.vid5_path] + }] + tgt_list = [['Music'], ['Music'], ['Speech'], ['Speech']] + tag_name = 'audio_tags' + op = VideoTaggingFromAudioMapper(self.hf_ast, tag_field_name=tag_name) + self._run_video_tagging_from_audio_mapper(op, ds_list, tgt_list, tag_field_name=tag_name) + def test_multi_chunk(self): ds_list = [{ 'text': diff --git a/tests/ops/mapper/test_video_tagging_from_frames_mapper.py b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py index b34c45151..b310591a4 100644 --- a/tests/ops/mapper/test_video_tagging_from_frames_mapper.py +++ b/tests/ops/mapper/test_video_tagging_from_frames_mapper.py @@ -55,22 +55,97 @@ def test(self): 'videos': [self.vid2_path], Fields.video_frame_tags: [[ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'ball', 'person' + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' ]] }, { 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path], Fields.video_frame_tags: [[ - 'woman', 'table', 'girl', 'sit', 'person', 'laptop', - 'bookshelf', 'conversation', 'round table', 'computer', 'man', - 'closet', 'stool', 'computer screen', 'laugh', 'cabinet', - 'hand', 'selfie', 'stand' + 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', 'closet', 'computer', 'girl', + 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', + 'selfie', 'stand' ]] }] op = VideoTaggingFromFramesMapper() self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list) + def test_no_video(self): + ds_list = [{ + 'text': f'白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }] + tgt_list = [{ + 'text': + f'白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [], + Fields.video_frame_tags: [[]] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path], + Fields.video_frame_tags: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ]] + }] + op = VideoTaggingFromFramesMapper() + self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list) + + def test_specified_tag_field_name(self): + tag_field_name = 'my_tags' + + ds_list = [{ + 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path] + }] + tgt_list = [{ + 'text': + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + 'videos': [self.vid1_path], + tag_field_name: [[ + 'animal', 'ray', 'text', 'writing', 'yellow', 'game', + 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', + 'sky' + ]] + }, { + 'text': + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + 'videos': [self.vid2_path], + tag_field_name: [[ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ]] + }, { + 'text': + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + 'videos': [self.vid3_path], + tag_field_name: [[ + 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', 'closet', 'computer', 'girl', + 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', + 'selfie', 'stand' + ]] + }] + op = VideoTaggingFromFramesMapper(tag_field_name=tag_field_name) + self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list) + def test_uniform(self): ds_list = [{ 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', @@ -89,28 +164,26 @@ def test_uniform(self): f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path], Fields.video_frame_tags: [[ - 'animal', 'cartoon', 'anime', 'game', 'screenshot', - 'video game', 'robe', 'ray', 'text', 'writing', 'yellow', - 'doll', 'tail', 'cartoon character', 'sky', 'person' - ]] + 'cartoon', 'animal', 'anime', 'game', 'screenshot', + 'video game', 'cartoon character', 'robe', 'ray', 'text', + 'writing', 'yellow', 'doll', 'tail', 'sky', 'person']] }, { 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], Fields.video_frame_tags: [[ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'hand', 'catch', 'bulletin board', 'blind', 'play', 'Wii', - 'cotton candy', 'tennis racket', 'game controller', 'remote', - 'stand', 'video game', 'Wii controller', 'racket', - 'baseball uniform', 'toy', 'green' - ]] + 'hand', 'catch', 'bulletin board', 'Wii', 'cotton candy', + 'tennis racket', 'blind', 'game controller', 'remote', 'stand', + 'video game', 'Wii controller', 'play', 'baseball uniform', + 'toy', 'green']] }, { 'text': f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path], Fields.video_frame_tags: [[ 'table', 'sit', 'woman', 'bookshelf', 'conversation', 'person', - 'round table', 'computer', 'girl', 'laptop', 'man', 'closet', + 'round table', 'computer', 'girl', 'man', 'closet', 'laptop', 'stand', 'computer screen', 'talk', 'room', 'stool', 'hand', 'point' ]] @@ -139,7 +212,7 @@ def test_multi_process(self): }] tgt_list = [{ 'text': - f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', + f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。', 'videos': [self.vid1_path], Fields.video_frame_tags: [[ 'animal', 'ray', 'text', 'writing', 'yellow', 'game', @@ -148,21 +221,22 @@ def test_multi_process(self): ]] }, { 'text': - f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', + f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}', 'videos': [self.vid2_path], Fields.video_frame_tags: [[ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'ball', 'person' + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' ]] }, { 'text': - f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', + f'{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid3_path], Fields.video_frame_tags: [[ - 'woman', 'table', 'girl', 'sit', 'person', 'laptop', - 'bookshelf', 'conversation', 'round table', 'computer', 'man', - 'closet', 'stool', 'computer screen', 'laugh', 'cabinet', - 'hand', 'selfie', 'stand' + 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', 'closet', 'computer', 'girl', + 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', + 'selfie', 'stand' ]] }] op = VideoTaggingFromFramesMapper() @@ -197,27 +271,25 @@ def test_multi_chunk(self): 'animal', 'ray', 'text', 'writing', 'yellow', 'game', 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', 'sky' - ], - [ - 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'ball', 'person' - ]] + ], [ + 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ]] }, { 'text': f'{SpecialTokens.video} 身穿白色上衣的男子,拿着一个东西,拍打自己的胃部。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', 'videos': [self.vid2_path, self.vid3_path], Fields.video_frame_tags: [[ 'man', 'shirt', 't shirt', 't-shirt', 'wear', 'white', 'boy', - 'catch', 'hand', 'blind', 'cotton candy', 'ball', 'person' - ], - [ - 'woman', 'table', 'girl', 'sit', - 'person', 'laptop', 'bookshelf', - 'conversation', 'round table', - 'computer', 'man', 'closet', 'stool', - 'computer screen', 'laugh', - 'cabinet', 'hand', 'selfie', 'stand' - ]] + 'catch', 'hand', 'blind', 'cotton candy', 'tennis racket', + 'ball', 'person' + ], [ + 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', 'closet', 'computer', 'girl', + 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', + 'selfie', 'stand' + ]] }, { 'text': f'{SpecialTokens.video} 白色的小羊站在一旁讲话。旁边还有两只灰色猫咪和一只拉着灰狼的猫咪。{SpecialTokens.eoc}{SpecialTokens.video} 两个长头发的女子正坐在一张圆桌前讲话互动。 {SpecialTokens.eoc}', @@ -226,15 +298,12 @@ def test_multi_chunk(self): 'animal', 'ray', 'text', 'writing', 'yellow', 'game', 'screenshot', 'cartoon', 'cartoon character', 'person', 'robe', 'sky' - ], - [ - 'woman', 'table', 'girl', 'sit', - 'person', 'laptop', 'bookshelf', - 'conversation', 'round table', - 'computer', 'man', 'closet', 'stool', - 'computer screen', 'laugh', - 'cabinet', 'hand', 'selfie', 'stand' - ]] + ], [ + 'woman', 'table', 'sit', 'person', 'laptop', 'bookshelf', + 'conversation', 'round table', 'closet', 'computer', 'girl', + 'man', 'stool', 'computer screen', 'laugh', 'cabinet', 'hand', + 'selfie', 'stand' + ]] }] op = VideoTaggingFromFramesMapper() self._run_video_tagging_from_frames_mapper(op, ds_list, tgt_list)