Skip to content

Commit

Permalink
Add new OP: image_tagging_mapper (#423)
Browse files Browse the repository at this point in the history
* * init image tagging mapper

* + Add unittest for image_tagging_mapper
* support specified tag field names for all tagging OPs

* * fix problems of unittest

* + add docs

* * update docs

* * skip two unittests which require ram

* * minor fix for gece's comments

* * merge main into this branch

* + add type hint
  • Loading branch information
HYLcool authored Sep 12, 2024
1 parent 762805b commit c40a308
Show file tree
Hide file tree
Showing 12 changed files with 522 additions and 161 deletions.
5 changes: 5 additions & 0 deletions configs/config_all.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ process:
cv_classifier: '' # OpenCV classifier path for face detection. By default, we will use 'haarcascade_frontalface_alt.xml'.
blur_type: 'gaussian' # type of blur kernel, including ['mean', 'box', 'gaussian']
radius: 2 # radius of blur kernel
- image_tagging_mapper: # Mapper to generate image tags.
tag_field_name: '__dj__image_tags__' # the field name to store the tags. It's "__dj__image_tags__" in default.
- nlpaug_en_mapper: # simply augment texts in English based on the nlpaug library
sequential: false # whether combine all augmentation methods to a sequence. If it's True, a sample will be augmented by all opened augmentation methods sequentially. If it's False, each opened augmentation method would generate its augmented samples independently.
aug_num: 1 # number of augmented samples to be generated. If `sequential` is True, there will be total aug_num augmented samples generated. If it's False, there will be (aug_num * #opened_aug_method) augmented samples generated.
Expand Down Expand Up @@ -258,10 +260,12 @@ process:
show_progress: false # whether to show progress from scenedetect
- video_tagging_from_audio_mapper: # Mapper to generate video tags from audio streams extracted from the video.
hf_ast: 'MIT/ast-finetuned-audioset-10-10-0.4593' # Huggingface model name for the audio classification model.
tag_field_name: '__dj__video_audio_tags__' # the field name to store the tags. It's "__dj__video_audio_tags__" in default.
mem_required: '500MB' # This operation (Op) utilizes deep neural network models that consume a significant amount of memory for computation, hence the system's available memory might constrains the maximum number of processes that can be launched
- video_tagging_from_frames_mapper: # Mapper to generate video tags from frames extracted from the video.
frame_sampling_method: 'all_keyframes' # sampling method of extracting frame images from the videos. Should be one of ["all_keyframes", "uniform"]. The former one extracts all key frames and the latter one extract specified number of frames uniformly from the video. Default: "all_keyframes".
frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration.
tag_field_name: '__dj__video_frame_tags__' # the field name to store the tags. It's "__dj__video_frame_tags__" in default.
- whitespace_normalization_mapper: # normalize different kinds of whitespaces to English whitespace.

# Filter ops
Expand Down Expand Up @@ -478,6 +482,7 @@ process:
contain: any # require the videos containing 'any' or 'all' given tags. When tags equal to [], 'all' keeps all samples, 'any' keeps no sample.
frame_sampling_method: all_keyframes # sampling method of extracting frame images from the videos. Should be one of ["all_keyframes", "uniform"]. The former one extracts all key frames and the latter one extract specified number of frames uniformly from the video. Default: "all_keyframes".
frame_num: 3 # the number of frames to be extracted uniformly from the video. Only works when frame_sampling_method is "uniform". If it's 1, only the middle frame will be extracted. If it's 2, only the first and the last frames will be extracted. If it's larger than 2, in addition to the first and the last frames, other frames will be extracted uniformly within the video duration.
tag_field_name: '__dj__video_frame_tags__' # the field name to store the tags. It's "__dj__video_frame_tags__" in default.
any_or_all: any # keep this sample when any/all videos meet the filter condition
- words_num_filter: # filter text with number of words out of specific range
lang: en # sample in which language
Expand Down
7 changes: 6 additions & 1 deletion data_juicer/ops/filter/video_tagging_from_frames_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self,
contain: str = 'any',
frame_sampling_method: str = 'all_keyframes',
frame_num: PositiveInt = 3,
tag_field_name: str = Fields.video_frame_tags,
any_or_all: str = 'any',
*args,
**kwargs):
Expand All @@ -61,6 +62,8 @@ def __init__(self,
the first and the last frames will be extracted. If it's larger
than 2, in addition to the first and the last frames, other frames
will be extracted uniformly within the video duration.
:param tag_field_name: the field name to store the tags. It's
"__dj__video_frame_tags__" in default.
:param any_or_all: keep this sample with 'any' or 'all' strategy of
all videos. 'any': keep this sample if any videos meet the
condition. 'all': keep this sample only if all videos meet the
Expand All @@ -82,10 +85,12 @@ def __init__(self,
self.tags = set([tag.lower() for tag in tags])
self.contain_any = (contain == 'any')
self.any = (any_or_all == 'any')
self.tag_field_name = tag_field_name
self.tagging_producer = VideoTaggingFromFramesMapper(
frame_sampling_method=frame_sampling_method,
frame_num=frame_num,
accelerator=self.accelerator,
tag_field_name=self.tag_field_name,
)

def compute_stats(self, sample, rank=None, context=False):
Expand All @@ -95,7 +100,7 @@ def compute_stats(self, sample, rank=None, context=False):
return sample

def process(self, sample, rank=None):
video_tags = sample[Fields.video_frame_tags]
video_tags = sample[self.tag_field_name]
if len(video_tags) <= 0:
return True

Expand Down
11 changes: 7 additions & 4 deletions data_juicer/ops/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
generate_instruction_mapper, image_blur_mapper,
image_captioning_from_gpt4v_mapper, image_captioning_mapper,
image_diffusion_mapper, image_face_blur_mapper,
nlpaug_en_mapper, nlpcda_zh_mapper, optimize_instruction_mapper,
punctuation_normalization_mapper, remove_bibliography_mapper,
remove_comments_mapper, remove_header_mapper,
remove_long_words_mapper, remove_non_chinese_character_mapper,
image_tagging_mapper, nlpaug_en_mapper, nlpcda_zh_mapper,
optimize_instruction_mapper, punctuation_normalization_mapper,
remove_bibliography_mapper, remove_comments_mapper,
remove_header_mapper, remove_long_words_mapper,
remove_non_chinese_character_mapper,
remove_repeat_sentences_mapper, remove_specific_chars_mapper,
remove_table_text_mapper,
remove_words_with_incorrect_substrings_mapper,
Expand Down Expand Up @@ -41,6 +42,7 @@
from .image_captioning_mapper import ImageCaptioningMapper
from .image_diffusion_mapper import ImageDiffusionMapper
from .image_face_blur_mapper import ImageFaceBlurMapper
from .image_tagging_mapper import ImageTaggingMapper
from .nlpaug_en_mapper import NlpaugEnMapper
from .nlpcda_zh_mapper import NlpcdaZhMapper
from .optimize_instruction_mapper import OptimizeInstructionMapper
Expand Down Expand Up @@ -123,6 +125,7 @@
'AudioFFmpegWrappedMapper',
'VideoSplitByDurationMapper',
'VideoFaceBlurMapper',
'ImageTaggingMapper',
]

# yapf: enable
85 changes: 85 additions & 0 deletions data_juicer/ops/mapper/image_tagging_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from collections import Counter

import numpy as np

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields
from data_juicer.utils.mm_utils import load_data_with_context, load_image
from data_juicer.utils.model_utils import get_model, prepare_model

from ..base_op import OPERATORS, UNFORKABLE, Mapper
from ..op_fusion import LOADED_IMAGES

OP_NAME = 'image_tagging_mapper'

with AvailabilityChecking(
['torch', 'git+https://github.com/xinyu1205/recognize-anything.git'],
OP_NAME):
import ram # noqa: F401
import torch

# avoid hanging when calling recognizeAnything in multiprocessing
torch.set_num_threads(1)


@UNFORKABLE.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
@LOADED_IMAGES.register_module(OP_NAME)
class ImageTaggingMapper(Mapper):
"""Mapper to generate image tags.
"""

_accelerator = 'cuda'

def __init__(self,
tag_field_name: str = Fields.image_tags,
*args,
**kwargs):
"""
Initialization method.
:param tag_field_name: the field name to store the tags. It's
"__dj__image_tags__" in default.
:param args: extra args
:param kwargs: extra args
"""
super().__init__(*args, **kwargs)
self.model_key = prepare_model(
model_type='recognizeAnything',
pretrained_model_name_or_path='ram_plus_swin_large_14m.pth',
input_size=384)
from ram import get_transform
self.transform = get_transform(image_size=384)
self.tag_field_name = tag_field_name

def process(self, sample, rank=None, context=False):
# check if it's generated already
if self.tag_field_name in sample:
return sample

# there is no image in this sample
if self.image_key not in sample or not sample[self.image_key]:
sample[self.tag_field_name] = np.array([[]], dtype=np.str_)
return sample

# load images
loaded_image_keys = sample[self.image_key]
sample, images = load_data_with_context(sample, context,
loaded_image_keys, load_image)

model = get_model(self.model_key, rank, self.use_cuda())
image_tags = []
for _, value in enumerate(loaded_image_keys):
image = images[value]

image_tensor = torch.unsqueeze(self.transform(image), dim=0).to(
next(model.parameters()).device)
with torch.no_grad():
tags, _ = model.generate_tag(image_tensor)

words = [word.strip() for word in tags[0].split('|')]
word_count = Counter(words)
sorted_word_list = [item for item, _ in word_count.most_common()]
image_tags.append(np.array(sorted_word_list, dtype=np.str_))

sample[self.tag_field_name] = image_tags
return sample
16 changes: 12 additions & 4 deletions data_juicer/ops/mapper/video_tagging_from_audio_mapper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import librosa
import numpy as np

from data_juicer.utils.availability_utils import AvailabilityChecking
from data_juicer.utils.constant import Fields
Expand Down Expand Up @@ -29,11 +30,16 @@ class VideoTaggingFromAudioMapper(Mapper):
def __init__(self,
hf_ast: str = 'MIT/ast-finetuned-audioset-10-10-0.4593',
trust_remote_code: bool = False,
tag_field_name: str = Fields.video_audio_tags,
*args,
**kwargs):
"""
Initialization method.
:param hf_ast: path to the HF model to tag from audios.
:param trust_remote_code: whether to trust the remote code of HF models
:param tag_field_name: the field name to store the tags. It's
"__dj__video_audio_tags__" in default.
:param args: extra args
:param kwargs: extra args
"""
Expand All @@ -44,14 +50,16 @@ def __init__(self,
self._model_sampling_rate = 16000
self._no_audio_label = 'EMPTY'

self.tag_field_name = tag_field_name

def process(self, sample, rank=None):
# check if it's generated already
if Fields.video_audio_tags in sample:
if self.tag_field_name in sample:
return sample

# there is no video in this sample
if self.video_key not in sample or not sample[self.video_key]:
sample[Fields.video_audio_tags] = []
sample[self.tag_field_name] = np.array([], dtype=np.str_)
return sample

# load video paths
Expand Down Expand Up @@ -80,11 +88,11 @@ def process(self, sample, rank=None):
sr = self._model_sampling_rate
inputs = feature_extractor(y,
sampling_rate=sr,
return_tensors='pt')
return_tensors='pt').to(model.device)
with torch.no_grad():
logits = model(**inputs).logits
predicted_tag_id = torch.argmax(logits, dim=-1).item()
predicted_tag = model.config.id2label[predicted_tag_id]
video_audio_tags.append(predicted_tag)
sample[Fields.video_audio_tags] = video_audio_tags
sample[self.tag_field_name] = np.array(video_audio_tags, dtype=np.str_)
return sample
16 changes: 11 additions & 5 deletions data_juicer/ops/mapper/video_tagging_from_frames_mapper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import Counter

import numpy as np
from pydantic import PositiveInt

from data_juicer.utils.availability_utils import AvailabilityChecking
Expand Down Expand Up @@ -36,6 +37,7 @@ class VideoTaggingFromFramesMapper(Mapper):
def __init__(self,
frame_sampling_method: str = 'all_keyframes',
frame_num: PositiveInt = 3,
tag_field_name: str = Fields.video_frame_tags,
*args,
**kwargs):
"""
Expand All @@ -54,6 +56,8 @@ def __init__(self,
the first and the last frames will be extracted. If it's larger
than 2, in addition to the first and the last frames, other frames
will be extracted uniformly within the video duration.
:param tag_field_name: the field name to store the tags. It's
"__dj__video_frame_tags__" in default.
:param args: extra args
:param kwargs: extra args
"""
Expand All @@ -71,14 +75,16 @@ def __init__(self,
from ram import get_transform
self.transform = get_transform(image_size=384)

self.tag_field_name = tag_field_name

def process(self, sample, rank=None, context=False):
# check if it's generated already
if Fields.video_frame_tags in sample:
if self.tag_field_name in sample:
return sample

# there is no video in this sample
if self.video_key not in sample or not sample[self.video_key]:
sample[Fields.video_frame_tags] = []
sample[self.tag_field_name] = np.array([[]], dtype=np.str_)
return sample

# load videos
Expand All @@ -98,7 +104,7 @@ def process(self, sample, rank=None, context=False):
frames = extract_video_frames_uniformly(video, self.frame_num)
else:
video_tags.append([])
frames = []
continue

frame_tensor = torch.stack([
self.transform(frame.to_image()) for frame in frames
Expand All @@ -109,11 +115,11 @@ def process(self, sample, rank=None, context=False):
words = [word.strip() for tag in tags for word in tag.split('|')]
word_count = Counter(words)
sorted_word_list = [item for item, _ in word_count.most_common()]
video_tags.append(sorted_word_list)
video_tags.append(np.array(sorted_word_list, dtype=np.str_))

if not context:
for vid_key in videos:
close_video(videos[vid_key])

sample[Fields.video_frame_tags] = video_tags
sample[self.tag_field_name] = video_tags
return sample
2 changes: 2 additions & 0 deletions data_juicer/utils/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class Fields(object):
# video_frame_tags
video_frame_tags = DEFAULT_PREFIX + 'video_frame_tags__'
video_audio_tags = DEFAULT_PREFIX + 'video_audio_tags__'
# image_tags
image_tags = DEFAULT_PREFIX + 'image_tags__'

# the name of the original file from which this sample was derived.
source_file = DEFAULT_PREFIX + 'source_file__'
Expand Down
Loading

0 comments on commit c40a308

Please sign in to comment.